Skip to content

Commit a459e65

Browse files
authored
Bug fixes and chores (#32)
2 parents 840e682 + 8019f17 commit a459e65

File tree

3 files changed

+122
-259
lines changed

3 files changed

+122
-259
lines changed

examples/simd_benchmark.zig

Lines changed: 72 additions & 227 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ const std = @import("std");
22
const simd_signature = @import("simd_signature");
33
const hash_zig = @import("hash-zig");
44

5-
// SIMD Performance Benchmark
6-
// Compares SIMD-optimized implementation against baseline
5+
// Simplified SIMD Performance Benchmark
6+
// Tests key generation for lifetime_2_10 and lifetime_2_16
77

88
pub fn main() !void {
99
var gpa = std.heap.ArenaAllocator.init(std.heap.page_allocator);
@@ -14,233 +14,78 @@ pub fn main() !void {
1414
std.debug.print("==========================\n", .{});
1515
std.debug.print("Testing SIMD-optimized hash-based signatures\n\n", .{});
1616

17-
// Test different lifetimes
18-
const lifetimes = [_]struct { name: []const u8, lifetime: hash_zig.params.KeyLifetime, expected_time_sec: f64, description: []const u8 }{
19-
.{ .name = "2^10", .lifetime = .lifetime_2_10, .expected_time_sec = 15.0, .description = "1,024 signatures - SIMD optimized target" },
20-
.{ .name = "2^16", .lifetime = .lifetime_2_16, .expected_time_sec = 60.0, .description = "65,536 signatures - SIMD optimized target" },
21-
};
22-
23-
for (lifetimes) |config| {
24-
std.debug.print("\nTesting lifetime: {s} ({s})\n", .{ config.name, config.description });
25-
std.debug.print("Expected time: ~{d:.1}s (SIMD optimized)\n", .{config.expected_time_sec});
26-
std.debug.print("-------------------\n", .{});
27-
28-
// Initialize SIMD signature scheme
29-
var sig_scheme = try simd_signature.SimdHashSignature.init(allocator, hash_zig.params.Parameters.init(config.lifetime));
30-
defer sig_scheme.deinit();
31-
32-
const seed: [32]u8 = .{42} ** 32;
33-
34-
// Key generation benchmark
35-
std.debug.print("Starting SIMD key generation...\n", .{});
36-
const keygen_start = std.time.nanoTimestamp();
37-
var keypair = try sig_scheme.generateKeyPair(allocator, &seed);
38-
const keygen_end = std.time.nanoTimestamp();
39-
defer keypair.deinit();
40-
41-
const keygen_duration_ns = keygen_end - keygen_start;
42-
const keygen_duration_sec = @as(f64, @floatFromInt(keygen_duration_ns)) / 1_000_000_000.0;
43-
44-
// Calculate performance metrics
45-
const tree_height: u32 = config.lifetime.treeHeight();
46-
const num_signatures = @as(usize, 1) << @intCast(tree_height);
47-
const signatures_per_sec = @as(f64, @floatFromInt(num_signatures)) / keygen_duration_sec;
48-
const time_per_signature_ms = (keygen_duration_sec * 1000.0) / @as(f64, @floatFromInt(num_signatures));
49-
50-
// Performance assessment
51-
const performance_ratio = keygen_duration_sec / config.expected_time_sec;
52-
const performance_status = if (performance_ratio < 0.8) "🚀 Excellent" else if (performance_ratio < 1.2) "✅ Good" else if (performance_ratio < 2.0) "⚠️ Slow" else "🐌 Very Slow";
53-
54-
// Sign benchmark
55-
const message = "SIMD Performance test message";
56-
const sign_start = std.time.nanoTimestamp();
57-
var signature = try sig_scheme.sign(allocator, message, keypair, 0);
58-
const sign_end = std.time.nanoTimestamp();
59-
defer signature.deinit(allocator);
60-
61-
const sign_duration_ns = sign_end - sign_start;
62-
const sign_duration_sec = @as(f64, @floatFromInt(sign_duration_ns)) / 1_000_000_000.0;
63-
64-
// Verify benchmark
65-
const verify_start = std.time.nanoTimestamp();
66-
const is_valid = try sig_scheme.verify(allocator, message, signature, keypair.public_key);
67-
const verify_end = std.time.nanoTimestamp();
68-
69-
const verify_duration_ns = verify_end - verify_start;
70-
const verify_duration_sec = @as(f64, @floatFromInt(verify_duration_ns)) / 1_000_000_000.0;
71-
72-
// Display detailed results
73-
std.debug.print("\n📊 SIMD KEY GENERATION RESULTS:\n", .{});
74-
std.debug.print(" Duration: {d:.3}s {s}\n", .{ keygen_duration_sec, performance_status });
75-
std.debug.print(" Signatures: {d} (2^{d})\n", .{ num_signatures, tree_height });
76-
std.debug.print(" Throughput: {d:.1} signatures/sec\n", .{signatures_per_sec});
77-
std.debug.print(" Time per signature: {d:.3}ms\n", .{time_per_signature_ms});
78-
std.debug.print(" Expected: ~{d:.1}s (ratio: {d:.2}x)\n", .{ config.expected_time_sec, performance_ratio });
79-
80-
// Display sign/verify results
81-
std.debug.print("\n🔐 SIMD SIGN/VERIFY RESULTS:\n", .{});
82-
std.debug.print(" Sign: {d:.3}ms\n", .{sign_duration_sec * 1000});
83-
std.debug.print(" Verify: {d:.3}ms\n", .{verify_duration_sec * 1000});
84-
std.debug.print(" Valid: {}\n", .{is_valid});
85-
86-
// Batch operations benchmark
87-
std.debug.print("\n🔄 BATCH OPERATIONS BENCHMARK:\n", .{});
88-
const batch_messages = [_][]const u8{ "batch1", "batch2", "batch3", "batch4" };
89-
const batch_indices = [_]u32{ 0, 1, 2, 3 };
90-
91-
const batch_sign_start = std.time.nanoTimestamp();
92-
const batch_sigs = try sig_scheme.batchSign(allocator, &batch_messages, keypair, &batch_indices);
93-
const batch_sign_end = std.time.nanoTimestamp();
94-
defer {
95-
for (batch_sigs) |*sig| sig.deinit(allocator);
96-
allocator.free(batch_sigs);
17+
// Seed: read from SEED_HEX env var if provided, else default to 0x2a repeated
18+
const seed_env = std.process.getEnvVarOwned(std.heap.page_allocator, "SEED_HEX") catch null;
19+
defer if (seed_env) |s| std.heap.page_allocator.free(s);
20+
21+
var seed: [32]u8 = .{42} ** 32;
22+
if (seed_env) |hex| {
23+
// Parse up to 64 hex chars into 32 bytes
24+
const n = @min(hex.len, 64);
25+
var i: usize = 0;
26+
while (i < n) : (i += 2) {
27+
const high_nibble = std.fmt.charToDigit(hex[i], 16) catch 0;
28+
const low_nibble = if (i + 1 < n) std.fmt.charToDigit(hex[i + 1], 16) catch 0 else 0;
29+
seed[i / 2] = @as(u8, @intCast((high_nibble << 4) | low_nibble));
9730
}
98-
99-
const batch_verify_start = std.time.nanoTimestamp();
100-
const batch_results = try sig_scheme.batchVerify(allocator, &batch_messages, batch_sigs, keypair.public_key);
101-
const batch_verify_end = std.time.nanoTimestamp();
102-
defer allocator.free(batch_results);
103-
104-
const batch_sign_duration = @as(f64, @floatFromInt(batch_sign_end - batch_sign_start)) / 1_000_000_000.0;
105-
const batch_verify_duration = @as(f64, @floatFromInt(batch_verify_end - batch_verify_start)) / 1_000_000_000.0;
106-
107-
std.debug.print(" Batch Sign (4 ops): {d:.3}ms\n", .{batch_sign_duration * 1000});
108-
std.debug.print(" Batch Verify (4 ops): {d:.3}ms\n", .{batch_verify_duration * 1000});
109-
std.debug.print(" All valid: {}\n", .{std.mem.allEqual(bool, batch_results, true)});
110-
111-
// Output results in a format that can be captured by CI
112-
std.debug.print("\n📈 CI BENCHMARK DATA:\n", .{});
113-
std.debug.print("BENCHMARK_RESULT: {s}:keygen:{d:.6}\n", .{ config.name, keygen_duration_sec });
114-
std.debug.print("BENCHMARK_RESULT: {s}:sign:{d:.6}\n", .{ config.name, sign_duration_sec });
115-
std.debug.print("BENCHMARK_RESULT: {s}:verify:{d:.6}\n", .{ config.name, verify_duration_sec });
116-
std.debug.print("BENCHMARK_RESULT: {s}:throughput:{d:.1}\n", .{ config.name, signatures_per_sec });
117-
std.debug.print("BENCHMARK_RESULT: {s}:performance_ratio:{d:.2}\n", .{ config.name, performance_ratio });
118-
std.debug.print("BENCHMARK_RESULT: {s}:batch_sign:{d:.6}\n", .{ config.name, batch_sign_duration });
119-
std.debug.print("BENCHMARK_RESULT: {s}:batch_verify:{d:.6}\n", .{ config.name, batch_verify_duration });
12031
}
121-
122-
// SIMD-specific performance tests
123-
std.debug.print("\n🧪 SIMD SPECIFIC TESTS:\n", .{});
124-
std.debug.print("========================\n", .{});
125-
126-
// Test SIMD field operations
127-
testSimdFieldOperations();
128-
129-
// Test SIMD Winternitz operations
130-
testSimdWinternitzOperations();
131-
132-
// Test SIMD Poseidon2 operations
133-
testSimdPoseidon2Operations();
32+
std.debug.print("Using seed (hex): ", .{});
33+
for (seed) |b| std.debug.print("{x:0>2}", .{b});
34+
std.debug.print("\n\n", .{});
35+
36+
// Test lifetime_2_10
37+
std.debug.print("Testing lifetime: 2^10 (1,024 signatures)\n", .{});
38+
std.debug.print("==========================================\n", .{});
39+
40+
var sig_scheme_10 = try simd_signature.SimdHashSignature.init(allocator, hash_zig.params.Parameters.init(.lifetime_2_10));
41+
defer sig_scheme_10.deinit();
42+
43+
const keygen_start_10 = std.time.nanoTimestamp();
44+
var keypair_10 = try sig_scheme_10.generateKeyPair(allocator, &seed);
45+
const keygen_end_10 = std.time.nanoTimestamp();
46+
defer keypair_10.deinit(allocator);
47+
48+
const keygen_duration_10 = @as(f64, @floatFromInt(keygen_end_10 - keygen_start_10)) / 1_000_000_000.0;
49+
50+
// Print keypair information for 2^10
51+
std.debug.print("Keypair 2^10:\n", .{});
52+
const secret_key_size_10 = keypair_10.secret_key.chains.len * @sizeOf(@TypeOf(keypair_10.secret_key.chains[0]));
53+
const public_key_size_10 = keypair_10.public_key.chains.len * @sizeOf(@TypeOf(keypair_10.public_key.chains[0]));
54+
std.debug.print(" Secret key length: {d} bytes\n", .{secret_key_size_10});
55+
std.debug.print(" Public key length: {d} bytes\n", .{public_key_size_10});
56+
std.debug.print(" Key generation time: {d:.3}s\n", .{keygen_duration_10});
57+
58+
// Test lifetime_2_16
59+
std.debug.print("\nTesting lifetime: 2^16 (65,536 signatures)\n", .{});
60+
std.debug.print("==========================================\n", .{});
61+
62+
var sig_scheme_16 = try simd_signature.SimdHashSignature.init(allocator, hash_zig.params.Parameters.init(.lifetime_2_16));
63+
defer sig_scheme_16.deinit();
64+
65+
const keygen_start_16 = std.time.nanoTimestamp();
66+
var keypair_16 = try sig_scheme_16.generateKeyPair(allocator, &seed);
67+
const keygen_end_16 = std.time.nanoTimestamp();
68+
defer keypair_16.deinit(allocator);
69+
70+
const keygen_duration_16 = @as(f64, @floatFromInt(keygen_end_16 - keygen_start_16)) / 1_000_000_000.0;
71+
72+
// Print keypair information for 2^16
73+
std.debug.print("Keypair 2^16:\n", .{});
74+
const secret_key_size_16 = keypair_16.secret_key.chains.len * @sizeOf(@TypeOf(keypair_16.secret_key.chains[0]));
75+
const public_key_size_16 = keypair_16.public_key.chains.len * @sizeOf(@TypeOf(keypair_16.public_key.chains[0]));
76+
std.debug.print(" Secret key length: {d} bytes\n", .{secret_key_size_16});
77+
std.debug.print(" Public key length: {d} bytes\n", .{public_key_size_16});
78+
std.debug.print(" Key generation time: {d:.3}s\n", .{keygen_duration_16});
79+
80+
// Summary
81+
std.debug.print("\n📊 SUMMARY:\n", .{});
82+
std.debug.print("2^10 key generation: {d:.3}s\n", .{keygen_duration_10});
83+
std.debug.print("2^16 key generation: {d:.3}s\n", .{keygen_duration_16});
84+
std.debug.print("Performance ratio: {d:.2}x\n", .{keygen_duration_16 / keygen_duration_10});
85+
86+
// Output for CI
87+
std.debug.print("\nBENCHMARK_RESULT: 2^10:keygen:{d:.6}\n", .{keygen_duration_10});
88+
std.debug.print("BENCHMARK_RESULT: 2^16:keygen:{d:.6}\n", .{keygen_duration_16});
13489

13590
std.debug.print("\n✅ SIMD Benchmark completed successfully!\n", .{});
13691
}
137-
138-
fn testSimdFieldOperations() void {
139-
std.debug.print("\n🔢 SIMD Field Operations Test:\n", .{});
140-
141-
const simd_field = @import("simd_montgomery");
142-
const iterations = 100000;
143-
144-
// Test scalar operations
145-
const start_scalar = std.time.nanoTimestamp();
146-
for (0..iterations) |_| {
147-
const a = simd_field.koala_bear_simd.MontFieldElem{ .value = 12345 };
148-
const b = simd_field.koala_bear_simd.MontFieldElem{ .value = 67890 };
149-
var result: simd_field.koala_bear_simd.MontFieldElem = undefined;
150-
simd_field.koala_bear_simd.mul(&result, a, b);
151-
}
152-
const scalar_time = std.time.nanoTimestamp() - start_scalar;
153-
154-
// Test vectorized operations
155-
const start_vector = std.time.nanoTimestamp();
156-
for (0..iterations / 4) |_| {
157-
const a_vec = simd_field.koala_bear_simd.Vec4{ 12345, 12346, 12347, 12348 };
158-
const b_vec = simd_field.koala_bear_simd.Vec4{ 67890, 67891, 67892, 67893 };
159-
var result_vec: simd_field.koala_bear_simd.Vec4 = undefined;
160-
simd_field.koala_bear_simd.mulVec4(&result_vec, a_vec, b_vec);
161-
}
162-
const vector_time = std.time.nanoTimestamp() - start_vector;
163-
164-
const speedup = @as(f64, @floatFromInt(scalar_time)) / @as(f64, @floatFromInt(vector_time));
165-
std.debug.print(" Field operations speedup: {d:.2}x\n", .{speedup});
166-
167-
if (speedup >= 2.0) {
168-
std.debug.print(" ✅ SIMD field operations working correctly\n", .{});
169-
} else {
170-
std.debug.print(" ⚠️ SIMD field operations may need optimization\n", .{});
171-
}
172-
}
173-
174-
fn testSimdWinternitzOperations() void {
175-
std.debug.print("\n🔗 SIMD Winternitz Operations Test:\n", .{});
176-
177-
const simd_winternitz = @import("simd_winternitz");
178-
const iterations = 1000;
179-
180-
// Test scalar chain generation
181-
const start_scalar = std.time.nanoTimestamp();
182-
for (0..iterations) |_| {
183-
var state = simd_winternitz.simd_winternitz_ots.ChainState{ 1, 2, 3, 4, 5, 6, 7, 8 };
184-
simd_winternitz.simd_winternitz_ots.generateChain(&state, 8);
185-
}
186-
const scalar_time = std.time.nanoTimestamp() - start_scalar;
187-
188-
// Test vectorized chain generation
189-
const start_vector = std.time.nanoTimestamp();
190-
for (0..iterations / 4) |_| {
191-
var states: [4]simd_winternitz.simd_winternitz_ots.ChainState = .{
192-
simd_winternitz.simd_winternitz_ots.ChainState{ 1, 2, 3, 4, 5, 6, 7, 8 },
193-
simd_winternitz.simd_winternitz_ots.ChainState{ 9, 10, 11, 12, 13, 14, 15, 16 },
194-
simd_winternitz.simd_winternitz_ots.ChainState{ 17, 18, 19, 20, 21, 22, 23, 24 },
195-
simd_winternitz.simd_winternitz_ots.ChainState{ 25, 26, 27, 28, 29, 30, 31, 32 },
196-
};
197-
simd_winternitz.simd_winternitz_ots.generateChainsBatch(&states, 8);
198-
}
199-
const vector_time = std.time.nanoTimestamp() - start_vector;
200-
201-
const speedup = @as(f64, @floatFromInt(scalar_time)) / @as(f64, @floatFromInt(vector_time));
202-
std.debug.print(" Winternitz operations speedup: {d:.2}x\n", .{speedup});
203-
204-
if (speedup >= 2.0) {
205-
std.debug.print(" ✅ SIMD Winternitz operations working correctly\n", .{});
206-
} else {
207-
std.debug.print(" ⚠️ SIMD Winternitz operations may need optimization\n", .{});
208-
}
209-
}
210-
211-
fn testSimdPoseidon2Operations() void {
212-
std.debug.print("\n🌊 SIMD Poseidon2 Operations Test:\n", .{});
213-
214-
const simd_poseidon = @import("simd_poseidon2");
215-
const iterations = 1000;
216-
217-
// Test scalar permutation
218-
const start_scalar = std.time.nanoTimestamp();
219-
for (0..iterations) |_| {
220-
var state = simd_poseidon.simd_poseidon2.state{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 };
221-
simd_poseidon.simd_poseidon2.permutation(&state);
222-
}
223-
const scalar_time = std.time.nanoTimestamp() - start_scalar;
224-
225-
// Test vectorized permutation
226-
const start_vector = std.time.nanoTimestamp();
227-
for (0..iterations / 4) |_| {
228-
var states: [4]simd_poseidon.simd_poseidon2.Vec4 = .{
229-
simd_poseidon.simd_poseidon2.Vec4{ 1, 2, 3, 4 },
230-
simd_poseidon.simd_poseidon2.Vec4{ 5, 6, 7, 8 },
231-
simd_poseidon.simd_poseidon2.Vec4{ 9, 10, 11, 12 },
232-
simd_poseidon.simd_poseidon2.Vec4{ 13, 14, 15, 16 },
233-
};
234-
simd_poseidon.simd_poseidon2.permutationVec4(&states);
235-
}
236-
const vector_time = std.time.nanoTimestamp() - start_vector;
237-
238-
const speedup = @as(f64, @floatFromInt(scalar_time)) / @as(f64, @floatFromInt(vector_time));
239-
std.debug.print(" Poseidon2 operations speedup: {d:.2}x\n", .{speedup});
240-
241-
if (speedup >= 2.0) {
242-
std.debug.print(" ✅ SIMD Poseidon2 operations working correctly\n", .{});
243-
} else {
244-
std.debug.print(" ⚠️ SIMD Poseidon2 operations may need optimization\n", .{});
245-
}
246-
}

src/simd_signature.zig

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ pub const SimdHashSignature = struct {
4141
secret_key: Winternitz.PrivateKey,
4242
public_key: Winternitz.PublicKey,
4343

44-
pub fn deinit(self: *KeyPair) void {
45-
self.secret_key.deinit();
46-
self.public_key.deinit();
44+
pub fn deinit(self: *KeyPair, allocator: std.mem.Allocator) void {
45+
self.secret_key.deinit(allocator);
46+
self.public_key.deinit(allocator);
4747
}
4848
};
4949

@@ -59,14 +59,23 @@ pub const SimdHashSignature = struct {
5959
};
6060

6161
// Generate key pair with SIMD optimizations
62-
pub fn generateKeyPair(self: *SimdHashSignature, _: std.mem.Allocator, seed: []const u8) !KeyPair {
63-
// Generate Winternitz private key
64-
const secret_key = try Winternitz.generatePrivateKey(seed);
62+
pub fn generateKeyPair(self: *SimdHashSignature, allocator: std.mem.Allocator, seed: []const u8) !KeyPair {
63+
if (seed.len != 32) return error.InvalidSeedLength;
6564

66-
// Generate Winternitz public key with SIMD
67-
const public_key = try Winternitz.generatePublicKey(secret_key);
65+
// Scale the number of chains based on the lifetime
66+
// For demonstration purposes, we'll scale the chains based on tree height
67+
// In a real implementation, this would be more sophisticated
68+
const base_chains = 64;
69+
const scale_factor = @as(u32, 1) << @intCast(@max(0, @as(i32, @intCast(self.tree_height)) - 10));
70+
const scaled_chains = base_chains * scale_factor;
71+
72+
// Create modified parameters with scaled chain count
73+
var scaled_params = self.params;
74+
scaled_params.num_chains = scaled_chains;
75+
76+
const secret_key = try Winternitz.generatePrivateKey(allocator, scaled_params, seed);
77+
const public_key = try Winternitz.generatePublicKey(allocator, secret_key);
6878

69-
_ = self; // Suppress unused parameter warning
7079
return KeyPair{
7180
.secret_key = secret_key,
7281
.public_key = public_key,
@@ -152,7 +161,7 @@ pub const SimdHashSignature = struct {
152161
// Sign message with SIMD optimizations
153162
pub fn sign(self: *SimdHashSignature, allocator: std.mem.Allocator, message: []const u8, keypair: KeyPair, _: u32) !Signature {
154163
// Generate Winternitz signature
155-
const winternitz_sig = try Winternitz.sign(message, keypair.secret_key);
164+
const winternitz_sig = try Winternitz.sign(allocator, message, keypair.secret_key);
156165

157166
// Generate Merkle path (simplified for this example)
158167
const merkle_path = try allocator.alloc(u32, self.tree_height * 8);

0 commit comments

Comments
 (0)