Skip to content

Commit 703c42e

Browse files
committed
Crude implementation of sqrt_vartime
1 parent 21f3a6b commit 703c42e

File tree

3 files changed

+124
-1
lines changed

3 files changed

+124
-1
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ rand_core = { version = "0.6", features = ["std"] }
3939
rand_chacha = "0.3"
4040

4141
[features]
42-
default = ["rand"]
42+
default = ["rand", "alloc"]
4343
alloc = ["serdect?/alloc"]
4444
std = ["alloc"]
4545

src/uint/boxed.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ mod mul_mod;
1515
mod neg;
1616
mod shl;
1717
mod shr;
18+
mod sqrt;
1819
mod sub;
1920
mod sub_mod;
2021

@@ -91,6 +92,11 @@ impl BoxedUint {
9192
.fold(Choice::from(1), |acc, limb| acc & limb.is_zero())
9293
}
9394

95+
/// Is this [`BoxedUint`] non-zero?
96+
pub fn is_nonzero(&self) -> Choice {
97+
!self.is_zero()
98+
}
99+
94100
/// Is this [`BoxedUint`] equal to one?
95101
pub fn is_one(&self) -> Choice {
96102
let mut iter = self.limbs.iter();

src/uint/boxed/sqrt.rs

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
//! [`Uint`] square root operations.
2+
use crate::{BoxedUint, NonZero};
3+
4+
impl BoxedUint {
5+
/// Computes √(`self`) in constant time.
6+
///
7+
/// Callers can check if `self` is a square by squaring the result
8+
pub const fn sqrt(&self) -> Self {
9+
todo!();
10+
}
11+
12+
/// Computes √(`self`)
13+
///
14+
/// Callers can check if `self` is a square by squaring the result
15+
pub fn sqrt_vartime(&self) -> Self {
16+
// Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
17+
18+
// The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`.
19+
// Will not overflow since `b <= BITS`.
20+
let (mut x, _overflow) =
21+
Self::one_with_precision(self.bits_precision()).shl((self.bits() + 1) >> 1); // ≥ √(`self`)
22+
23+
// Stop right away if `x` is zero to avoid divizion by zero.
24+
// TODO: is it kay to use cmp instead of implementing cmp_vartime?
25+
while !x
26+
.cmp(&Self::zero_with_precision(self.bits_precision()))
27+
.is_eq()
28+
{
29+
// Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`
30+
31+
// TODO: is using wrapping_div instead of wrapping_div_vartime ok?
32+
let q = self.wrapping_div(&NonZero::<Self>::new(x.clone()).expect("division by 0"));
33+
let t = x.wrapping_add(&q);
34+
let next_x = t.shr1();
35+
36+
// If `next_x` is the same as `x` or greater, we reached convergence
37+
// (`x` is guaranteed to either go down or oscillate between
38+
// `sqrt(self)` and `sqrt(self) + 1`)
39+
if !x.cmp(&next_x).is_gt() {
40+
break;
41+
}
42+
43+
x = next_x;
44+
}
45+
46+
if self.is_nonzero().into() {
47+
x
48+
} else {
49+
Self::zero_with_precision(self.bits_precision())
50+
}
51+
}
52+
}
53+
54+
#[cfg(test)]
55+
mod tests {
56+
use super::*;
57+
use crate::Limb;
58+
59+
#[test]
60+
fn edge_vartime() {
61+
let zero = BoxedUint::zero_with_precision(256);
62+
let one = BoxedUint::one_with_precision(256);
63+
let max = !zero.clone();
64+
assert_eq!(zero.sqrt_vartime(), zero);
65+
assert_eq!(one.sqrt_vartime(), one);
66+
let mut half = zero;
67+
for i in 0..half.limbs.len() / 2 {
68+
half.limbs[i] = Limb::MAX;
69+
}
70+
assert_eq!(max.sqrt_vartime(), half);
71+
}
72+
73+
#[test]
74+
fn simple() {
75+
let tests = [
76+
(4u8, 2u8),
77+
(9, 3),
78+
(16, 4),
79+
(25, 5),
80+
(36, 6),
81+
(49, 7),
82+
(64, 8),
83+
(81, 9),
84+
(100, 10),
85+
(121, 11),
86+
(144, 12),
87+
(169, 13),
88+
];
89+
for (a, e) in &tests {
90+
let l = BoxedUint::from(*a);
91+
let r = BoxedUint::from(*e);
92+
// assert_eq!(l.sqrt(), r);
93+
assert_eq!(l.sqrt_vartime(), r);
94+
// assert_eq!(l.checked_sqrt().is_some().unwrap_u8(), 1u8);
95+
// assert_eq!(l.checked_sqrt_vartime().is_some().unwrap_u8(), 1u8);
96+
}
97+
}
98+
99+
#[test]
100+
fn nonsquares_vartime() {
101+
assert_eq!(BoxedUint::from(2u8).sqrt_vartime(), BoxedUint::from(1u8));
102+
// assert_eq!(
103+
// BoxedUint::from(2u8).checked_sqrt_vartime().is_some().unwrap_u8(),
104+
// 0
105+
// );
106+
assert_eq!(BoxedUint::from(3u8).sqrt_vartime(), BoxedUint::from(1u8));
107+
// assert_eq!(
108+
// BoxedUint::from(3u8).checked_sqrt_vartime().is_some().unwrap_u8(),
109+
// 0
110+
// );
111+
assert_eq!(BoxedUint::from(5u8).sqrt_vartime(), BoxedUint::from(2u8));
112+
assert_eq!(BoxedUint::from(6u8).sqrt_vartime(), BoxedUint::from(2u8));
113+
assert_eq!(BoxedUint::from(7u8).sqrt_vartime(), BoxedUint::from(2u8));
114+
assert_eq!(BoxedUint::from(8u8).sqrt_vartime(), BoxedUint::from(2u8));
115+
assert_eq!(BoxedUint::from(10u8).sqrt_vartime(), BoxedUint::from(3u8));
116+
}
117+
}

0 commit comments

Comments
 (0)