Skip to content

Commit dec1a65

Browse files
authored
fix: overflow errors in simd modules (#30)
2 parents b78c949 + 6724708 commit dec1a65

File tree

2 files changed

+18
-13
lines changed

2 files changed

+18
-13
lines changed

src/poseidon2/fields/koalabear/simd_montgomery.zig

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ pub const koala_bear_simd = struct {
1818
// Field constants
1919
const modulus: u32 = 0x7f000001; // 2^31 - 2^24 + 1
2020
const mont_r: u64 = 1 << 32;
21-
const r_square_mod_modulus: u64 = @intCast((@as(u128, mont_r) * @as(u128, mont_r)) % modulus);
22-
const modulus_prime: u32 = 0x7f000001; // -modulus^-1 mod 2^32
21+
const r_square_mod_modulus: u64 = @intCast((@as(u128, mont_r) * @as(u128, mont_r)) % @as(u128, modulus));
22+
const modulus_prime: u32 = 0x81000001; // -modulus^-1 mod 2^32
2323

2424
// SIMD-optimized Montgomery reduction
2525
pub fn montReduceSIMD(mont_value: u64) FieldElem {
26-
const tmp = mont_value + (((mont_value & 0xFFFFFFFF) * modulus_prime) & 0xFFFFFFFF) * modulus;
26+
const low = mont_value & 0xFFFFFFFF;
27+
const q = (low *% modulus_prime) & 0xFFFFFFFF;
28+
const tmp = mont_value +% (@as(u64, q) *% @as(u64, modulus));
2729
const t = tmp >> 32;
2830
if (t >= modulus) {
2931
return @intCast(t - modulus);
@@ -87,37 +89,37 @@ pub const koala_bear_simd = struct {
8789

8890
// Vectorized addition with modular reduction
8991
pub fn addVec4(out: *Vec4, a: Vec4, b: Vec4) void {
90-
const sum = a + b;
92+
const sum = a +% b;
9193
const mask = @Vector(4, u32){ modulus, modulus, modulus, modulus };
9294
const needs_reduction = sum >= mask;
9395

9496
// Apply reduction element-wise
9597
for (0..4) |i| {
96-
out[i] = if (needs_reduction[i]) sum[i] - modulus else sum[i];
98+
out[i] = if (needs_reduction[i]) sum[i] -% modulus else sum[i];
9799
}
98100
}
99101

100102
// Vectorized addition for 8 elements
101103
pub fn addVec8(out: *Vec8, a: Vec8, b: Vec8) void {
102-
const sum = a + b;
104+
const sum = a +% b;
103105
const mask = @Vector(8, u32){ modulus, modulus, modulus, modulus, modulus, modulus, modulus, modulus };
104106
const needs_reduction = sum >= mask;
105107

106108
// Apply reduction element-wise
107109
for (0..8) |i| {
108-
out[i] = if (needs_reduction[i]) sum[i] - modulus else sum[i];
110+
out[i] = if (needs_reduction[i]) sum[i] -% modulus else sum[i];
109111
}
110112
}
111113

112114
// Vectorized addition for 16 elements
113115
pub fn addVec16(out: *Vec16, a: Vec16, b: Vec16) void {
114-
const sum = a + b;
116+
const sum = a +% b;
115117
const mask = @Vector(16, u32){ modulus, modulus, modulus, modulus, modulus, modulus, modulus, modulus, modulus, modulus, modulus, modulus, modulus, modulus, modulus, modulus };
116118
const needs_reduction = sum >= mask;
117119

118120
// Apply reduction element-wise
119121
for (0..16) |i| {
120-
out[i] = if (needs_reduction[i]) sum[i] - modulus else sum[i];
122+
out[i] = if (needs_reduction[i]) sum[i] -% modulus else sum[i];
121123
}
122124
}
123125

@@ -237,11 +239,12 @@ pub const koala_bear_simd = struct {
237239
}
238240

239241
pub fn add(out: *MontFieldElem, a: MontFieldElem, b: MontFieldElem) void {
240-
var tmp = a.value + b.value;
242+
const tmp = a.value +% b.value;
241243
if (tmp >= modulus) {
242-
tmp -= modulus;
244+
out.* = .{ .value = tmp -% modulus };
245+
} else {
246+
out.* = .{ .value = tmp };
243247
}
244-
out.* = .{ .value = tmp };
245248
}
246249

247250
pub fn square(out: *MontFieldElem, a: MontFieldElem) void {

src/tweakable_hash.zig

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ pub const TweakableHash = struct {
5151
defer allocator.free(tweaked_data);
5252

5353
std.mem.writeInt(u64, tweaked_data[0..8], tweak, .big);
54-
@memcpy(tweaked_data[8..], data);
54+
for (data, 0..) |byte, i| {
55+
tweaked_data[8 + i] = byte;
56+
}
5557

5658
return switch (self.hash_impl) {
5759
.poseidon2 => |*p| try p.hashBytes(allocator, tweaked_data),

0 commit comments

Comments
 (0)