Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 72 additions & 227 deletions examples/simd_benchmark.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ const std = @import("std");
const simd_signature = @import("simd_signature");
const hash_zig = @import("hash-zig");

// SIMD Performance Benchmark
// Compares SIMD-optimized implementation against baseline
// Simplified SIMD Performance Benchmark
// Tests key generation for lifetime_2_10 and lifetime_2_16

pub fn main() !void {
var gpa = std.heap.ArenaAllocator.init(std.heap.page_allocator);
Expand All @@ -14,233 +14,78 @@ pub fn main() !void {
std.debug.print("==========================\n", .{});
std.debug.print("Testing SIMD-optimized hash-based signatures\n\n", .{});

// Test different lifetimes
const lifetimes = [_]struct { name: []const u8, lifetime: hash_zig.params.KeyLifetime, expected_time_sec: f64, description: []const u8 }{
.{ .name = "2^10", .lifetime = .lifetime_2_10, .expected_time_sec = 15.0, .description = "1,024 signatures - SIMD optimized target" },
.{ .name = "2^16", .lifetime = .lifetime_2_16, .expected_time_sec = 60.0, .description = "65,536 signatures - SIMD optimized target" },
};

for (lifetimes) |config| {
std.debug.print("\nTesting lifetime: {s} ({s})\n", .{ config.name, config.description });
std.debug.print("Expected time: ~{d:.1}s (SIMD optimized)\n", .{config.expected_time_sec});
std.debug.print("-------------------\n", .{});

// Initialize SIMD signature scheme
var sig_scheme = try simd_signature.SimdHashSignature.init(allocator, hash_zig.params.Parameters.init(config.lifetime));
defer sig_scheme.deinit();

const seed: [32]u8 = .{42} ** 32;

// Key generation benchmark
std.debug.print("Starting SIMD key generation...\n", .{});
const keygen_start = std.time.nanoTimestamp();
var keypair = try sig_scheme.generateKeyPair(allocator, &seed);
const keygen_end = std.time.nanoTimestamp();
defer keypair.deinit();

const keygen_duration_ns = keygen_end - keygen_start;
const keygen_duration_sec = @as(f64, @floatFromInt(keygen_duration_ns)) / 1_000_000_000.0;

// Calculate performance metrics
const tree_height: u32 = config.lifetime.treeHeight();
const num_signatures = @as(usize, 1) << @intCast(tree_height);
const signatures_per_sec = @as(f64, @floatFromInt(num_signatures)) / keygen_duration_sec;
const time_per_signature_ms = (keygen_duration_sec * 1000.0) / @as(f64, @floatFromInt(num_signatures));

// Performance assessment
const performance_ratio = keygen_duration_sec / config.expected_time_sec;
const performance_status = if (performance_ratio < 0.8) "🚀 Excellent" else if (performance_ratio < 1.2) "✅ Good" else if (performance_ratio < 2.0) "⚠️ Slow" else "🐌 Very Slow";

// Sign benchmark
const message = "SIMD Performance test message";
const sign_start = std.time.nanoTimestamp();
var signature = try sig_scheme.sign(allocator, message, keypair, 0);
const sign_end = std.time.nanoTimestamp();
defer signature.deinit(allocator);

const sign_duration_ns = sign_end - sign_start;
const sign_duration_sec = @as(f64, @floatFromInt(sign_duration_ns)) / 1_000_000_000.0;

// Verify benchmark
const verify_start = std.time.nanoTimestamp();
const is_valid = try sig_scheme.verify(allocator, message, signature, keypair.public_key);
const verify_end = std.time.nanoTimestamp();

const verify_duration_ns = verify_end - verify_start;
const verify_duration_sec = @as(f64, @floatFromInt(verify_duration_ns)) / 1_000_000_000.0;

// Display detailed results
std.debug.print("\n📊 SIMD KEY GENERATION RESULTS:\n", .{});
std.debug.print(" Duration: {d:.3}s {s}\n", .{ keygen_duration_sec, performance_status });
std.debug.print(" Signatures: {d} (2^{d})\n", .{ num_signatures, tree_height });
std.debug.print(" Throughput: {d:.1} signatures/sec\n", .{signatures_per_sec});
std.debug.print(" Time per signature: {d:.3}ms\n", .{time_per_signature_ms});
std.debug.print(" Expected: ~{d:.1}s (ratio: {d:.2}x)\n", .{ config.expected_time_sec, performance_ratio });

// Display sign/verify results
std.debug.print("\n🔐 SIMD SIGN/VERIFY RESULTS:\n", .{});
std.debug.print(" Sign: {d:.3}ms\n", .{sign_duration_sec * 1000});
std.debug.print(" Verify: {d:.3}ms\n", .{verify_duration_sec * 1000});
std.debug.print(" Valid: {}\n", .{is_valid});

// Batch operations benchmark
std.debug.print("\n🔄 BATCH OPERATIONS BENCHMARK:\n", .{});
const batch_messages = [_][]const u8{ "batch1", "batch2", "batch3", "batch4" };
const batch_indices = [_]u32{ 0, 1, 2, 3 };

const batch_sign_start = std.time.nanoTimestamp();
const batch_sigs = try sig_scheme.batchSign(allocator, &batch_messages, keypair, &batch_indices);
const batch_sign_end = std.time.nanoTimestamp();
defer {
for (batch_sigs) |*sig| sig.deinit(allocator);
allocator.free(batch_sigs);
// Seed: read from SEED_HEX env var if provided, else default to 0x2a repeated
const seed_env = std.process.getEnvVarOwned(std.heap.page_allocator, "SEED_HEX") catch null;
defer if (seed_env) |s| std.heap.page_allocator.free(s);

var seed: [32]u8 = .{42} ** 32;
if (seed_env) |hex| {
// Parse up to 64 hex chars into 32 bytes
const n = @min(hex.len, 64);
var i: usize = 0;
while (i < n) : (i += 2) {
const high_nibble = std.fmt.charToDigit(hex[i], 16) catch 0;
const low_nibble = if (i + 1 < n) std.fmt.charToDigit(hex[i + 1], 16) catch 0 else 0;
seed[i / 2] = @as(u8, @intCast((high_nibble << 4) | low_nibble));
}

const batch_verify_start = std.time.nanoTimestamp();
const batch_results = try sig_scheme.batchVerify(allocator, &batch_messages, batch_sigs, keypair.public_key);
const batch_verify_end = std.time.nanoTimestamp();
defer allocator.free(batch_results);

const batch_sign_duration = @as(f64, @floatFromInt(batch_sign_end - batch_sign_start)) / 1_000_000_000.0;
const batch_verify_duration = @as(f64, @floatFromInt(batch_verify_end - batch_verify_start)) / 1_000_000_000.0;

std.debug.print(" Batch Sign (4 ops): {d:.3}ms\n", .{batch_sign_duration * 1000});
std.debug.print(" Batch Verify (4 ops): {d:.3}ms\n", .{batch_verify_duration * 1000});
std.debug.print(" All valid: {}\n", .{std.mem.allEqual(bool, batch_results, true)});

// Output results in a format that can be captured by CI
std.debug.print("\n📈 CI BENCHMARK DATA:\n", .{});
std.debug.print("BENCHMARK_RESULT: {s}:keygen:{d:.6}\n", .{ config.name, keygen_duration_sec });
std.debug.print("BENCHMARK_RESULT: {s}:sign:{d:.6}\n", .{ config.name, sign_duration_sec });
std.debug.print("BENCHMARK_RESULT: {s}:verify:{d:.6}\n", .{ config.name, verify_duration_sec });
std.debug.print("BENCHMARK_RESULT: {s}:throughput:{d:.1}\n", .{ config.name, signatures_per_sec });
std.debug.print("BENCHMARK_RESULT: {s}:performance_ratio:{d:.2}\n", .{ config.name, performance_ratio });
std.debug.print("BENCHMARK_RESULT: {s}:batch_sign:{d:.6}\n", .{ config.name, batch_sign_duration });
std.debug.print("BENCHMARK_RESULT: {s}:batch_verify:{d:.6}\n", .{ config.name, batch_verify_duration });
}

// SIMD-specific performance tests
std.debug.print("\n🧪 SIMD SPECIFIC TESTS:\n", .{});
std.debug.print("========================\n", .{});

// Test SIMD field operations
testSimdFieldOperations();

// Test SIMD Winternitz operations
testSimdWinternitzOperations();

// Test SIMD Poseidon2 operations
testSimdPoseidon2Operations();
std.debug.print("Using seed (hex): ", .{});
for (seed) |b| std.debug.print("{x:0>2}", .{b});
std.debug.print("\n\n", .{});

// Test lifetime_2_10
std.debug.print("Testing lifetime: 2^10 (1,024 signatures)\n", .{});
std.debug.print("==========================================\n", .{});

var sig_scheme_10 = try simd_signature.SimdHashSignature.init(allocator, hash_zig.params.Parameters.init(.lifetime_2_10));
defer sig_scheme_10.deinit();

const keygen_start_10 = std.time.nanoTimestamp();
var keypair_10 = try sig_scheme_10.generateKeyPair(allocator, &seed);
const keygen_end_10 = std.time.nanoTimestamp();
defer keypair_10.deinit(allocator);

const keygen_duration_10 = @as(f64, @floatFromInt(keygen_end_10 - keygen_start_10)) / 1_000_000_000.0;

// Print keypair information for 2^10
std.debug.print("Keypair 2^10:\n", .{});
const secret_key_size_10 = keypair_10.secret_key.chains.len * @sizeOf(@TypeOf(keypair_10.secret_key.chains[0]));
const public_key_size_10 = keypair_10.public_key.chains.len * @sizeOf(@TypeOf(keypair_10.public_key.chains[0]));
std.debug.print(" Secret key length: {d} bytes\n", .{secret_key_size_10});
std.debug.print(" Public key length: {d} bytes\n", .{public_key_size_10});
std.debug.print(" Key generation time: {d:.3}s\n", .{keygen_duration_10});

// Test lifetime_2_16
std.debug.print("\nTesting lifetime: 2^16 (65,536 signatures)\n", .{});
std.debug.print("==========================================\n", .{});

var sig_scheme_16 = try simd_signature.SimdHashSignature.init(allocator, hash_zig.params.Parameters.init(.lifetime_2_16));
defer sig_scheme_16.deinit();

const keygen_start_16 = std.time.nanoTimestamp();
var keypair_16 = try sig_scheme_16.generateKeyPair(allocator, &seed);
const keygen_end_16 = std.time.nanoTimestamp();
defer keypair_16.deinit(allocator);

const keygen_duration_16 = @as(f64, @floatFromInt(keygen_end_16 - keygen_start_16)) / 1_000_000_000.0;

// Print keypair information for 2^16
std.debug.print("Keypair 2^16:\n", .{});
const secret_key_size_16 = keypair_16.secret_key.chains.len * @sizeOf(@TypeOf(keypair_16.secret_key.chains[0]));
const public_key_size_16 = keypair_16.public_key.chains.len * @sizeOf(@TypeOf(keypair_16.public_key.chains[0]));
std.debug.print(" Secret key length: {d} bytes\n", .{secret_key_size_16});
std.debug.print(" Public key length: {d} bytes\n", .{public_key_size_16});
std.debug.print(" Key generation time: {d:.3}s\n", .{keygen_duration_16});

// Summary
std.debug.print("\n📊 SUMMARY:\n", .{});
std.debug.print("2^10 key generation: {d:.3}s\n", .{keygen_duration_10});
std.debug.print("2^16 key generation: {d:.3}s\n", .{keygen_duration_16});
std.debug.print("Performance ratio: {d:.2}x\n", .{keygen_duration_16 / keygen_duration_10});

// Output for CI
std.debug.print("\nBENCHMARK_RESULT: 2^10:keygen:{d:.6}\n", .{keygen_duration_10});
std.debug.print("BENCHMARK_RESULT: 2^16:keygen:{d:.6}\n", .{keygen_duration_16});

std.debug.print("\n✅ SIMD Benchmark completed successfully!\n", .{});
}

fn testSimdFieldOperations() void {
std.debug.print("\n🔢 SIMD Field Operations Test:\n", .{});

const simd_field = @import("simd_montgomery");
const iterations = 100000;

// Test scalar operations
const start_scalar = std.time.nanoTimestamp();
for (0..iterations) |_| {
const a = simd_field.koala_bear_simd.MontFieldElem{ .value = 12345 };
const b = simd_field.koala_bear_simd.MontFieldElem{ .value = 67890 };
var result: simd_field.koala_bear_simd.MontFieldElem = undefined;
simd_field.koala_bear_simd.mul(&result, a, b);
}
const scalar_time = std.time.nanoTimestamp() - start_scalar;

// Test vectorized operations
const start_vector = std.time.nanoTimestamp();
for (0..iterations / 4) |_| {
const a_vec = simd_field.koala_bear_simd.Vec4{ 12345, 12346, 12347, 12348 };
const b_vec = simd_field.koala_bear_simd.Vec4{ 67890, 67891, 67892, 67893 };
var result_vec: simd_field.koala_bear_simd.Vec4 = undefined;
simd_field.koala_bear_simd.mulVec4(&result_vec, a_vec, b_vec);
}
const vector_time = std.time.nanoTimestamp() - start_vector;

const speedup = @as(f64, @floatFromInt(scalar_time)) / @as(f64, @floatFromInt(vector_time));
std.debug.print(" Field operations speedup: {d:.2}x\n", .{speedup});

if (speedup >= 2.0) {
std.debug.print(" ✅ SIMD field operations working correctly\n", .{});
} else {
std.debug.print(" ⚠️ SIMD field operations may need optimization\n", .{});
}
}

fn testSimdWinternitzOperations() void {
std.debug.print("\n🔗 SIMD Winternitz Operations Test:\n", .{});

const simd_winternitz = @import("simd_winternitz");
const iterations = 1000;

// Test scalar chain generation
const start_scalar = std.time.nanoTimestamp();
for (0..iterations) |_| {
var state = simd_winternitz.simd_winternitz_ots.ChainState{ 1, 2, 3, 4, 5, 6, 7, 8 };
simd_winternitz.simd_winternitz_ots.generateChain(&state, 8);
}
const scalar_time = std.time.nanoTimestamp() - start_scalar;

// Test vectorized chain generation
const start_vector = std.time.nanoTimestamp();
for (0..iterations / 4) |_| {
var states: [4]simd_winternitz.simd_winternitz_ots.ChainState = .{
simd_winternitz.simd_winternitz_ots.ChainState{ 1, 2, 3, 4, 5, 6, 7, 8 },
simd_winternitz.simd_winternitz_ots.ChainState{ 9, 10, 11, 12, 13, 14, 15, 16 },
simd_winternitz.simd_winternitz_ots.ChainState{ 17, 18, 19, 20, 21, 22, 23, 24 },
simd_winternitz.simd_winternitz_ots.ChainState{ 25, 26, 27, 28, 29, 30, 31, 32 },
};
simd_winternitz.simd_winternitz_ots.generateChainsBatch(&states, 8);
}
const vector_time = std.time.nanoTimestamp() - start_vector;

const speedup = @as(f64, @floatFromInt(scalar_time)) / @as(f64, @floatFromInt(vector_time));
std.debug.print(" Winternitz operations speedup: {d:.2}x\n", .{speedup});

if (speedup >= 2.0) {
std.debug.print(" ✅ SIMD Winternitz operations working correctly\n", .{});
} else {
std.debug.print(" ⚠️ SIMD Winternitz operations may need optimization\n", .{});
}
}

fn testSimdPoseidon2Operations() void {
std.debug.print("\n🌊 SIMD Poseidon2 Operations Test:\n", .{});

const simd_poseidon = @import("simd_poseidon2");
const iterations = 1000;

// Test scalar permutation
const start_scalar = std.time.nanoTimestamp();
for (0..iterations) |_| {
var state = simd_poseidon.simd_poseidon2.state{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 };
simd_poseidon.simd_poseidon2.permutation(&state);
}
const scalar_time = std.time.nanoTimestamp() - start_scalar;

// Test vectorized permutation
const start_vector = std.time.nanoTimestamp();
for (0..iterations / 4) |_| {
var states: [4]simd_poseidon.simd_poseidon2.Vec4 = .{
simd_poseidon.simd_poseidon2.Vec4{ 1, 2, 3, 4 },
simd_poseidon.simd_poseidon2.Vec4{ 5, 6, 7, 8 },
simd_poseidon.simd_poseidon2.Vec4{ 9, 10, 11, 12 },
simd_poseidon.simd_poseidon2.Vec4{ 13, 14, 15, 16 },
};
simd_poseidon.simd_poseidon2.permutationVec4(&states);
}
const vector_time = std.time.nanoTimestamp() - start_vector;

const speedup = @as(f64, @floatFromInt(scalar_time)) / @as(f64, @floatFromInt(vector_time));
std.debug.print(" Poseidon2 operations speedup: {d:.2}x\n", .{speedup});

if (speedup >= 2.0) {
std.debug.print(" ✅ SIMD Poseidon2 operations working correctly\n", .{});
} else {
std.debug.print(" ⚠️ SIMD Poseidon2 operations may need optimization\n", .{});
}
}
29 changes: 19 additions & 10 deletions src/simd_signature.zig
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ pub const SimdHashSignature = struct {
secret_key: Winternitz.PrivateKey,
public_key: Winternitz.PublicKey,

pub fn deinit(self: *KeyPair) void {
self.secret_key.deinit();
self.public_key.deinit();
pub fn deinit(self: *KeyPair, allocator: std.mem.Allocator) void {
self.secret_key.deinit(allocator);
self.public_key.deinit(allocator);
}
};

Expand All @@ -59,14 +59,23 @@ pub const SimdHashSignature = struct {
};

// Generate key pair with SIMD optimizations
pub fn generateKeyPair(self: *SimdHashSignature, _: std.mem.Allocator, seed: []const u8) !KeyPair {
// Generate Winternitz private key
const secret_key = try Winternitz.generatePrivateKey(seed);
pub fn generateKeyPair(self: *SimdHashSignature, allocator: std.mem.Allocator, seed: []const u8) !KeyPair {
if (seed.len != 32) return error.InvalidSeedLength;

// Generate Winternitz public key with SIMD
const public_key = try Winternitz.generatePublicKey(secret_key);
// Scale the number of chains based on the lifetime
// For demonstration purposes, we'll scale the chains based on tree height
// In a real implementation, this would be more sophisticated
const base_chains = 64;
const scale_factor = @as(u32, 1) << @intCast(@max(0, @as(i32, @intCast(self.tree_height)) - 10));
const scaled_chains = base_chains * scale_factor;

// Create modified parameters with scaled chain count
var scaled_params = self.params;
scaled_params.num_chains = scaled_chains;

const secret_key = try Winternitz.generatePrivateKey(allocator, scaled_params, seed);
const public_key = try Winternitz.generatePublicKey(allocator, secret_key);

_ = self; // Suppress unused parameter warning
return KeyPair{
.secret_key = secret_key,
.public_key = public_key,
Expand Down Expand Up @@ -152,7 +161,7 @@ pub const SimdHashSignature = struct {
// Sign message with SIMD optimizations
pub fn sign(self: *SimdHashSignature, allocator: std.mem.Allocator, message: []const u8, keypair: KeyPair, _: u32) !Signature {
// Generate Winternitz signature
const winternitz_sig = try Winternitz.sign(message, keypair.secret_key);
const winternitz_sig = try Winternitz.sign(allocator, message, keypair.secret_key);

// Generate Merkle path (simplified for this example)
const merkle_path = try allocator.alloc(u32, self.tree_height * 8);
Expand Down
Loading
Loading