@@ -18,12 +18,14 @@ pub const koala_bear_simd = struct {
1818 // Field constants
1919 const modulus : u32 = 0x7f000001 ; // 2^31 - 2^24 + 1
2020 const mont_r : u64 = 1 << 32 ;
21- const r_square_mod_modulus : u64 = @intCast ((@as (u128 , mont_r ) * @as (u128 , mont_r )) % modulus );
22- const modulus_prime : u32 = 0x7f000001 ; // -modulus^-1 mod 2^32
21+ const r_square_mod_modulus : u64 = @intCast ((@as (u128 , mont_r ) * @as (u128 , mont_r )) % @as ( u128 , modulus ) );
22+ const modulus_prime : u32 = 0x81000001 ; // -modulus^-1 mod 2^32
2323
2424 // SIMD-optimized Montgomery reduction
2525 pub fn montReduceSIMD (mont_value : u64 ) FieldElem {
26- const tmp = mont_value + (((mont_value & 0xFFFFFFFF ) * modulus_prime ) & 0xFFFFFFFF ) * modulus ;
26+ const low = mont_value & 0xFFFFFFFF ;
27+ const q = (low *% modulus_prime ) & 0xFFFFFFFF ;
28+ const tmp = mont_value +% (@as (u64 , q ) *% @as (u64 , modulus ));
2729 const t = tmp >> 32 ;
2830 if (t >= modulus ) {
2931 return @intCast (t - modulus );
@@ -87,37 +89,37 @@ pub const koala_bear_simd = struct {
8789
8890 // Vectorized addition with modular reduction
8991 pub fn addVec4 (out : * Vec4 , a : Vec4 , b : Vec4 ) void {
90- const sum = a + b ;
92+ const sum = a +% b ;
9193 const mask = @Vector (4 , u32 ){ modulus , modulus , modulus , modulus };
9294 const needs_reduction = sum >= mask ;
9395
9496 // Apply reduction element-wise
9597 for (0.. 4) | i | {
96- out [i ] = if (needs_reduction [i ]) sum [i ] - modulus else sum [i ];
98+ out [i ] = if (needs_reduction [i ]) sum [i ] -% modulus else sum [i ];
9799 }
98100 }
99101
100102 // Vectorized addition for 8 elements
101103 pub fn addVec8 (out : * Vec8 , a : Vec8 , b : Vec8 ) void {
102- const sum = a + b ;
104+ const sum = a +% b ;
103105 const mask = @Vector (8 , u32 ){ modulus , modulus , modulus , modulus , modulus , modulus , modulus , modulus };
104106 const needs_reduction = sum >= mask ;
105107
106108 // Apply reduction element-wise
107109 for (0.. 8) | i | {
108- out [i ] = if (needs_reduction [i ]) sum [i ] - modulus else sum [i ];
110+ out [i ] = if (needs_reduction [i ]) sum [i ] -% modulus else sum [i ];
109111 }
110112 }
111113
112114 // Vectorized addition for 16 elements
113115 pub fn addVec16 (out : * Vec16 , a : Vec16 , b : Vec16 ) void {
114- const sum = a + b ;
116+ const sum = a +% b ;
115117 const mask = @Vector (16 , u32 ){ modulus , modulus , modulus , modulus , modulus , modulus , modulus , modulus , modulus , modulus , modulus , modulus , modulus , modulus , modulus , modulus };
116118 const needs_reduction = sum >= mask ;
117119
118120 // Apply reduction element-wise
119121 for (0.. 16) | i | {
120- out [i ] = if (needs_reduction [i ]) sum [i ] - modulus else sum [i ];
122+ out [i ] = if (needs_reduction [i ]) sum [i ] -% modulus else sum [i ];
121123 }
122124 }
123125
@@ -237,11 +239,12 @@ pub const koala_bear_simd = struct {
237239 }
238240
239241 pub fn add (out : * MontFieldElem , a : MontFieldElem , b : MontFieldElem ) void {
240- var tmp = a .value + b .value ;
242+ const tmp = a .value +% b .value ;
241243 if (tmp >= modulus ) {
242- tmp -= modulus ;
244+ out .* = .{ .value = tmp -% modulus };
245+ } else {
246+ out .* = .{ .value = tmp };
243247 }
244- out .* = .{ .value = tmp };
245248 }
246249
247250 pub fn square (out : * MontFieldElem , a : MontFieldElem ) void {
0 commit comments