From 558753da6c84e0e11e98524ef3e93aaad9db7fd8 Mon Sep 17 00:00:00 2001 From: Oleg Andreev Date: Thu, 26 Apr 2018 11:15:39 -0700 Subject: [PATCH 1/3] Implement fast sum of powers for any n --- src/util.rs | 126 +++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 90 insertions(+), 36 deletions(-) diff --git a/src/util.rs b/src/util.rs index d86bb7f1..c1fca3fd 100644 --- a/src/util.rs +++ b/src/util.rs @@ -88,8 +88,8 @@ impl Poly2 { } /// Raises `x` to the power `n` using binary exponentiation, -/// with (1 to 2)*lg(n) scalar multiplications. -/// TODO: a consttime version of this would be awfully similar to a Montgomery ladder. +/// with `(1 to 2)*lg(n)` scalar multiplications. +/// TODO: a consttime version of this would be similar to a Montgomery ladder. pub fn scalar_exp_vartime(x: &Scalar, mut n: u64) -> Scalar { let mut result = Scalar::one(); let mut aux = *x; // x, x^2, x^4, x^8, ... @@ -99,38 +99,85 @@ pub fn scalar_exp_vartime(x: &Scalar, mut n: u64) -> Scalar { result = result * aux; } n = n >> 1; - aux = aux * aux; // FIXME: one unnecessary mult at the last step here! + if n > 0 { + aux = aux * aux; + } } result } -/// Takes the sum of all the powers of `x`, up to `n` -/// If `n` is a power of 2, it uses the efficient algorithm with `2*lg n` multiplications and additions. -/// If `n` is not a power of 2, it uses the slow algorithm with `n` multiplications and additions. -/// In the Bulletproofs case, all calls to `sum_of_powers` should have `n` as a power of 2. -pub fn sum_of_powers(x: &Scalar, n: usize) -> Scalar { - if !n.is_power_of_two() { - return sum_of_powers_slow(x, n); - } - if n == 0 || n == 1 { - return Scalar::from(n as u64); - } - let mut m = n; - let mut result = Scalar::one() + x; - let mut factor = *x; - while m > 2 { - factor = factor * factor; - result = result + factor * result; - m = m / 2; +/// Computes the sum of all the powers of \\(x\\) \\(S(n) = (x^0 + \dots + x^{n-1})\\) +/// using \\(O(\lg n)\\) multiplications and additions. Length \\(n\\) is not considered secret +/// and algorithm is fastest when \\(n\\) is the power of two. +/// +/// ### Algorithm overview +/// +/// First, let \\(n\\) be a power of two. +/// Then, we can divide the polynomial in two halves like so: +/// \\[ +/// \begin{aligned} +/// S(n) &= (1+\dots+x^{n-1}) \\\\ +/// &= (1+\dots+x^{n/2-1}) + x^{n/2} (1+\dots+x^{n/2-1}) \\\\ +/// &= s + x^{n/2} s. +/// \end{aligned} +/// \\] +/// We can divide each \\(s\\) in half until we arrive to a degree-1 polynomial \\((1+x\cdot 1)\\). +/// Recursively, the total sum can be defined as: +/// \\[ +/// \begin{aligned} +/// S(0) &= 0 \\\\ +/// S(n) &= s_{\lg n} \\\\ +/// s_0 &= 1 \\\\ +/// s_i &= s_{i-1} + x^{2^{i-1}} s_{i-1} +/// \end{aligned} +/// \\] +/// This representation allows us to square \\(x\\) only \\(\lg n\\) times. +/// +/// Lets apply this to \\(n\\) which is not a power of two (\\(2^{k-1} < n < 2^k\\)) which can be represented in binary using +/// bits \\(b_i\\) in \\(\\{0,1\\}\\): +/// \\[ +/// n = b_0 2^0 + \dots + b_{k-1} 2^{k-1} +/// \\] +/// If we scan the bits of \\(n\\) from low to high (\\(i \in [0,k)\\)), +/// we can conditionally (if \\(b_i = 1\\)) add to a resulting scalar +/// an intermediate polynomial with \\(2^i\\) terms using the above algorithm, +/// provided we offset the polynomial by \\(x^{n_i}\\), the next power of \\(x\\) +/// for the existing sum, where \\(n_i = \sum_{j=0}^{i-1} b_j 2^j\\). +/// +/// The full algorithm becomes: +/// \\[ +/// \begin{aligned} +/// S(0) &= 0 \\\\ +/// S(1) &= 1 \\\\ +/// S(i) &= S(i-1) + x^{n_i} s_i b_i\\\\ +/// &= S(i-1) + x^{n_{i-1}} (x^{2^{i-1}})^{b_{i-1}} s_i b_i +/// \end{aligned} +/// \\] +pub fn sum_of_powers(x: &Scalar, mut n: usize) -> Scalar { + let mut result = Scalar::zero(); + let mut f = Scalar::one(); // power of x to offset subsequent polynomials based on lower bits of n. + let mut s = Scalar::one(); // power-of-two polynomial: 1, 1+x, 1+x+x^2+x^3, ... + let mut p = *x; // x, x^2, x^4, ..., x^{2^i} + while n > 0 { + // take a bit from n + let bit = n & 1; + n = n >> 1; + + if bit == 1 { + // bits of `n` are not secret, so it's okay to be vartime because of `n` value. + result += f * s; + if n > 0 { // avoid multiplication if no bits left + f = f * p; + } + } + if n > 0 { // avoid multiplication if no bits left + s = s + p * s; + p = p * p; + } } result } -// takes the sum of all of the powers of x, up to n -fn sum_of_powers_slow(x: &Scalar, n: usize) -> Scalar { - exp_iter(*x).take(n).sum() -} - /// Given `data` with `len >= 32`, return the first 32 bytes. pub fn read32(data: &[u8]) -> [u8; 32] { let mut buf32 = [0u8; 32]; @@ -196,9 +243,14 @@ mod tests { ); } + // takes the sum of all of the powers of x, up to n + fn sum_of_powers_slow(x: &Scalar, n: usize) -> Scalar { + exp_iter(*x).take(n).fold(Scalar::zero(), |acc, x| acc + x) + } + #[test] - fn test_sum_of_powers() { - let x = Scalar::from(10u64); + fn test_sum_of_powers_pow2() { + let x = Scalar::from(1337133713371337u64); assert_eq!(sum_of_powers_slow(&x, 0), sum_of_powers(&x, 0)); assert_eq!(sum_of_powers_slow(&x, 1), sum_of_powers(&x, 1)); assert_eq!(sum_of_powers_slow(&x, 2), sum_of_powers(&x, 2)); @@ -210,14 +262,16 @@ mod tests { } #[test] - fn test_sum_of_powers_slow() { + fn test_sum_of_powers_non_pow2() { let x = Scalar::from(10u64); - assert_eq!(sum_of_powers_slow(&x, 0), Scalar::zero()); - assert_eq!(sum_of_powers_slow(&x, 1), Scalar::one()); - assert_eq!(sum_of_powers_slow(&x, 2), Scalar::from(11u64)); - assert_eq!(sum_of_powers_slow(&x, 3), Scalar::from(111u64)); - assert_eq!(sum_of_powers_slow(&x, 4), Scalar::from(1111u64)); - assert_eq!(sum_of_powers_slow(&x, 5), Scalar::from(11111u64)); - assert_eq!(sum_of_powers_slow(&x, 6), Scalar::from(111111u64)); + assert_eq!(sum_of_powers(&x, 0), Scalar::zero()); + assert_eq!(sum_of_powers(&x, 1), Scalar::one()); + assert_eq!(sum_of_powers(&x, 2), Scalar::from(11u64)); + assert_eq!(sum_of_powers(&x, 3), Scalar::from(111u64)); + assert_eq!(sum_of_powers(&x, 4), Scalar::from(1111u64)); + assert_eq!(sum_of_powers(&x, 5), Scalar::from(11111u64)); + assert_eq!(sum_of_powers(&x, 6), Scalar::from(111111u64)); + assert_eq!(sum_of_powers(&x, 7), Scalar::from(1111111u64)); + assert_eq!(sum_of_powers(&x, 8), Scalar::from(11111111u64)); } } From 6f30386bb8d2cdea477152cefc725dfd56bac3eb Mon Sep 17 00:00:00 2001 From: Oleg Andreev Date: Tue, 1 May 2018 10:47:17 -0700 Subject: [PATCH 2/3] wip --- src/util.rs | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/src/util.rs b/src/util.rs index c1fca3fd..d952f2c3 100644 --- a/src/util.rs +++ b/src/util.rs @@ -107,10 +107,10 @@ pub fn scalar_exp_vartime(x: &Scalar, mut n: u64) -> Scalar { } /// Computes the sum of all the powers of \\(x\\) \\(S(n) = (x^0 + \dots + x^{n-1})\\) -/// using \\(O(\lg n)\\) multiplications and additions. Length \\(n\\) is not considered secret -/// and algorithm is fastest when \\(n\\) is the power of two. +/// using \\(O(\lg n)\\) multiplications. Length \\(n\\) is not considered secret +/// and algorithm is fastest when \\(n\\) is the power of two (\\(2\lg n + 1\\) multiplications). /// -/// ### Algorithm overview +/// ### Algorithm description /// /// First, let \\(n\\) be a power of two. /// Then, we can divide the polynomial in two halves like so: @@ -131,9 +131,26 @@ pub fn scalar_exp_vartime(x: &Scalar, mut n: u64) -> Scalar { /// s_i &= s_{i-1} + x^{2^{i-1}} s_{i-1} /// \end{aligned} /// \\] -/// This representation allows us to square \\(x\\) only \\(\lg n\\) times. +/// This representation allows us to do only \\(2 \cdot \lg n\\) multiplications: +/// squaring \\(x\\) and multiplying it by \\(s_{i-1}\\) at each iteration. /// -/// Lets apply this to \\(n\\) which is not a power of two (\\(2^{k-1} < n < 2^k\\)) which can be represented in binary using +/// Lets apply this to \\(n\\) which is not a power of two. The intuition behind the generalized +/// algorithm is to combine all intermediate power-of-two-degree polynomials corresponding to the +/// bits of \\(n\\) that are equal to 1. +/// +/// 1. Represent \\(n\\) in binary. +/// 2. For each bit which is set (from the lowest to the highest): +/// 1. Compute a corresponding power-of-two-degree polynomial using the above algorithm. +/// Since we can reuse all intermediate polynomials, this adds no overhead to computing +/// a polynomial for the highest bit. +/// 2. Multiply the polynomial by the next power of \\(x\\), relative to the degree of the +/// already computed result. This effectively _offsets_ the polynomial to a correct range of +/// powers, so it can be added directly with the rest. +/// The next power of \\(x\\) is computed along all the intermediate polynomials, +/// by multiplying it by power-of-two power of \\(x\\) computed in step 2.1. +/// 3. Add to the result. +/// +/// (\\(2^{k-1} < n < 2^k\\)) which can be represented in binary using /// bits \\(b_i\\) in \\(\\{0,1\\}\\): /// \\[ /// n = b_0 2^0 + \dots + b_{k-1} 2^{k-1} @@ -155,16 +172,16 @@ pub fn scalar_exp_vartime(x: &Scalar, mut n: u64) -> Scalar { /// \\] pub fn sum_of_powers(x: &Scalar, mut n: usize) -> Scalar { let mut result = Scalar::zero(); - let mut f = Scalar::one(); // power of x to offset subsequent polynomials based on lower bits of n. - let mut s = Scalar::one(); // power-of-two polynomial: 1, 1+x, 1+x+x^2+x^3, ... - let mut p = *x; // x, x^2, x^4, ..., x^{2^i} + let mut f = Scalar::one(); // next-power-of-x to offset subsequent polynomials based on preceding bits of n. + let mut s = Scalar::one(); // power-of-two polynomials: (1, 1+x, 1+x+x^2+x^3, 1+...+x^7, , 1+...+x^15, ...) + let mut p = *x; // power-of-two powers of x: (x, x^2, x^4, ..., x^{2^i}) while n > 0 { // take a bit from n let bit = n & 1; n = n >> 1; if bit == 1 { - // bits of `n` are not secret, so it's okay to be vartime because of `n` value. + // `n` is not secret, so it's okay to be vartime on bits of `n`. result += f * s; if n > 0 { // avoid multiplication if no bits left f = f * p; From c13e29b335fbfd4d9c634df9a5e9a6d6edcdc3ba Mon Sep 17 00:00:00 2001 From: Oleg Andreev Date: Mon, 15 Oct 2018 13:46:09 -0700 Subject: [PATCH 3/3] fmt --- src/util.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/util.rs b/src/util.rs index d952f2c3..0d065c7a 100644 --- a/src/util.rs +++ b/src/util.rs @@ -135,7 +135,7 @@ pub fn scalar_exp_vartime(x: &Scalar, mut n: u64) -> Scalar { /// squaring \\(x\\) and multiplying it by \\(s_{i-1}\\) at each iteration. /// /// Lets apply this to \\(n\\) which is not a power of two. The intuition behind the generalized -/// algorithm is to combine all intermediate power-of-two-degree polynomials corresponding to the +/// algorithm is to combine all intermediate power-of-two-degree polynomials corresponding to the /// bits of \\(n\\) that are equal to 1. /// /// 1. Represent \\(n\\) in binary. @@ -179,15 +179,17 @@ pub fn sum_of_powers(x: &Scalar, mut n: usize) -> Scalar { // take a bit from n let bit = n & 1; n = n >> 1; - + if bit == 1 { // `n` is not secret, so it's okay to be vartime on bits of `n`. result += f * s; - if n > 0 { // avoid multiplication if no bits left + if n > 0 { + // avoid multiplication if no bits left f = f * p; } } - if n > 0 { // avoid multiplication if no bits left + if n > 0 { + // avoid multiplication if no bits left s = s + p * s; p = p * p; }