@@ -190,7 +190,7 @@ pub const Strobe128 = struct {
190190pub const Transcript = struct {
191191 strobe : Strobe128 ,
192192
193- const DomainSeperator = enum {
193+ pub const DomainSeperator = enum {
194194 @"zero-ciphertext-instruction" ,
195195 @"zero-ciphertext-proof" ,
196196 @"pubkey-validity-instruction" ,
@@ -230,21 +230,26 @@ pub const Transcript = struct {
230230 ciphertext : zksdk.elgamal.Ciphertext ,
231231 commitment : zksdk.pedersen.Commitment ,
232232 u64 : u64 ,
233+ domsep : DomainSeperator ,
233234
234235 grouped_2 : zksdk .elgamal .GroupedElGamalCiphertext (2 ),
235236 grouped_3 : zksdk .elgamal .GroupedElGamalCiphertext (3 ),
236237 };
237238
238- pub fn init (comptime seperator : DomainSeperator , inputs : []const TranscriptInput ) Transcript {
239+ /// [agave] https://github.com/solana-program/zk-elgamal-proof/blob/zk-sdk%40v5.0.0/zk-sdk/src/lib.rs#L36
240+ const TRANSCRIPT_DOMAIN = "solana-zk-elgamal-proof-program-v1" ;
241+
242+ pub fn init (comptime seperator : DomainSeperator ) Transcript {
239243 var transcript : Transcript = .{ .strobe = Strobe128 .init ("Merlin v1.0" ) };
240- transcript .appendDomSep ( seperator );
241- for ( inputs ) | input | transcript .appendMessage ( input . label , input . message );
244+ transcript .appendBytes ( "dom-sep" , TRANSCRIPT_DOMAIN );
245+ transcript .appendBytes ( "dom-sep" , @tagName ( seperator ) );
242246 return transcript ;
243247 }
244248
245249 pub fn initTest (label : []const u8 ) Transcript {
246250 comptime if (! builtin .is_test ) @compileError ("should only be used during tests" );
247251 var transcript : Transcript = .{ .strobe = Strobe128 .init ("Merlin v1.0" ) };
252+ transcript .appendBytes ("dom-sep" , TRANSCRIPT_DOMAIN );
248253 transcript .appendBytes ("dom-sep" , label );
249254 return transcript ;
250255 }
@@ -264,6 +269,7 @@ pub const Transcript = struct {
264269 .point = > | * point | & point .toBytes (),
265270 .pubkey = > | * pubkey | & pubkey .toBytes (),
266271 .scalar = > | * scalar | & scalar .toBytes (),
272+ .domsep = > | t | @tagName (t ),
267273 .ciphertext = > | * ct | b : {
268274 @memcpy (buffer [0.. 32], & ct .commitment .point .toBytes ());
269275 @memcpy (buffer [32.. 64], & ct .handle .point .toBytes ());
@@ -284,26 +290,33 @@ pub const Transcript = struct {
284290 comptime session : * Session ,
285291 comptime t : Input.Type ,
286292 comptime label : []const u8 ,
287- data : t . Data ( ),
288- ) if (t == .validate_point ) error {IdentityElement }! void else void {
289- // if validate_point fails to validate, we no longer want to check the contract
293+ data : @FieldType ( Message , @tagName ( t.base ()) ),
294+ ) if (t . validates () ) error {IdentityElement }! void else void {
295+ // If validate_point fails to validate, we no longer want to check the contract
290296 // because the function calling append will now return early.
291297 errdefer session .cancel ();
292298
293299 if (t == .bytes and ! builtin .is_test )
294300 @compileError ("message type `bytes` only allowed in tests" );
295301
296- // assert correctness
302+ // Get the next expected input, and inside we verify that it matches
303+ // the type we're about to append to the transcript.
297304 const input = comptime session .nextInput (t , label );
298- if (t == .validate_point ) try data .rejectIdentity ();
305+ // If the input requires validation, we perform it here.
306+ if (comptime t .validates ()) try data .rejectIdentity ();
307+ // Ensure that the domain seperators are added with the correct label.
308+ // They should always be added through the `appendDomSep` helper function.
309+ switch (t ) {
310+ .domsep = > comptime {
311+ std .debug .assert (input .seperator .? == data );
312+ std .debug .assert (std .mem .eql (u8 , label , "dom-sep" ));
313+ },
314+ else = > {},
315+ }
299316
300- // add the message
301317 self .appendMessage (input .label , @unionInit (
302318 Message ,
303- @tagName (switch (t ) {
304- .validate_point = > .point ,
305- else = > t ,
306- }),
319+ @tagName (t .base ()),
307320 data ,
308321 ));
309322 }
@@ -314,12 +327,16 @@ pub const Transcript = struct {
314327 pub inline fn appendNoValidate (
315328 self : * Transcript ,
316329 comptime session : * Session ,
330+ comptime t : Input.Type ,
317331 comptime label : []const u8 ,
318- point : Ristretto255 ,
332+ data : @FieldType ( Message , @tagName ( t.base ())) ,
319333 ) void {
320- const input = comptime session .nextInput (.validate_point , label );
321- point .rejectIdentity () catch {}; // ignore the error
322- self .appendMessage (input .label , .{ .point = point });
334+ const input = comptime session .nextInput (
335+ @field (Input .Type , "validate_" ++ @tagName (t )),
336+ label ,
337+ );
338+ data .rejectIdentity () catch {}; // ignore the error
339+ self .appendMessage (input .label , @unionInit (Message , @tagName (t ), data ));
323340 }
324341
325342 fn challengeBytes (
@@ -329,7 +346,6 @@ pub const Transcript = struct {
329346 ) void {
330347 var data_len : [4 ]u8 = undefined ;
331348 std .mem .writeInt (u32 , & data_len , @intCast (destination .len ), .little );
332-
333349 self .strobe .metaAd (label , false );
334350 self .strobe .metaAd (& data_len , true );
335351 self .strobe .prf (destination , false );
@@ -351,64 +367,89 @@ pub const Transcript = struct {
351367
352368 // domain seperation helpers
353369
354- pub fn appendDomSep (self : * Transcript , comptime seperator : DomainSeperator ) void {
355- self .appendBytes ("dom-sep" , @tagName (seperator ));
356- }
357-
358- pub fn appendHandleDomSep (
370+ pub inline fn appendDomSep (
359371 self : * Transcript ,
360- comptime mode : enum { batched , unbatched } ,
361- comptime handles : enum { two , three } ,
372+ comptime session : * Session ,
373+ comptime seperator : DomainSeperator ,
362374 ) void {
363- self .appendDomSep (switch (mode ) {
364- .batched = > .@"batched-validity-proof" ,
365- .unbatched = > .@"validity-proof" ,
366- });
367- self .appendMessage ("handles" , .{ .u64 = switch (handles ) {
368- .two = > 2 ,
369- .three = > 3 ,
370- } });
375+ self .append (session , .domsep , "dom-sep" , seperator );
371376 }
372377
373- pub fn appendRangeProof (
378+ pub inline fn appendRangeProof (
374379 self : * Transcript ,
380+ comptime session : * Session ,
375381 comptime mode : enum { range , inner },
376382 n : comptime_int ,
377383 ) void {
378- self .appendDomSep (switch (mode ) {
384+ self .appendDomSep (session , switch (mode ) {
379385 .range = > .@"range-proof" ,
380386 .inner = > .@"inner-product" ,
381387 });
382- self .appendMessage ( "n" , .{ . u64 = n } );
388+ self .append ( session , .u64 , "n" , n );
383389 }
384390
385391 // sessions
386392
387393 pub const Input = struct {
388394 label : []const u8 ,
389395 type : Type ,
396+ seperator : ? DomainSeperator = null ,
390397
391398 const Type = enum {
392399 bytes ,
393400 scalar ,
394- challenge ,
401+ u64 ,
402+
395403 point ,
396- validate_point ,
397404 pubkey ,
405+ ciphertext ,
406+ commitment ,
407+ grouped_2 ,
408+ grouped_3 ,
398409
399- pub fn Data (comptime t : Type ) type {
410+ validate_point ,
411+ validate_pubkey ,
412+ validate_ciphertext ,
413+ validate_commitment ,
414+ validate_grouped_2 ,
415+ validate_grouped_3 ,
416+
417+ domsep ,
418+ challenge ,
419+
420+ /// Returns whether this input type performs identity validation.
421+ fn validates (t : Type ) bool {
400422 return switch (t ) {
401- .bytes = > []const u8 ,
402- .scalar = > Scalar ,
403- .validate_point , .point = > Ristretto255 ,
404- .pubkey = > zksdk .elgamal .Pubkey ,
405- .challenge = > unreachable , // call `challenge*`
423+ .validate_point ,
424+ .validate_pubkey ,
425+ .validate_ciphertext ,
426+ .validate_commitment ,
427+ .validate_grouped_2 ,
428+ .validate_grouped_3 ,
429+ = > true ,
430+ else = > false ,
406431 };
407432 }
433+
434+ /// For a given input type, returns the base type.
435+ /// E.g. `validate_point` -> `point`
436+ /// E.g. `point` -> `point`
437+ fn base (t : Type ) Type {
438+ if (t .validates ()) {
439+ return @field (Type , @tagName (t )["validate_" .len .. ]);
440+ }
441+ return t ;
442+ }
408443 };
409444
445+ pub fn domain (sep : DomainSeperator ) Input {
446+ return .{ .label = "dom-sep" , .type = .domsep , .seperator = sep };
447+ }
448+
410449 fn check (self : Input , t : Type , label : []const u8 ) void {
411- std .debug .assert (self .type == t );
450+ if (self .type != t ) {
451+ @compileError ("expected: " ++ @tagName (self .type ) ++ ", found: " ++ @tagName (t ));
452+ }
412453 std .debug .assert (std .mem .eql (u8 , self .label , label ));
413454 }
414455 };
@@ -418,7 +459,8 @@ pub const Transcript = struct {
418459 pub const Session = struct {
419460 i : u8 ,
420461 contract : Contract ,
421- err : bool , // if validate_point errors, we skip the finish() check
462+ // If an identity validation errors, we skip the finish() check.
463+ err : bool ,
422464
423465 pub inline fn nextInput (comptime self : * Session , t : Input.Type , label : []const u8 ) Input {
424466 comptime {
@@ -453,6 +495,14 @@ pub const Transcript = struct {
453495 return .{ .i = 0 , .contract = contract , .err = false };
454496 }
455497 }
498+
499+ /// The same as `getSession`, but does not check that it ends with a challenge.
500+ /// Only used in certain cases when we need an "init" contract, such as `percentage_with_cap`.
501+ pub inline fn getInitSession (comptime contract : []const Input ) Session {
502+ comptime {
503+ return .{ .i = 0 , .contract = contract , .err = false };
504+ }
505+ }
456506};
457507
458508test "equivalence" {
@@ -468,9 +518,9 @@ test "equivalence" {
468518 transcript .challengeBytes ("challenge" , & bytes );
469519
470520 try std .testing .expectEqualSlices (u8 , &.{
471- 0xd5 , 0xa2 , 0x19 , 0x72 , 0xd0 , 0xd5 , 0xfe , 0x32 ,
472- 0xc , 0xd , 0x26 , 0x3f , 0xac , 0x7f , 0xff , 0xb8 ,
473- 0x14 , 0x5a , 0xa6 , 0x40 , 0xaf , 0x6e , 0x9b , 0xca ,
474- 0x17 , 0x7c , 0x3 , 0xc7 , 0xef , 0xcf , 0x6 , 0x15 ,
521+ 159 , 115 , 74 , 116 , 119 , 227 , 89 , 42 ,
522+ 108 , 83 , 69 , 218 , 43 , 29 , 11 , 79 ,
523+ 117 , 141 , 121 , 172 , 163 , 50 , 123 , 92 ,
524+ 25 , 21 , 111 , 177 , 11 , 232 , 4 , 35 ,
475525 }, & bytes );
476526}
0 commit comments