@@ -190,7 +190,7 @@ pub const Strobe128 = struct {
190190pub const Transcript = struct {
191191 strobe : Strobe128 ,
192192
193- const DomainSeperator = enum {
193+ pub const DomainSeperator = enum {
194194 @"zero-ciphertext-instruction" ,
195195 @"zero-ciphertext-proof" ,
196196 @"pubkey-validity-instruction" ,
@@ -230,21 +230,26 @@ pub const Transcript = struct {
230230 ciphertext : zksdk.elgamal.Ciphertext ,
231231 commitment : zksdk.pedersen.Commitment ,
232232 u64 : u64 ,
233+ domsep : DomainSeperator ,
233234
234235 grouped_2 : zksdk .elgamal .GroupedElGamalCiphertext (2 ),
235236 grouped_3 : zksdk .elgamal .GroupedElGamalCiphertext (3 ),
236237 };
237238
238- pub fn init (comptime seperator : DomainSeperator , inputs : []const TranscriptInput ) Transcript {
239+ /// [agave] https://github.com/solana-program/zk-elgamal-proof/blob/zk-sdk%40v5.0.0/zk-sdk/src/lib.rs#L36
240+ const TRANSCRIPT_DOMAIN = "solana-zk-elgamal-proof-program-v1" ;
241+
242+ pub fn init (comptime seperator : DomainSeperator ) Transcript {
239243 var transcript : Transcript = .{ .strobe = Strobe128 .init ("Merlin v1.0" ) };
240- transcript .appendDomSep ( seperator );
241- for ( inputs ) | input | transcript .appendMessage ( input . label , input . message );
244+ transcript .appendBytes ( "dom-sep" , TRANSCRIPT_DOMAIN );
245+ transcript .appendBytes ( "dom-sep" , @tagName ( seperator ) );
242246 return transcript ;
243247 }
244248
245249 pub fn initTest (label : []const u8 ) Transcript {
246250 comptime if (! builtin .is_test ) @compileError ("should only be used during tests" );
247251 var transcript : Transcript = .{ .strobe = Strobe128 .init ("Merlin v1.0" ) };
252+ transcript .appendBytes ("dom-sep" , TRANSCRIPT_DOMAIN );
248253 transcript .appendBytes ("dom-sep" , label );
249254 return transcript ;
250255 }
@@ -264,6 +269,7 @@ pub const Transcript = struct {
264269 .point = > | * point | & point .toBytes (),
265270 .pubkey = > | * pubkey | & pubkey .toBytes (),
266271 .scalar = > | * scalar | & scalar .toBytes (),
272+ .domsep = > | t | @tagName (t ),
267273 .ciphertext = > | * ct | b : {
268274 @memcpy (buffer [0.. 32], & ct .commitment .point .toBytes ());
269275 @memcpy (buffer [32.. 64], & ct .handle .point .toBytes ());
@@ -284,26 +290,30 @@ pub const Transcript = struct {
284290 comptime session : * Session ,
285291 comptime t : Input.Type ,
286292 comptime label : []const u8 ,
287- data : t . Data ( ),
288- ) if (t == .validate_point ) error {IdentityElement }! void else void {
289- // if validate_point fails to validate, we no longer want to check the contract
293+ data : @FieldType ( Message , @tagName ( t.base ()) ),
294+ ) if (t . validates () ) error {IdentityElement }! void else void {
295+ // If validate_point fails to validate, we no longer want to check the contract
290296 // because the function calling append will now return early.
291297 errdefer session .cancel ();
292298
293- if (t == .bytes and ! builtin .is_test )
294- @compileError ("message type `bytes` only allowed in tests" );
295-
296- // assert correctness
299+ // Get the next expected input, and inside we verify that it matches
300+ // the type we're about to append to the transcript.
297301 const input = comptime session .nextInput (t , label );
298- if (t == .validate_point ) try data .rejectIdentity ();
302+ // If the input requires validation, we perform it here.
303+ if (comptime t .validates ()) try data .rejectIdentity ();
304+ // Ensure that the domain seperators are added with the correct label.
305+ // They should always be added through the `appendDomSep` helper function.
306+ switch (t ) {
307+ .domsep = > comptime {
308+ std .debug .assert (input .seperator .? == data );
309+ std .debug .assert (std .mem .eql (u8 , label , "dom-sep" ));
310+ },
311+ else = > {},
312+ }
299313
300- // add the message
301314 self .appendMessage (input .label , @unionInit (
302315 Message ,
303- @tagName (switch (t ) {
304- .validate_point = > .point ,
305- else = > t ,
306- }),
316+ @tagName (t .base ()),
307317 data ,
308318 ));
309319 }
@@ -314,22 +324,26 @@ pub const Transcript = struct {
314324 pub inline fn appendNoValidate (
315325 self : * Transcript ,
316326 comptime session : * Session ,
327+ comptime t : Input.Type ,
317328 comptime label : []const u8 ,
318- point : Ristretto255 ,
329+ data : @FieldType ( Message , @tagName ( t.base ())) ,
319330 ) void {
320- const input = comptime session .nextInput (.validate_point , label );
321- point .rejectIdentity () catch {}; // ignore the error
322- self .appendMessage (input .label , .{ .point = point });
331+ const input = comptime session .nextInput (
332+ @field (Input .Type , "validate_" ++ @tagName (t )),
333+ label ,
334+ );
335+ data .rejectIdentity () catch {}; // ignore the error
336+ self .appendMessage (input .label , @unionInit (Message , @tagName (t ), data ));
323337 }
324338
325- fn challengeBytes (
339+ /// NOTE: This is only meant for `challengeScalar` and tests.
340+ pub fn challengeBytes (
326341 self : * Transcript ,
327342 label : []const u8 ,
328343 destination : []u8 ,
329344 ) void {
330345 var data_len : [4 ]u8 = undefined ;
331346 std .mem .writeInt (u32 , & data_len , @intCast (destination .len ), .little );
332-
333347 self .strobe .metaAd (label , false );
334348 self .strobe .metaAd (& data_len , true );
335349 self .strobe .prf (destination , false );
@@ -351,64 +365,89 @@ pub const Transcript = struct {
351365
352366 // domain seperation helpers
353367
354- pub fn appendDomSep (self : * Transcript , comptime seperator : DomainSeperator ) void {
355- self .appendBytes ("dom-sep" , @tagName (seperator ));
356- }
357-
358- pub fn appendHandleDomSep (
368+ pub inline fn appendDomSep (
359369 self : * Transcript ,
360- comptime mode : enum { batched , unbatched } ,
361- comptime handles : enum { two , three } ,
370+ comptime session : * Session ,
371+ comptime seperator : DomainSeperator ,
362372 ) void {
363- self .appendDomSep (switch (mode ) {
364- .batched = > .@"batched-validity-proof" ,
365- .unbatched = > .@"validity-proof" ,
366- });
367- self .appendMessage ("handles" , .{ .u64 = switch (handles ) {
368- .two = > 2 ,
369- .three = > 3 ,
370- } });
373+ self .append (session , .domsep , "dom-sep" , seperator );
371374 }
372375
373- pub fn appendRangeProof (
376+ pub inline fn appendRangeProof (
374377 self : * Transcript ,
378+ comptime session : * Session ,
375379 comptime mode : enum { range , inner },
376380 n : comptime_int ,
377381 ) void {
378- self .appendDomSep (switch (mode ) {
382+ self .appendDomSep (session , switch (mode ) {
379383 .range = > .@"range-proof" ,
380384 .inner = > .@"inner-product" ,
381385 });
382- self .appendMessage ( "n" , .{ . u64 = n } );
386+ self .append ( session , .u64 , "n" , n );
383387 }
384388
385389 // sessions
386390
387391 pub const Input = struct {
388392 label : []const u8 ,
389393 type : Type ,
394+ seperator : ? DomainSeperator = null ,
390395
391396 const Type = enum {
392397 bytes ,
393398 scalar ,
394- challenge ,
399+ u64 ,
400+
395401 point ,
396- validate_point ,
397402 pubkey ,
403+ ciphertext ,
404+ commitment ,
405+ grouped_2 ,
406+ grouped_3 ,
407+
408+ validate_point ,
409+ validate_pubkey ,
410+ validate_ciphertext ,
411+ validate_commitment ,
412+ validate_grouped_2 ,
413+ validate_grouped_3 ,
414+
415+ domsep ,
416+ challenge ,
398417
399- pub fn Data (comptime t : Type ) type {
418+ /// Returns whether this input type performs identity validation.
419+ fn validates (t : Type ) bool {
400420 return switch (t ) {
401- .bytes = > []const u8 ,
402- .scalar = > Scalar ,
403- .validate_point , .point = > Ristretto255 ,
404- .pubkey = > zksdk .elgamal .Pubkey ,
405- .challenge = > unreachable , // call `challenge*`
421+ .validate_point ,
422+ .validate_pubkey ,
423+ .validate_ciphertext ,
424+ .validate_commitment ,
425+ .validate_grouped_2 ,
426+ .validate_grouped_3 ,
427+ = > true ,
428+ else = > false ,
406429 };
407430 }
431+
432+ /// For a given input type, returns the base type.
433+ /// E.g. `validate_point` -> `point`
434+ /// E.g. `point` -> `point`
435+ fn base (t : Type ) Type {
436+ if (t .validates ()) {
437+ return @field (Type , @tagName (t )["validate_" .len .. ]);
438+ }
439+ return t ;
440+ }
408441 };
409442
443+ pub fn domain (sep : DomainSeperator ) Input {
444+ return .{ .label = "dom-sep" , .type = .domsep , .seperator = sep };
445+ }
446+
410447 fn check (self : Input , t : Type , label : []const u8 ) void {
411- std .debug .assert (self .type == t );
448+ if (self .type != t ) {
449+ @compileError ("expected: " ++ @tagName (self .type ) ++ ", found: " ++ @tagName (t ));
450+ }
412451 std .debug .assert (std .mem .eql (u8 , self .label , label ));
413452 }
414453 };
@@ -418,7 +457,8 @@ pub const Transcript = struct {
418457 pub const Session = struct {
419458 i : u8 ,
420459 contract : Contract ,
421- err : bool , // if validate_point errors, we skip the finish() check
460+ // If an identity validation errors, we skip the finish() check.
461+ err : bool ,
422462
423463 pub inline fn nextInput (comptime self : * Session , t : Input.Type , label : []const u8 ) Input {
424464 comptime {
@@ -453,6 +493,15 @@ pub const Transcript = struct {
453493 return .{ .i = 0 , .contract = contract , .err = false };
454494 }
455495 }
496+
497+ /// The same as `getSession`, but does not check that it ends with a challenge.
498+ ///
499+ /// Only used in certain cases when we need an "init" contract, such as `percentage_with_cap`.
500+ pub inline fn getInitSession (comptime contract : []const Input ) Session {
501+ comptime {
502+ return .{ .i = 0 , .contract = contract , .err = false };
503+ }
504+ }
456505};
457506
458507test "equivalence" {
@@ -468,9 +517,9 @@ test "equivalence" {
468517 transcript .challengeBytes ("challenge" , & bytes );
469518
470519 try std .testing .expectEqualSlices (u8 , &.{
471- 0xd5 , 0xa2 , 0x19 , 0x72 , 0xd0 , 0xd5 , 0xfe , 0x32 ,
472- 0xc , 0xd , 0x26 , 0x3f , 0xac , 0x7f , 0xff , 0xb8 ,
473- 0x14 , 0x5a , 0xa6 , 0x40 , 0xaf , 0x6e , 0x9b , 0xca ,
474- 0x17 , 0x7c , 0x3 , 0xc7 , 0xef , 0xcf , 0x6 , 0x15 ,
520+ 159 , 115 , 74 , 116 , 119 , 227 , 89 , 42 ,
521+ 108 , 83 , 69 , 218 , 43 , 29 , 11 , 79 ,
522+ 117 , 141 , 121 , 172 , 163 , 50 , 123 , 92 ,
523+ 25 , 21 , 111 , 177 , 11 , 232 , 4 , 35 ,
475524 }, & bytes );
476525}
0 commit comments