1- import { AsyncStateService , StateEntry } from "@proto-kit/sequencer" ;
1+ import { AsyncStateService , MaskName , StateEntry } from "@proto-kit/sequencer" ;
22import { Field } from "o1js" ;
33import { Prisma } from "@prisma/client" ;
44import { noop } from "@proto-kit/common" ;
@@ -20,53 +20,85 @@ const Decimal = Prisma.Decimal.clone({
2020export class PrismaStateService implements AsyncStateService {
2121 private cache : StateEntry [ ] = [ ] ;
2222
23- private maskId ?: number ;
23+ private maskId ?: { id : number ; parentId ?: number } ;
2424
2525 /**
2626 * @param connection
2727 * @param mask A indicator to which masking level the values belong.
2828 * This name has to be unique
29- * @param parent
29+ * @param parentName
3030 */
3131 public constructor (
3232 private readonly connection : PrismaConnection ,
3333 private readonly mask : string ,
34- private readonly parent ?: number
34+ private readonly parentName ?: string
3535 ) { }
3636
37- private async getMaskId ( ) : Promise < number > {
37+ private async getMaskId ( ) : Promise < { id : number ; parentId ?: number } > {
3838 if ( this . maskId === undefined ) {
39- this . maskId = await this . initializeMask ( this . mask , this . parent ) ;
39+ this . maskId = await this . initializeMask ( this . mask , this . parentName ) ;
4040 }
4141 return this . maskId ;
4242 }
4343
44- private async initializeMask ( mask : string , parent ?: number ) : Promise < number > {
44+ private async initializeMask (
45+ mask : string ,
46+ parentName ?: string
47+ ) : Promise < { id : number ; parentId ?: number } > {
4548 const { prismaClient } = this . connection ;
4649
4750 const found = await prismaClient . mask . findFirst ( {
4851 where : {
4952 name : mask ,
50- parent,
53+ } ,
54+ include : {
55+ parentMask : true ,
5156 } ,
5257 } ) ;
5358
5459 if ( found === null ) {
60+ // Find parent id
61+ let parentId : number | undefined = undefined ;
62+ if ( parentName !== undefined ) {
63+ const parent = await prismaClient . mask . findFirst ( {
64+ where : {
65+ name : parentName ,
66+ } ,
67+ } ) ;
68+ if ( parent === null ) {
69+ throw new Error ( `Parent mask with name ${ parentName } not found` ) ;
70+ }
71+
72+ parentId = parent . id ;
73+ } else if ( mask !== MaskName . base ( ) ) {
74+ throw new Error (
75+ "Can't initialize mask that's not the base using a null-parent "
76+ ) ;
77+ }
78+
79+ // Create mask
5580 const createdMask = await prismaClient . mask . create ( {
5681 data : {
57- parent,
82+ parent : parentId ,
5883 name : mask ,
5984 } ,
6085 } ) ;
61- return createdMask . id ;
86+ return {
87+ id : createdMask . id ,
88+ parentId : parentId ,
89+ } ;
6290 }
63- return found . id ;
91+
92+ return {
93+ id : found . id ,
94+ parentId : found . parent ?? undefined ,
95+ } ;
6496 }
6597
6698 public async commit ( ) : Promise < void > {
6799 const { prismaClient } = this . connection ;
68100
69- const maskId = await this . getMaskId ( ) ;
101+ const { id : maskId } = await this . getMaskId ( ) ;
70102
71103 const data = this . cache
72104 . filter ( ( entry ) => entry . value !== undefined )
@@ -92,7 +124,7 @@ export class PrismaStateService implements AsyncStateService {
92124 }
93125
94126 public async getMany ( keys : Field [ ] ) : Promise < StateEntry [ ] > {
95- const maskId = await this . getMaskId ( ) ;
127+ const { id : maskId } = await this . getMaskId ( ) ;
96128 const paths = keys . map ( ( key ) => new Decimal ( key . toString ( ) ) ) ;
97129
98130 const records : {
@@ -123,16 +155,18 @@ export class PrismaStateService implements AsyncStateService {
123155 }
124156
125157 public async createMask ( name : string ) : Promise < AsyncStateService > {
126- const maskId = await this . getMaskId ( ) ;
127- return new PrismaStateService ( this . connection , name , maskId ) ;
158+ // We only call this to make sure this mask actually exists, therefore that the
159+ // relation can be satisfied
160+ await this . getMaskId ( ) ;
161+ return new PrismaStateService ( this . connection , name , this . mask ) ;
128162 }
129163
130164 public async mergeIntoParent ( ) : Promise < void > {
131- const maskId = await this . getMaskId ( ) ;
165+ const { id : maskId , parentId } = await this . getMaskId ( ) ;
132166
133167 const client = this . connection . prismaClient ;
134168
135- if ( this . parent !== undefined ) {
169+ if ( parentId !== undefined ) {
136170 // Rough strategy here:
137171 // 1. Delete all entries that are bound to be overwritten from the parent mask
138172 // 2. Update this mask's entries to parent mask id
@@ -150,7 +184,7 @@ export class PrismaStateService implements AsyncStateService {
150184 parent : maskId ,
151185 } ,
152186 data : {
153- parent : this . parent ,
187+ parent : parentId ,
154188 } ,
155189 } ) ,
156190 client . mask . delete ( {
@@ -165,7 +199,7 @@ export class PrismaStateService implements AsyncStateService {
165199 }
166200
167201 public async drop ( ) : Promise < void > {
168- const maskId = await this . getMaskId ( ) ;
202+ const { id : maskId } = await this . getMaskId ( ) ;
169203
170204 await this . connection . prismaClient . state . deleteMany ( {
171205 where : {
0 commit comments