11use  super :: Uint ; 
2- use  crate :: { ConstChoice ,  ConstCtOption } ; 
2+ use  crate :: modular:: BernsteinYangInverter ; 
3+ use  crate :: { ConstChoice ,  ConstCtOption ,  Odd ,  PrecomputeInverter } ; 
34
45impl < const  LIMBS :  usize >  Uint < LIMBS >  { 
56    /// Computes 1/`self` mod `2^k`. 
@@ -79,96 +80,33 @@ impl<const LIMBS: usize> Uint<LIMBS> {
7980        ConstCtOption :: new ( x,  is_some) 
8081    } 
8182
82-     /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd. 
83- /// In other words `self^-1 mod modulus`. 
84- /// `bits` and `modulus_bits` are the bounds on the bit size 
85- /// of `self` and `modulus`, respectively 
86- /// (the inversion speed will be proportional to `bits + modulus_bits`). 
87- /// The second element of the tuple is the truthy value 
88- /// if `modulus` is odd and an inverse exists, otherwise it is a falsy value. 
89- /// 
90- /// **Note:** variable time in `bits` and `modulus_bits`. 
91- /// 
92- /// The algorithm is the same as in GMP 6.2.1's `mpn_sec_invert`. 
93- pub  const  fn  inv_odd_mod_bounded ( 
94-         & self , 
95-         modulus :  & Self , 
96-         bits :  u32 , 
97-         modulus_bits :  u32 , 
98-     )  -> ConstCtOption < Self >  { 
99-         let  mut  a = * self ; 
100- 
101-         let  mut  u = Uint :: ONE ; 
102-         let  mut  v = Uint :: ZERO ; 
103- 
104-         let  mut  b = * modulus; 
105- 
106-         // `bit_size` can be anything >= `self.bits()` + `modulus.bits()`, setting to the minimum. 
107-         let  bit_size = bits + modulus_bits; 
108- 
109-         let  m1hp = modulus. shr1 ( ) . wrapping_add ( & Uint :: ONE ) ; 
110- 
111-         let  modulus_is_odd = modulus. is_odd ( ) ; 
112- 
113-         let  mut  i = 0 ; 
114-         while  i < bit_size { 
115-             // A sanity check that `b` stays odd. Only matters if `modulus` was odd to begin with, 
116-             // otherwise this whole thing produces nonsense anyway. 
117-             debug_assert ! ( modulus_is_odd. not( ) . or( b. is_odd( ) ) . is_true_vartime( ) ) ; 
118- 
119-             let  self_odd = a. is_odd ( ) ; 
120- 
121-             // Set `self -= b` if `self` is odd. 
122-             let  ( new_a,  swap)  = a. conditional_wrapping_sub ( & b,  self_odd) ; 
123-             // Set `b += self` if `swap` is true. 
124-             b = Uint :: select ( & b,  & b. wrapping_add ( & new_a) ,  swap) ; 
125-             // Negate `self` if `swap` is true. 
126-             a = new_a. conditional_wrapping_neg ( swap) ; 
127- 
128-             let  ( new_u,  new_v)  = Uint :: swap ( & u,  & v,  swap) ; 
129-             let  ( new_u,  cy)  = new_u. conditional_wrapping_sub ( & new_v,  self_odd) ; 
130-             let  ( new_u,  cyy)  = new_u. conditional_wrapping_add ( modulus,  cy) ; 
131-             debug_assert ! ( cy. is_true_vartime( )  == cyy. is_true_vartime( ) ) ; 
132- 
133-             let  ( new_a,  carry)  = a. shr1_with_carry ( ) ; 
134-             debug_assert ! ( modulus_is_odd. not( ) . or( carry. not( ) ) . is_true_vartime( ) ) ; 
135-             let  ( new_u,  cy)  = new_u. shr1_with_carry ( ) ; 
136-             let  ( new_u,  cy)  = new_u. conditional_wrapping_add ( & m1hp,  cy) ; 
137-             debug_assert ! ( modulus_is_odd. not( ) . or( cy. not( ) ) . is_true_vartime( ) ) ; 
138- 
139-             a = new_a; 
140-             u = new_u; 
141-             v = new_v; 
142- 
143-             i += 1 ; 
144-         } 
145- 
146-         debug_assert ! ( modulus_is_odd
147-             . not( ) 
148-             . or( a. is_nonzero( ) . not( ) ) 
149-             . is_true_vartime( ) ) ; 
150- 
151-         ConstCtOption :: new ( v,  Uint :: eq ( & b,  & Uint :: ONE ) . and ( modulus_is_odd) ) 
152-     } 
153- 
15483    /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd. 
15584/// Returns `(inverse, ConstChoice::TRUE)` if an inverse exists, 
15685/// otherwise `(undefined, ConstChoice::FALSE)`. 
157- pub  const  fn  inv_odd_mod ( & self ,  modulus :  & Self )  -> ConstCtOption < Self >  { 
158-         self . inv_odd_mod_bounded ( modulus,  Uint :: < LIMBS > :: BITS ,  Uint :: < LIMBS > :: BITS ) 
86+ pub  const  fn  inv_odd_mod < const  UNSAT_LIMBS :  usize > ( 
87+         & self , 
88+         modulus :  & Odd < Self > , 
89+     )  -> ConstCtOption < Self > 
90+     where 
91+         Odd < Self > :  PrecomputeInverter < Inverter  = BernsteinYangInverter < LIMBS ,  UNSAT_LIMBS > > , 
92+     { 
93+         BernsteinYangInverter :: < LIMBS ,  UNSAT_LIMBS > :: new ( modulus,  & Uint :: ONE ) . inv ( self ) 
15994    } 
16095
16196    /// Computes the multiplicative inverse of `self` mod `modulus`. 
16297/// Returns `(inverse, ConstChoice::TRUE)` if an inverse exists, 
16398/// otherwise `(undefined, ConstChoice::FALSE)`. 
164- pub  const  fn  inv_mod ( & self ,  modulus :  & Self )  -> ConstCtOption < Self >  { 
99+ pub  const  fn  inv_mod < const  UNSAT_LIMBS :  usize > ( & self ,  modulus :  & Self )  -> ConstCtOption < Self > 
100+     where 
101+         Odd < Self > :  PrecomputeInverter < Inverter  = BernsteinYangInverter < LIMBS ,  UNSAT_LIMBS > > , 
102+     { 
165103        // Decompose `modulus = s * 2^k` where `s` is odd 
166104        let  k = modulus. trailing_zeros ( ) ; 
167105        let  s = modulus. overflowing_shr ( k) . unwrap_or ( Self :: ZERO ) ; 
168106
169107        // Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses. 
170108        // Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1` 
171-         let  maybe_a = self . inv_odd_mod ( & s ) ; 
109+         let  maybe_a = self . inv_odd_mod ( & Odd ( s ) ) ; 
172110        let  maybe_b = self . inv_mod2k ( k) ; 
173111        let  is_some = maybe_a. is_some ( ) . and ( maybe_b. is_some ( ) ) ; 
174112
@@ -262,7 +200,9 @@ mod tests {
262200            "37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A" , 
263201            "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C" , 
264202            "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767" 
265-         ] ) ; 
203+         ] ) 
204+         . to_odd ( ) 
205+         . unwrap ( ) ; 
266206        let  expected = U1024 :: from_be_hex ( concat ! [ 
267207            "B03623284B0EBABCABD5C5881893320281460C0A8E7BF4BFDCFFCBCCBF436A55" , 
268208            "D364235C8171E46C7D21AAD0680676E57274A8FDA6D12768EF961CACDD2DAE57" , 
@@ -273,10 +213,6 @@ mod tests {
273213        let  res = a. inv_odd_mod ( & m) . unwrap ( ) ; 
274214        assert_eq ! ( res,  expected) ; 
275215
276-         // Check that trying to pass an even modulus results in `None` 
277-         let  res = a. inv_odd_mod ( & ( m. wrapping_add ( & U1024 :: ONE ) ) ) ; 
278-         assert ! ( res. is_none( ) . is_true_vartime( ) ) ; 
279- 
280216        // Even though it is less efficient, it still works 
281217        let  res = a. inv_mod ( & m) . unwrap ( ) ; 
282218        assert_eq ! ( res,  expected) ; 
@@ -291,7 +227,7 @@ mod tests {
291227        let  p2 =
292228            U256 :: from_be_hex ( "00000000000000000000000000000000ffffffffffffffffffffffffffffff53" ) ; 
293229
294-         let  m = p1. wrapping_mul ( & p2) ; 
230+         let  m = p1. wrapping_mul ( & p2) . to_odd ( ) . unwrap ( ) ; 
295231
296232        // `m` is a multiple of `p1`, so no inverse exists 
297233        let  res = p1. inv_odd_mod ( & m) ; 
@@ -323,36 +259,10 @@ mod tests {
323259        assert_eq ! ( res,  expected) ; 
324260    } 
325261
326-     #[ test]  
327-     fn  test_invert_bounded ( )  { 
328-         let  a = U1024 :: from_be_hex ( concat ! [ 
329-             "0000000000000000000000000000000000000000000000000000000000000000" , 
330-             "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8" , 
331-             "BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8" , 
332-             "382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985" 
333-         ] ) ; 
334-         let  m = U1024 :: from_be_hex ( concat ! [ 
335-             "0000000000000000000000000000000000000000000000000000000000000000" , 
336-             "0000000000000000000000000000000000000000000000000000000000000000" , 
337-             "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C" , 
338-             "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767" 
339-         ] ) ; 
340- 
341-         let  res = a. inv_odd_mod_bounded ( & m,  768 ,  512 ) . unwrap ( ) ; 
342- 
343-         let  expected = U1024 :: from_be_hex ( concat ! [ 
344-             "0000000000000000000000000000000000000000000000000000000000000000" , 
345-             "0000000000000000000000000000000000000000000000000000000000000000" , 
346-             "0DCC94E2FE509E6EBBA0825645A38E73EF85D5927C79C1AD8FFE7C8DF9A822FA" , 
347-             "09EB396A21B1EF05CBE51E1A8EF284EF01EBDD36A9A4EA17039D8EEFDD934768" 
348-         ] ) ; 
349-         assert_eq ! ( res,  expected) ; 
350-     } 
351- 
352262    #[ test]  
353263    fn  test_invert_small ( )  { 
354264        let  a = U64 :: from ( 3u64 ) ; 
355-         let  m = U64 :: from ( 13u64 ) ; 
265+         let  m = U64 :: from ( 13u64 ) . to_odd ( ) . unwrap ( ) ; 
356266
357267        let  res = a. inv_odd_mod ( & m) . unwrap ( ) ; 
358268        assert_eq ! ( U64 :: from( 9u64 ) ,  res) ; 
@@ -361,7 +271,7 @@ mod tests {
361271    #[ test]  
362272    fn  test_no_inverse_small ( )  { 
363273        let  a = U64 :: from ( 14u64 ) ; 
364-         let  m = U64 :: from ( 49u64 ) ; 
274+         let  m = U64 :: from ( 49u64 ) . to_odd ( ) . unwrap ( ) ; 
365275
366276        let  res = a. inv_odd_mod ( & m) ; 
367277        assert ! ( res. is_none( ) . is_true_vartime( ) ) ; 
0 commit comments