Skip to content

Commit 6d313f8

Browse files
authored
Use Bernstein-Yang to implement (Boxed)Uint::inv_mod(_odd) (#501)
1 parent b04f2f1 commit 6d313f8

File tree

10 files changed

+41
-267
lines changed

10 files changed

+41
-267
lines changed

benches/uint.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion};
2-
use crypto_bigint::{Limb, NonZero, Random, Reciprocal, Uint, U128, U2048, U256};
2+
use crypto_bigint::{Limb, NonZero, Odd, Random, Reciprocal, Uint, U128, U2048, U256};
33
use rand_core::OsRng;
44

55
fn bench_division(c: &mut Criterion) {
@@ -170,7 +170,7 @@ fn bench_inv_mod(c: &mut Criterion) {
170170
group.bench_function("inv_odd_mod, U256", |b| {
171171
b.iter_batched(
172172
|| {
173-
let m = U256::random(&mut OsRng) | U256::ONE;
173+
let m = Odd::<U256>::random(&mut OsRng);
174174
loop {
175175
let x = U256::random(&mut OsRng);
176176
let inv_x = x.inv_odd_mod(&m);
@@ -187,7 +187,7 @@ fn bench_inv_mod(c: &mut Criterion) {
187187
group.bench_function("inv_mod, U256, odd modulus", |b| {
188188
b.iter_batched(
189189
|| {
190-
let m = U256::random(&mut OsRng) | U256::ONE;
190+
let m = Odd::<U256>::random(&mut OsRng);
191191
loop {
192192
let x = U256::random(&mut OsRng);
193193
let inv_x = x.inv_odd_mod(&m);

src/uint/add.rs

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,6 @@ impl<const LIMBS: usize> Uint<LIMBS> {
3131
pub const fn wrapping_add(&self, rhs: &Self) -> Self {
3232
self.adc(rhs, Limb::ZERO).0
3333
}
34-
35-
/// Perform wrapping addition, returning the truthy value as the second element of the tuple
36-
/// if an overflow has occurred.
37-
pub(crate) const fn conditional_wrapping_add(
38-
&self,
39-
rhs: &Self,
40-
choice: ConstChoice,
41-
) -> (Self, ConstChoice) {
42-
let actual_rhs = Uint::select(&Uint::ZERO, rhs, choice);
43-
let (sum, carry) = self.adc(&actual_rhs, Limb::ZERO);
44-
(sum, ConstChoice::from_word_lsb(carry.0))
45-
}
4634
}
4735

4836
impl<const LIMBS: usize> Add for Uint<LIMBS> {

src/uint/boxed/inv_mod.rs

Lines changed: 3 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,16 @@
11
//! [`BoxedUint`] modular inverse (i.e. reciprocal) operations.
22
33
use crate::{
4-
modular::BoxedBernsteinYangInverter, BoxedUint, ConstantTimeSelect, Integer, Odd,
4+
modular::BoxedBernsteinYangInverter, BoxedUint, ConstantTimeSelect, Integer, Inverter, Odd,
55
PrecomputeInverter, PrecomputeInverterWithAdjuster,
66
};
77
use subtle::{Choice, ConstantTimeEq, ConstantTimeLess, CtOption};
88

99
impl BoxedUint {
1010
/// Computes the multiplicative inverse of `self` mod `modulus`.
1111
/// Returns `None` if an inverse does not exist.
12-
pub fn inv_mod(&self, modulus: &Self) -> CtOption<Self> {
13-
debug_assert_eq!(self.bits_precision(), modulus.bits_precision());
14-
15-
// Decompose `modulus = s * 2^k` where `s` is odd
16-
let k = modulus.trailing_zeros();
17-
let s = modulus >> k;
18-
19-
// Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses.
20-
// Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1`
21-
let (a, a_is_some) = self.inv_odd_mod(&s);
22-
let (b, b_is_some) = self.inv_mod2k(k);
23-
24-
// Restore from RNS:
25-
// self^{-1} = a mod s = b mod 2^k
26-
// => self^{-1} = a + s * ((b - a) * s^(-1) mod 2^k)
27-
// (essentially one step of the Garner's algorithm for recovery from RNS).
28-
29-
let (m_odd_inv, _is_some) = s.inv_mod2k(k); // `s` is odd, so this always exists
30-
31-
// This part is mod 2^k
32-
let mask = (Self::one() << k).wrapping_sub(&Self::one());
33-
let t = (b.wrapping_sub(&a).wrapping_mul(&m_odd_inv)).bitand(&mask);
34-
35-
// Will not overflow since `a <= s - 1`, `t <= 2^k - 1`,
36-
// so `a + s * t <= s * 2^k - 1 == modulus - 1`.
37-
let result = a.wrapping_add(&s.wrapping_mul(&t));
38-
CtOption::new(result, a_is_some & b_is_some)
12+
pub fn inv_mod(&self, modulus: &Odd<Self>) -> CtOption<Self> {
13+
modulus.precompute_inverter().invert(self)
3914
}
4015

4116
/// Computes 1/`self` mod `2^k`.
@@ -69,80 +44,6 @@ impl BoxedUint {
6944

7045
(x, is_some)
7146
}
72-
73-
/// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
74-
/// Returns `None` if an inverse does not exist.
75-
pub(crate) fn inv_odd_mod(&self, modulus: &Self) -> (Self, Choice) {
76-
self.inv_odd_mod_bounded(modulus, self.bits_precision(), modulus.bits_precision())
77-
}
78-
79-
/// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
80-
/// In other words `self^-1 mod modulus`.
81-
///
82-
/// `bits` and `modulus_bits` are the bounds on the bit size
83-
/// of `self` and `modulus`, respectively.
84-
///
85-
/// (the inversion speed will be proportional to `bits + modulus_bits`).
86-
/// The second element of the tuple is the truthy value
87-
/// if `modulus` is odd and an inverse exists, otherwise it is a falsy value.
88-
///
89-
/// **Note:** variable time in `bits` and `modulus_bits`.
90-
///
91-
/// The algorithm is the same as in GMP 6.2.1's `mpn_sec_invert`.
92-
fn inv_odd_mod_bounded(&self, modulus: &Self, bits: u32, modulus_bits: u32) -> (Self, Choice) {
93-
debug_assert_eq!(self.bits_precision(), modulus.bits_precision());
94-
95-
let bits_precision = self.bits_precision();
96-
97-
let mut a = self.clone();
98-
let mut u = Self::one_with_precision(bits_precision);
99-
let mut v = Self::zero_with_precision(bits_precision);
100-
let mut b = modulus.clone();
101-
102-
// `bit_size` can be anything >= `self.bits()` + `modulus.bits()`, setting to the minimum.
103-
let bit_size = bits + modulus_bits;
104-
105-
let m1hp = modulus
106-
.shr1()
107-
.wrapping_add(&Self::one_with_precision(bits_precision));
108-
109-
let modulus_is_odd = modulus.is_odd();
110-
111-
for _ in 0..bit_size {
112-
// A sanity check that `b` stays odd. Only matters if `modulus` was odd to begin with,
113-
// otherwise this whole thing produces nonsense anyway.
114-
debug_assert!(bool::from(!modulus_is_odd | b.is_odd()));
115-
116-
let self_odd = a.is_odd();
117-
118-
// Set `self -= b` if `self` is odd.
119-
let swap = a.conditional_sbb_assign(&b, self_odd);
120-
// Set `b += self` if `swap` is true.
121-
b = Self::ct_select(&b, &b.wrapping_add(&a), swap);
122-
// Negate `self` if `swap` is true.
123-
a = a.conditional_wrapping_neg(swap);
124-
125-
let mut new_u = u.clone();
126-
let mut new_v = v.clone();
127-
Self::ct_swap(&mut new_u, &mut new_v, swap);
128-
let cy = new_u.conditional_sbb_assign(&new_v, self_odd);
129-
let cyy = new_u.conditional_adc_assign(modulus, cy);
130-
debug_assert!(bool::from(cy.ct_eq(&cyy)));
131-
132-
let (new_a, carry) = a.shr1_with_carry();
133-
debug_assert!(bool::from(!modulus_is_odd | !carry));
134-
let (mut new_u, cy) = new_u.shr1_with_carry();
135-
let cy = new_u.conditional_adc_assign(&m1hp, cy);
136-
debug_assert!(bool::from(!modulus_is_odd | !cy));
137-
138-
a = new_a;
139-
u = new_u;
140-
v = new_v;
141-
}
142-
143-
debug_assert!(bool::from(!modulus_is_odd | a.is_zero()));
144-
(v, b.is_one() & modulus_is_odd)
145-
}
14647
}
14748

14849
/// Precompute a Bernstein-Yang inverter using `self` as the modulus.

src/uint/boxed/neg.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
11
//! [`BoxedUint`] negation operations.
22
3-
use crate::{BoxedUint, ConstantTimeSelect, Limb, WideWord, Word, WrappingNeg};
4-
use subtle::Choice;
3+
use crate::{BoxedUint, Limb, WideWord, Word, WrappingNeg};
54

65
impl BoxedUint {
7-
/// Negates based on `choice` by wrapping the integer.
8-
pub(crate) fn conditional_wrapping_neg(&self, choice: Choice) -> BoxedUint {
9-
Self::ct_select(self, &self.wrapping_neg(), choice)
10-
}
11-
126
/// Perform wrapping negation.
137
pub fn wrapping_neg(&self) -> Self {
148
let mut ret = vec![Limb::ZERO; self.nlimbs()];

src/uint/cmp.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,6 @@ impl<const LIMBS: usize> Uint<LIMBS> {
2222
Uint { limbs }
2323
}
2424

25-
#[inline]
26-
pub(crate) const fn swap(a: &Self, b: &Self, c: ConstChoice) -> (Self, Self) {
27-
let new_a = Self::select(a, b, c);
28-
let new_b = Self::select(b, a, c);
29-
30-
(new_a, new_b)
31-
}
32-
3325
/// Returns the truthy value if `self`!=0 or the falsy value otherwise.
3426
#[inline]
3527
pub(crate) const fn is_nonzero(&self) -> ConstChoice {

src/uint/inv_mod.rs

Lines changed: 21 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use super::Uint;
2-
use crate::{ConstChoice, ConstCtOption};
2+
use crate::modular::BernsteinYangInverter;
3+
use crate::{ConstChoice, ConstCtOption, Odd, PrecomputeInverter};
34

45
impl<const LIMBS: usize> Uint<LIMBS> {
56
/// Computes 1/`self` mod `2^k`.
@@ -79,96 +80,33 @@ impl<const LIMBS: usize> Uint<LIMBS> {
7980
ConstCtOption::new(x, is_some)
8081
}
8182

82-
/// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
83-
/// In other words `self^-1 mod modulus`.
84-
/// `bits` and `modulus_bits` are the bounds on the bit size
85-
/// of `self` and `modulus`, respectively
86-
/// (the inversion speed will be proportional to `bits + modulus_bits`).
87-
/// The second element of the tuple is the truthy value
88-
/// if `modulus` is odd and an inverse exists, otherwise it is a falsy value.
89-
///
90-
/// **Note:** variable time in `bits` and `modulus_bits`.
91-
///
92-
/// The algorithm is the same as in GMP 6.2.1's `mpn_sec_invert`.
93-
pub const fn inv_odd_mod_bounded(
94-
&self,
95-
modulus: &Self,
96-
bits: u32,
97-
modulus_bits: u32,
98-
) -> ConstCtOption<Self> {
99-
let mut a = *self;
100-
101-
let mut u = Uint::ONE;
102-
let mut v = Uint::ZERO;
103-
104-
let mut b = *modulus;
105-
106-
// `bit_size` can be anything >= `self.bits()` + `modulus.bits()`, setting to the minimum.
107-
let bit_size = bits + modulus_bits;
108-
109-
let m1hp = modulus.shr1().wrapping_add(&Uint::ONE);
110-
111-
let modulus_is_odd = modulus.is_odd();
112-
113-
let mut i = 0;
114-
while i < bit_size {
115-
// A sanity check that `b` stays odd. Only matters if `modulus` was odd to begin with,
116-
// otherwise this whole thing produces nonsense anyway.
117-
debug_assert!(modulus_is_odd.not().or(b.is_odd()).is_true_vartime());
118-
119-
let self_odd = a.is_odd();
120-
121-
// Set `self -= b` if `self` is odd.
122-
let (new_a, swap) = a.conditional_wrapping_sub(&b, self_odd);
123-
// Set `b += self` if `swap` is true.
124-
b = Uint::select(&b, &b.wrapping_add(&new_a), swap);
125-
// Negate `self` if `swap` is true.
126-
a = new_a.conditional_wrapping_neg(swap);
127-
128-
let (new_u, new_v) = Uint::swap(&u, &v, swap);
129-
let (new_u, cy) = new_u.conditional_wrapping_sub(&new_v, self_odd);
130-
let (new_u, cyy) = new_u.conditional_wrapping_add(modulus, cy);
131-
debug_assert!(cy.is_true_vartime() == cyy.is_true_vartime());
132-
133-
let (new_a, carry) = a.shr1_with_carry();
134-
debug_assert!(modulus_is_odd.not().or(carry.not()).is_true_vartime());
135-
let (new_u, cy) = new_u.shr1_with_carry();
136-
let (new_u, cy) = new_u.conditional_wrapping_add(&m1hp, cy);
137-
debug_assert!(modulus_is_odd.not().or(cy.not()).is_true_vartime());
138-
139-
a = new_a;
140-
u = new_u;
141-
v = new_v;
142-
143-
i += 1;
144-
}
145-
146-
debug_assert!(modulus_is_odd
147-
.not()
148-
.or(a.is_nonzero().not())
149-
.is_true_vartime());
150-
151-
ConstCtOption::new(v, Uint::eq(&b, &Uint::ONE).and(modulus_is_odd))
152-
}
153-
15483
/// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
15584
/// Returns `(inverse, ConstChoice::TRUE)` if an inverse exists,
15685
/// otherwise `(undefined, ConstChoice::FALSE)`.
157-
pub const fn inv_odd_mod(&self, modulus: &Self) -> ConstCtOption<Self> {
158-
self.inv_odd_mod_bounded(modulus, Uint::<LIMBS>::BITS, Uint::<LIMBS>::BITS)
86+
pub const fn inv_odd_mod<const UNSAT_LIMBS: usize>(
87+
&self,
88+
modulus: &Odd<Self>,
89+
) -> ConstCtOption<Self>
90+
where
91+
Odd<Self>: PrecomputeInverter<Inverter = BernsteinYangInverter<LIMBS, UNSAT_LIMBS>>,
92+
{
93+
BernsteinYangInverter::<LIMBS, UNSAT_LIMBS>::new(modulus, &Uint::ONE).inv(self)
15994
}
16095

16196
/// Computes the multiplicative inverse of `self` mod `modulus`.
16297
/// Returns `(inverse, ConstChoice::TRUE)` if an inverse exists,
16398
/// otherwise `(undefined, ConstChoice::FALSE)`.
164-
pub const fn inv_mod(&self, modulus: &Self) -> ConstCtOption<Self> {
99+
pub const fn inv_mod<const UNSAT_LIMBS: usize>(&self, modulus: &Self) -> ConstCtOption<Self>
100+
where
101+
Odd<Self>: PrecomputeInverter<Inverter = BernsteinYangInverter<LIMBS, UNSAT_LIMBS>>,
102+
{
165103
// Decompose `modulus = s * 2^k` where `s` is odd
166104
let k = modulus.trailing_zeros();
167105
let s = modulus.overflowing_shr(k).unwrap_or(Self::ZERO);
168106

169107
// Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses.
170108
// Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1`
171-
let maybe_a = self.inv_odd_mod(&s);
109+
let maybe_a = self.inv_odd_mod(&Odd(s));
172110
let maybe_b = self.inv_mod2k(k);
173111
let is_some = maybe_a.is_some().and(maybe_b.is_some());
174112

@@ -262,7 +200,9 @@ mod tests {
262200
"37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A",
263201
"D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
264202
"558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767"
265-
]);
203+
])
204+
.to_odd()
205+
.unwrap();
266206
let expected = U1024::from_be_hex(concat![
267207
"B03623284B0EBABCABD5C5881893320281460C0A8E7BF4BFDCFFCBCCBF436A55",
268208
"D364235C8171E46C7D21AAD0680676E57274A8FDA6D12768EF961CACDD2DAE57",
@@ -273,10 +213,6 @@ mod tests {
273213
let res = a.inv_odd_mod(&m).unwrap();
274214
assert_eq!(res, expected);
275215

276-
// Check that trying to pass an even modulus results in `None`
277-
let res = a.inv_odd_mod(&(m.wrapping_add(&U1024::ONE)));
278-
assert!(res.is_none().is_true_vartime());
279-
280216
// Even though it is less efficient, it still works
281217
let res = a.inv_mod(&m).unwrap();
282218
assert_eq!(res, expected);
@@ -291,7 +227,7 @@ mod tests {
291227
let p2 =
292228
U256::from_be_hex("00000000000000000000000000000000ffffffffffffffffffffffffffffff53");
293229

294-
let m = p1.wrapping_mul(&p2);
230+
let m = p1.wrapping_mul(&p2).to_odd().unwrap();
295231

296232
// `m` is a multiple of `p1`, so no inverse exists
297233
let res = p1.inv_odd_mod(&m);
@@ -323,36 +259,10 @@ mod tests {
323259
assert_eq!(res, expected);
324260
}
325261

326-
#[test]
327-
fn test_invert_bounded() {
328-
let a = U1024::from_be_hex(concat![
329-
"0000000000000000000000000000000000000000000000000000000000000000",
330-
"347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
331-
"BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8",
332-
"382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
333-
]);
334-
let m = U1024::from_be_hex(concat![
335-
"0000000000000000000000000000000000000000000000000000000000000000",
336-
"0000000000000000000000000000000000000000000000000000000000000000",
337-
"D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
338-
"558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767"
339-
]);
340-
341-
let res = a.inv_odd_mod_bounded(&m, 768, 512).unwrap();
342-
343-
let expected = U1024::from_be_hex(concat![
344-
"0000000000000000000000000000000000000000000000000000000000000000",
345-
"0000000000000000000000000000000000000000000000000000000000000000",
346-
"0DCC94E2FE509E6EBBA0825645A38E73EF85D5927C79C1AD8FFE7C8DF9A822FA",
347-
"09EB396A21B1EF05CBE51E1A8EF284EF01EBDD36A9A4EA17039D8EEFDD934768"
348-
]);
349-
assert_eq!(res, expected);
350-
}
351-
352262
#[test]
353263
fn test_invert_small() {
354264
let a = U64::from(3u64);
355-
let m = U64::from(13u64);
265+
let m = U64::from(13u64).to_odd().unwrap();
356266

357267
let res = a.inv_odd_mod(&m).unwrap();
358268
assert_eq!(U64::from(9u64), res);
@@ -361,7 +271,7 @@ mod tests {
361271
#[test]
362272
fn test_no_inverse_small() {
363273
let a = U64::from(14u64);
364-
let m = U64::from(49u64);
274+
let m = U64::from(49u64).to_odd().unwrap();
365275

366276
let res = a.inv_odd_mod(&m);
367277
assert!(res.is_none().is_true_vartime());

0 commit comments

Comments
 (0)