Skip to content

Commit 2498fe3

Browse files
committed
zksdk 5.0 conformance
1 parent 20ffa22 commit 2498fe3

File tree

15 files changed

+995
-905
lines changed

15 files changed

+995
-905
lines changed

src/runtime/program/zk_elgamal/execute.zig

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ fn processVerifyProof(
156156
break :cd proof_data.context;
157157
};
158158

159-
// create context state if additional accounts are provided with the instruction
159+
// Create context state if additional accounts are provided with the instruction.
160160
if (ic.ixn_info.account_metas.items.len >= accessed_accounts + 2) {
161161
const context_authority_key = blk: {
162162
const context_state_authority = try ic.borrowInstructionAccount(accessed_accounts + 1);

src/runtime/program/zk_elgamal/lib.zig

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
const sig = @import("../../../sig.zig");
22

3+
/// [agave] https://github.com/solana-program/zk-elgamal-proof/blob/zk-sdk%40v5.0.0/zk-sdk/src/zk_elgamal_proof_program/proof_data/mod.rs#L48
34
pub const ProofType = enum(u8) {
45
/// Empty proof type used to distinguish if a proof context account is initialized
56
uninitialized,
@@ -35,6 +36,7 @@ pub fn ProofContextState(C: type) type {
3536

3637
pub const ID: sig.core.Pubkey = .parse("ZkE1Gama1Proof11111111111111111111111111111");
3738

39+
// [agave] https://github.com/anza-xyz/agave/blob/master/programs/zk-elgamal-proof/src/lib.rs#L19-L31
3840
pub const CLOSE_CONTEXT_STATE_COMPUTE_UNITS: u64 = 3_300;
3941
pub const VERIFY_ZERO_CIPHERTEXT_COMPUTE_UNITS: u64 = 6_000;
4042
pub const VERIFY_CIPHERTEXT_CIPHERTEXT_EQUALITY_COMPUTE_UNITS: u64 = 8_000;

src/zksdk/elgamal.zig

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ pub const Pubkey = struct {
5454
);
5555
return fromBytes(buffer);
5656
}
57+
58+
pub fn rejectIdentity(self: *const Pubkey) error{IdentityElement}!void {
59+
try self.point.rejectIdentity();
60+
}
5761
};
5862

5963
pub const Keypair = struct {
@@ -107,6 +111,11 @@ pub const Ciphertext = struct {
107111
);
108112
return fromBytes(buffer);
109113
}
114+
115+
pub fn rejectIdentity(self: *const Ciphertext) error{IdentityElement}!void {
116+
try self.commitment.point.rejectIdentity();
117+
try self.handle.point.rejectIdentity();
118+
}
110119
};
111120

112121
pub fn encrypt(comptime T: type, value: T, pubkey: *const Pubkey) Ciphertext {
@@ -168,13 +177,27 @@ pub fn GroupedElGamalCiphertext(comptime N: u64) type {
168177
};
169178
}
170179

180+
pub fn fromBase64(string: []const u8) !Self {
181+
const base64 = std.base64.standard;
182+
var buffer: [BYTE_LEN]u8 = @splat(0);
183+
const decoded_length = try base64.Decoder.calcSizeForSlice(string);
184+
try std.base64.standard.Decoder.decode(
185+
buffer[0..decoded_length],
186+
string,
187+
);
188+
return fromBytes(buffer);
189+
}
190+
171191
pub fn toBytes(self: Self) [BYTE_LEN]u8 {
172192
var handles: [N * 32]u8 = undefined;
173193
for (self.handles, 0..) |handle, i| {
174-
const position = i * 32;
175-
handles[position..][0..32].* = handle.point.toBytes();
194+
handles[i * 32 ..][0..32].* = handle.point.toBytes();
176195
}
177196
return self.commitment.point.toBytes() ++ handles;
178197
}
198+
199+
pub fn rejectIdentity(self: *const Self) error{IdentityElement}!void {
200+
try self.commitment.rejectIdentity();
201+
}
179202
};
180203
}

src/zksdk/lib.zig

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@ pub const PercentageWithCapData = percentage.Data;
2222
pub const PubkeyProofData = pubkey_validity.Data;
2323
pub const ZeroCiphertextData = zero_ciphertext.Data;
2424

25-
// grouped ciphertext validity
26-
const grouped_cipher_handles_2 = @import("sigma_proofs/grouped_ciphertext/handles_2.zig");
27-
const grouped_cipher_handles_3 = @import("sigma_proofs/grouped_ciphertext/handles_3.zig");
25+
// // grouped ciphertext validity
26+
const grouped_cipher_2_handles = @import("sigma_proofs/grouped_ciphertext/2_handles.zig");
27+
const grouped_cipher_3_handles = @import("sigma_proofs/grouped_ciphertext/3_handles.zig");
2828

29-
pub const GroupedCiphertext2HandlesData = grouped_cipher_handles_2.Data;
30-
pub const BatchedGroupedCiphertext2HandlesData = grouped_cipher_handles_2.BatchedData;
31-
pub const GroupedCiphertext3HandlesData = grouped_cipher_handles_3.Data;
32-
pub const BatchedGroupedCiphertext3HandlesData = grouped_cipher_handles_3.BatchedData;
29+
pub const GroupedCiphertext2HandlesData = grouped_cipher_2_handles.Data;
30+
pub const BatchedGroupedCiphertext2HandlesData = grouped_cipher_2_handles.BatchedData;
31+
pub const GroupedCiphertext3HandlesData = grouped_cipher_3_handles.Data;
32+
pub const BatchedGroupedCiphertext3HandlesData = grouped_cipher_3_handles.BatchedData;
3333

34-
// range proof
34+
// // range proof
3535
pub const bulletproofs = @import("range_proof/bulletproofs.zig");
3636

3737
pub const RangeProofU64Data = bulletproofs.Data(64);

src/zksdk/merlin.zig

Lines changed: 101 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ pub const Strobe128 = struct {
190190
pub const Transcript = struct {
191191
strobe: Strobe128,
192192

193-
const DomainSeperator = enum {
193+
pub const DomainSeperator = enum {
194194
@"zero-ciphertext-instruction",
195195
@"zero-ciphertext-proof",
196196
@"pubkey-validity-instruction",
@@ -230,21 +230,26 @@ pub const Transcript = struct {
230230
ciphertext: zksdk.elgamal.Ciphertext,
231231
commitment: zksdk.pedersen.Commitment,
232232
u64: u64,
233+
domsep: DomainSeperator,
233234

234235
grouped_2: zksdk.elgamal.GroupedElGamalCiphertext(2),
235236
grouped_3: zksdk.elgamal.GroupedElGamalCiphertext(3),
236237
};
237238

238-
pub fn init(comptime seperator: DomainSeperator, inputs: []const TranscriptInput) Transcript {
239+
/// [agave] https://github.com/solana-program/zk-elgamal-proof/blob/zk-sdk%40v5.0.0/zk-sdk/src/lib.rs#L36
240+
const TRANSCRIPT_DOMAIN = "solana-zk-elgamal-proof-program-v1";
241+
242+
pub fn init(comptime seperator: DomainSeperator) Transcript {
239243
var transcript: Transcript = .{ .strobe = Strobe128.init("Merlin v1.0") };
240-
transcript.appendDomSep(seperator);
241-
for (inputs) |input| transcript.appendMessage(input.label, input.message);
244+
transcript.appendBytes("dom-sep", TRANSCRIPT_DOMAIN);
245+
transcript.appendBytes("dom-sep", @tagName(seperator));
242246
return transcript;
243247
}
244248

245249
pub fn initTest(label: []const u8) Transcript {
246250
comptime if (!builtin.is_test) @compileError("should only be used during tests");
247251
var transcript: Transcript = .{ .strobe = Strobe128.init("Merlin v1.0") };
252+
transcript.appendBytes("dom-sep", TRANSCRIPT_DOMAIN);
248253
transcript.appendBytes("dom-sep", label);
249254
return transcript;
250255
}
@@ -264,6 +269,7 @@ pub const Transcript = struct {
264269
.point => |*point| &point.toBytes(),
265270
.pubkey => |*pubkey| &pubkey.toBytes(),
266271
.scalar => |*scalar| &scalar.toBytes(),
272+
.domsep => |t| @tagName(t),
267273
.ciphertext => |*ct| b: {
268274
@memcpy(buffer[0..32], &ct.commitment.point.toBytes());
269275
@memcpy(buffer[32..64], &ct.handle.point.toBytes());
@@ -284,26 +290,33 @@ pub const Transcript = struct {
284290
comptime session: *Session,
285291
comptime t: Input.Type,
286292
comptime label: []const u8,
287-
data: t.Data(),
288-
) if (t == .validate_point) error{IdentityElement}!void else void {
289-
// if validate_point fails to validate, we no longer want to check the contract
293+
data: @FieldType(Message, @tagName(t.base())),
294+
) if (t.validates()) error{IdentityElement}!void else void {
295+
// If validate_point fails to validate, we no longer want to check the contract
290296
// because the function calling append will now return early.
291297
errdefer session.cancel();
292298

293299
if (t == .bytes and !builtin.is_test)
294300
@compileError("message type `bytes` only allowed in tests");
295301

296-
// assert correctness
302+
// Get the next expected input, and inside we verify that it matches
303+
// the type we're about to append to the transcript.
297304
const input = comptime session.nextInput(t, label);
298-
if (t == .validate_point) try data.rejectIdentity();
305+
// If the input requires validation, we perform it here.
306+
if (comptime t.validates()) try data.rejectIdentity();
307+
// Ensure that the domain seperators are added with the correct label.
308+
// They should always be added through the `appendDomSep` helper function.
309+
switch (t) {
310+
.domsep => comptime {
311+
std.debug.assert(input.seperator.? == data);
312+
std.debug.assert(std.mem.eql(u8, label, "dom-sep"));
313+
},
314+
else => {},
315+
}
299316

300-
// add the message
301317
self.appendMessage(input.label, @unionInit(
302318
Message,
303-
@tagName(switch (t) {
304-
.validate_point => .point,
305-
else => t,
306-
}),
319+
@tagName(t.base()),
307320
data,
308321
));
309322
}
@@ -314,12 +327,16 @@ pub const Transcript = struct {
314327
pub inline fn appendNoValidate(
315328
self: *Transcript,
316329
comptime session: *Session,
330+
comptime t: Input.Type,
317331
comptime label: []const u8,
318-
point: Ristretto255,
332+
data: @FieldType(Message, @tagName(t.base())),
319333
) void {
320-
const input = comptime session.nextInput(.validate_point, label);
321-
point.rejectIdentity() catch {}; // ignore the error
322-
self.appendMessage(input.label, .{ .point = point });
334+
const input = comptime session.nextInput(
335+
@field(Input.Type, "validate_" ++ @tagName(t)),
336+
label,
337+
);
338+
data.rejectIdentity() catch {}; // ignore the error
339+
self.appendMessage(input.label, @unionInit(Message, @tagName(t), data));
323340
}
324341

325342
fn challengeBytes(
@@ -329,7 +346,6 @@ pub const Transcript = struct {
329346
) void {
330347
var data_len: [4]u8 = undefined;
331348
std.mem.writeInt(u32, &data_len, @intCast(destination.len), .little);
332-
333349
self.strobe.metaAd(label, false);
334350
self.strobe.metaAd(&data_len, true);
335351
self.strobe.prf(destination, false);
@@ -351,64 +367,89 @@ pub const Transcript = struct {
351367

352368
// domain seperation helpers
353369

354-
pub fn appendDomSep(self: *Transcript, comptime seperator: DomainSeperator) void {
355-
self.appendBytes("dom-sep", @tagName(seperator));
356-
}
357-
358-
pub fn appendHandleDomSep(
370+
pub inline fn appendDomSep(
359371
self: *Transcript,
360-
comptime mode: enum { batched, unbatched },
361-
comptime handles: enum { two, three },
372+
comptime session: *Session,
373+
comptime seperator: DomainSeperator,
362374
) void {
363-
self.appendDomSep(switch (mode) {
364-
.batched => .@"batched-validity-proof",
365-
.unbatched => .@"validity-proof",
366-
});
367-
self.appendMessage("handles", .{ .u64 = switch (handles) {
368-
.two => 2,
369-
.three => 3,
370-
} });
375+
self.append(session, .domsep, "dom-sep", seperator);
371376
}
372377

373-
pub fn appendRangeProof(
378+
pub inline fn appendRangeProof(
374379
self: *Transcript,
380+
comptime session: *Session,
375381
comptime mode: enum { range, inner },
376382
n: comptime_int,
377383
) void {
378-
self.appendDomSep(switch (mode) {
384+
self.appendDomSep(session, switch (mode) {
379385
.range => .@"range-proof",
380386
.inner => .@"inner-product",
381387
});
382-
self.appendMessage("n", .{ .u64 = n });
388+
self.append(session, .u64, "n", n);
383389
}
384390

385391
// sessions
386392

387393
pub const Input = struct {
388394
label: []const u8,
389395
type: Type,
396+
seperator: ?DomainSeperator = null,
390397

391398
const Type = enum {
392399
bytes,
393400
scalar,
394-
challenge,
401+
u64,
402+
395403
point,
396-
validate_point,
397404
pubkey,
405+
ciphertext,
406+
commitment,
407+
grouped_2,
408+
grouped_3,
398409

399-
pub fn Data(comptime t: Type) type {
410+
validate_point,
411+
validate_pubkey,
412+
validate_ciphertext,
413+
validate_commitment,
414+
validate_grouped_2,
415+
validate_grouped_3,
416+
417+
domsep,
418+
challenge,
419+
420+
/// Returns whether this input type performs identity validation.
421+
fn validates(t: Type) bool {
400422
return switch (t) {
401-
.bytes => []const u8,
402-
.scalar => Scalar,
403-
.validate_point, .point => Ristretto255,
404-
.pubkey => zksdk.elgamal.Pubkey,
405-
.challenge => unreachable, // call `challenge*`
423+
.validate_point,
424+
.validate_pubkey,
425+
.validate_ciphertext,
426+
.validate_commitment,
427+
.validate_grouped_2,
428+
.validate_grouped_3,
429+
=> true,
430+
else => false,
406431
};
407432
}
433+
434+
/// For a given input type, returns the base type.
435+
/// E.g. `validate_point` -> `point`
436+
/// E.g. `point` -> `point`
437+
fn base(t: Type) Type {
438+
if (t.validates()) {
439+
return @field(Type, @tagName(t)["validate_".len..]);
440+
}
441+
return t;
442+
}
408443
};
409444

445+
pub fn domain(sep: DomainSeperator) Input {
446+
return .{ .label = "dom-sep", .type = .domsep, .seperator = sep };
447+
}
448+
410449
fn check(self: Input, t: Type, label: []const u8) void {
411-
std.debug.assert(self.type == t);
450+
if (self.type != t) {
451+
@compileError("expected: " ++ @tagName(self.type) ++ ", found: " ++ @tagName(t));
452+
}
412453
std.debug.assert(std.mem.eql(u8, self.label, label));
413454
}
414455
};
@@ -418,7 +459,8 @@ pub const Transcript = struct {
418459
pub const Session = struct {
419460
i: u8,
420461
contract: Contract,
421-
err: bool, // if validate_point errors, we skip the finish() check
462+
// If an identity validation errors, we skip the finish() check.
463+
err: bool,
422464

423465
pub inline fn nextInput(comptime self: *Session, t: Input.Type, label: []const u8) Input {
424466
comptime {
@@ -453,6 +495,14 @@ pub const Transcript = struct {
453495
return .{ .i = 0, .contract = contract, .err = false };
454496
}
455497
}
498+
499+
/// The same as `getSession`, but does not check that it ends with a challenge.
500+
/// Only used in certain cases when we need an "init" contract, such as `percentage_with_cap`.
501+
pub inline fn getInitSession(comptime contract: []const Input) Session {
502+
comptime {
503+
return .{ .i = 0, .contract = contract, .err = false };
504+
}
505+
}
456506
};
457507

458508
test "equivalence" {
@@ -468,9 +518,9 @@ test "equivalence" {
468518
transcript.challengeBytes("challenge", &bytes);
469519

470520
try std.testing.expectEqualSlices(u8, &.{
471-
0xd5, 0xa2, 0x19, 0x72, 0xd0, 0xd5, 0xfe, 0x32,
472-
0xc, 0xd, 0x26, 0x3f, 0xac, 0x7f, 0xff, 0xb8,
473-
0x14, 0x5a, 0xa6, 0x40, 0xaf, 0x6e, 0x9b, 0xca,
474-
0x17, 0x7c, 0x3, 0xc7, 0xef, 0xcf, 0x6, 0x15,
521+
159, 115, 74, 116, 119, 227, 89, 42,
522+
108, 83, 69, 218, 43, 29, 11, 79,
523+
117, 141, 121, 172, 163, 50, 123, 92,
524+
25, 21, 111, 177, 11, 232, 4, 35,
475525
}, &bytes);
476526
}

src/zksdk/pedersen.zig

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ pub const Commitment = struct {
6666
);
6767
return fromBytes(buffer);
6868
}
69+
70+
pub fn rejectIdentity(self: *const Commitment) error{IdentityElement}!void {
71+
try self.point.rejectIdentity();
72+
}
6973
};
7074

7175
pub const DecryptHandle = struct {

0 commit comments

Comments
 (0)