diff --git a/src/benchmarked.rs b/src/benchmarked.rs index 36e90f1b..8d6e0501 100644 --- a/src/benchmarked.rs +++ b/src/benchmarked.rs @@ -16,7 +16,7 @@ pub fn benchmarked_gadget_mul_call_poly_ntt( outp: &mut [F], inp: &[Vec], ) -> Result<(), FlpError> { - g.call_poly_ntt(outp, inp) + g.eval_poly_ntt(outp, inp) } /// Sets `outp` to `inp[0] * inp[1]`, where `inp[0]` and `inp[1]` are polynomials. This function @@ -26,5 +26,5 @@ pub fn benchmarked_gadget_mul_call_poly_direct( outp: &mut [F], inp: &[Vec], ) -> Result<(), FlpError> { - g.call_poly_direct(outp, inp) + g.eval_poly_direct(outp, inp) } diff --git a/src/flp.rs b/src/flp.rs index d4214b0a..56f1b185 100644 --- a/src/flp.rs +++ b/src/flp.rs @@ -51,7 +51,7 @@ use crate::dp::DifferentialPrivacyStrategy; use crate::field::{FieldElement, FieldElementWithInteger, FieldError, NttFriendlyFieldElement}; use crate::fp::log2; use crate::ntt::{ntt, ntt_inv_finish, NttError}; -use crate::polynomial::{nth_root_powers, poly_eval, poly_eval_batched}; +use crate::polynomial::{nth_root_powers, poly_eval_lagrange_batched, poly_eval_monomial}; use std::any::Any; use std::convert::TryFrom; use std::fmt::Debug; @@ -471,7 +471,7 @@ pub trait Flp: Sized + Eq + Clone + Debug { // This avoids using NTTs to convert them to the monomial basis. let roots = nth_root_powers(m); let polynomials = &gadget.f_vals[..gadget.arity()]; - let mut evals = poly_eval_batched(polynomials, &roots, *query_rand_val); + let mut evals = poly_eval_lagrange_batched(polynomials, &roots, *query_rand_val); verifier.append(&mut evals); // Add the value of the gadget polynomial evaluated at the query randomness value. @@ -613,12 +613,12 @@ pub trait Gadget: Debug { /// Evaluate the gadget on input of a sequence of polynomials. The output is written to `outp`. fn eval_poly(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError>; - /// Returns the arity of the gadget. This is the length of `inp` passed to `call` or - /// `call_poly`. + /// Returns the arity of the gadget. This is the length of `inp` passed to `eval` or + /// `eval_poly`. fn arity(&self) -> usize; /// Returns the circuit's arithmetic degree. This determines the minimum length the `outp` - /// buffer passed to `call_poly`. + /// buffer passed to `eval_poly`. fn degree(&self) -> usize; /// Returns the number of times the gadget is expected to be called. @@ -737,7 +737,7 @@ impl QueryShimGadget { let step = (1 << (log2(p as u128) - log2(m as u128))) as usize; // Evaluate the gadget polynomial `p` at query randomness `r`. - let p_at_r = poly_eval(&proof_data[gadget_arity..], r); + let p_at_r = poly_eval_monomial(&proof_data[gadget_arity..], r); Ok(Self { inner, @@ -1169,7 +1169,7 @@ mod tests { } // In https://github.com/divviup/libprio-rs/issues/254 an out-of-bounds bug was reported that - // gets triggered when the size of the buffer passed to `gadget.call_poly()` is larger than + // gets triggered when the size of the buffer passed to `gadget.eval_poly()` is larger than // needed for computing the gadget polynomial. #[test] fn issue254() { diff --git a/src/flp/gadgets.rs b/src/flp/gadgets.rs index 870d2901..1b201566 100644 --- a/src/flp/gadgets.rs +++ b/src/flp/gadgets.rs @@ -7,7 +7,7 @@ use crate::field::add_vector; use crate::field::NttFriendlyFieldElement; use crate::flp::{gadget_poly_len, wire_poly_len, FlpError, Gadget}; use crate::ntt::{ntt, ntt_inv_finish}; -use crate::polynomial::{poly_deg, poly_eval, poly_mul}; +use crate::polynomial::{poly_deg, poly_eval_monomial, poly_mul_monomial}; #[cfg(feature = "multithreaded")] use rayon::prelude::*; @@ -46,18 +46,18 @@ impl Mul { } /// Multiply input polynomials directly. - pub(crate) fn call_poly_direct( + pub(crate) fn eval_poly_direct( &mut self, outp: &mut [F], inp: &[Vec], ) -> Result<(), FlpError> { - let v = poly_mul(&inp[0], &inp[1]); + let v = poly_mul_monomial(&inp[0], &inp[1]); outp[..v.len()].clone_from_slice(&v); Ok(()) } /// Multiply input polynomials using NTT. - pub(crate) fn call_poly_ntt(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError> { + pub(crate) fn eval_poly_ntt(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError> { let n = self.n; let mut buf = vec![F::zero(); n]; @@ -83,9 +83,9 @@ impl Gadget for Mul { fn eval_poly(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError> { gadget_call_poly_check(self, outp, inp)?; if inp[0].len() >= NTT_THRESHOLD { - self.call_poly_ntt(outp, inp) + self.eval_poly_ntt(outp, inp) } else { - self.call_poly_direct(outp, inp) + self.eval_poly_direct(outp, inp) } } @@ -137,7 +137,7 @@ impl PolyEval { impl PolyEval { /// Multiply input polynomials directly. - fn call_poly_direct(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError> { + fn eval_poly_direct(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError> { outp[0] = self.poly[0]; let mut x = inp[0].to_vec(); for i in 1..self.poly.len() { @@ -146,14 +146,14 @@ impl PolyEval { } if i < self.poly.len() - 1 { - x = poly_mul(&x, &inp[0]); + x = poly_mul_monomial(&x, &inp[0]); } } Ok(()) } /// Multiply input polynomials using NTT. - fn call_poly_ntt(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError> { + fn eval_poly_ntt(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError> { let n = self.n; let inp = &inp[0]; @@ -186,7 +186,7 @@ impl PolyEval { impl Gadget for PolyEval { fn eval(&mut self, inp: &[F]) -> Result { gadget_call_check(self, inp.len())?; - Ok(poly_eval(&self.poly, inp[0])) + Ok(poly_eval_monomial(&self.poly, inp[0])) } fn eval_poly(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError> { @@ -197,9 +197,9 @@ impl Gadget for PolyEval { } if inp[0].len() >= NTT_THRESHOLD { - self.call_poly_ntt(outp, inp) + self.eval_poly_ntt(outp, inp) } else { - self.call_poly_direct(outp, inp) + self.eval_poly_direct(outp, inp) } } @@ -322,7 +322,7 @@ where struct ParallelSumFoldState { /// Inner gadget. inner: G, - /// Output buffer for `call_poly()`. + /// Output buffer for `eval_poly()`. partial_output: Vec, /// Sum accumulator. partial_sum: Vec, @@ -405,7 +405,7 @@ where } } -/// Check that the input parameters of g.call() are well-formed. +/// Check that the input parameters of g.eval() are well-formed. fn gadget_call_check>( gadget: &G, in_len: usize, @@ -425,23 +425,23 @@ fn gadget_call_check>( Ok(()) } -/// Check that the input parameters of g.call_poly() are well-formed. -fn gadget_call_poly_check>( +/// Check that the input parameters of g.eval_poly() are well-formed. +fn gadget_call_poly_check, P: AsRef<[F]>>( gadget: &G, outp: &[F], - inp: &[Vec], + inp: &[P], ) -> Result<(), FlpError> { gadget_call_check(gadget, inp.len())?; for i in 1..inp.len() { - if inp[i].len() != inp[0].len() { + if inp[i].as_ref().len() != inp[0].as_ref().len() { return Err(FlpError::Gadget( "gadget called on wire polynomials with different lengths".to_string(), )); } } - let expected = gadget_poly_len(gadget.degree(), inp[0].len()).next_power_of_two(); + let expected = gadget_poly_len(gadget.degree(), inp[0].as_ref().len()).next_power_of_two(); if outp.len() != expected { return Err(FlpError::Gadget(format!( "incorrect output length: got {}; want {}", @@ -550,8 +550,8 @@ mod tests { } } - /// Test that calling g.call_poly() and evaluating the output at a given point is equivalent - /// to evaluating each of the inputs at the same point and applying g.call() on the results. + /// Test that calling g.eval_poly() and evaluating the output at a given point is equivalent + /// to evaluating each of the inputs at the same point and applying g.eval() on the results. fn gadget_test>(g: &mut G, num_calls: usize) { let wire_poly_len = (1 + num_calls).next_power_of_two(); let mut prng = Prng::new(); @@ -564,17 +564,17 @@ mod tests { for out in wire_polys[i].iter_mut().take(wire_poly_len) { *out = prng.get(); } - inp[i] = poly_eval(&wire_polys[i], r); + inp[i] = poly_eval_monomial(&wire_polys[i], r); } g.eval_poly(&mut gadget_poly, &wire_polys).unwrap(); - let got = poly_eval(&gadget_poly, r); + let got = poly_eval_monomial(&gadget_poly, r); let want = g.eval(&inp).unwrap(); assert_eq!(got, want); // Repeat the call to make sure that the gadget's memory is reset properly between calls. g.eval_poly(&mut gadget_poly, &wire_polys).unwrap(); - let got = poly_eval(&gadget_poly, r); + let got = poly_eval_monomial(&gadget_poly, r); assert_eq!(got, want); } } diff --git a/src/ntt.rs b/src/ntt.rs index b955bee0..9b66caaf 100644 --- a/src/ntt.rs +++ b/src/ntt.rs @@ -22,16 +22,53 @@ pub enum NttError { SizeInvalid, } -/// Sets `outp` to the NTT of `inp`. +/// Sets `outp` to the NTT of `inp`, converting a polynomial in the monomial basis to the Lagrange +/// basis. /// /// Interpreting the input as the coefficients of a polynomial, the output is equal to the input -/// evaluated at points `p^0, p^1, ... p^(size-1)`, where `p` is the `2^size`-th principal root of +/// evaluated at points `p^0, p^1, ... p^(size-1)`, where `p` is the `size`-th principal root of /// unity. +/// +/// This corresponds to the `Field.ntt` interface of [6.1.2][1], with `set_s = false`, and uses +/// Algorithm 4 of [Faz25][2]. +/// +/// [1]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-18#section-6.1.2 +/// [2]: https://eprint.iacr.org/2025/1727.pdf +pub(crate) fn ntt( + outp: &mut [F], + inp: &[F], + size: usize, +) -> Result<(), NttError> { + ntt_internal(outp, inp, size, false) +} + +/// Sets `outp` to the NTT of `inp`. +/// +/// Interpreting the input as the coefficients of a polynomial, the output is equal to the input +/// evaluated at points `s * p^0, s * p^1, ... s * p^(size-1)`, where `p` is the size-th principal +/// root of unity and `s` is a (2 * size)-th root of unity. +/// +/// This corresponds to the `Field.ntt` interface of [6.1.2][1], with `set_s = true` and uses +/// Algorithm 4 of [Faz25][2]. +/// +/// [1]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-18#section-6.1.2 +/// [2]: https://eprint.iacr.org/2025/1727.pdf +// TODO(#1394): make available outside of tests +#[cfg(test)] +pub(crate) fn ntt_set_s( + outp: &mut [F], + inp: &[F], + size: usize, +) -> Result<(), NttError> { + ntt_internal(outp, inp, size, true) +} + #[allow(clippy::many_single_char_names)] -pub fn ntt( +fn ntt_internal( outp: &mut [F], inp: &[F], size: usize, + set_s: bool, ) -> Result<(), NttError> { let d = usize::try_from(log2(size as u128)).map_err(|_| NttError::SizeTooLarge)?; @@ -39,7 +76,7 @@ pub fn ntt( return Err(NttError::OutputTooSmall); } - if size > 1 << MAX_ROOTS { + if (set_s && size > 1 << (MAX_ROOTS - 1)) || size > 1 << MAX_ROOTS { return Err(NttError::SizeTooLarge); } @@ -58,7 +95,13 @@ pub fn ntt( let mut w: F; for l in 1..d + 1 { - w = F::one(); + w = if set_s { + // Unwrap safety: we ensure above that size is small enough to ensure that we have all + // the roots we need. + F::root(l + 1).unwrap() + } else { + F::one() + }; let r = F::root(l).unwrap(); let y = 1 << (l - 1); let chunk = (size / y) >> 1; @@ -67,7 +110,7 @@ pub fn ntt( for j in 0..chunk { let x = j << l; let u = outp[x]; - let v = outp[x + y]; + let v = w * outp[x + y]; outp[x] = u + v; outp[x + y] = u - v; } @@ -87,7 +130,21 @@ pub fn ntt( Ok(()) } -/// Sets `outp` to the inverse of the DFT of `inp`. +/// Does the same thing as [`ntt`], but returns the output. +// TODO(#1394): make available outside of tests +#[cfg(test)] +pub(crate) fn get_ntt( + input: &[F], + size: usize, +) -> Result, NttError> { + let mut output = vec![F::zero(); size]; + ntt(&mut output, input, size)?; + + Ok(output) +} + +/// Sets `outp` to the inverse of the NTT of `inp`. +// TODO(#1394): make available outside of tests and config experimental #[cfg(any(test, all(feature = "crypto-dependencies", feature = "experimental")))] pub(crate) fn ntt_inv( outp: &mut [F], @@ -100,6 +157,19 @@ pub(crate) fn ntt_inv( Ok(()) } +/// Does the same thing as [`ntt_inv`], but returns the output. +// TODO(#1394): make available outside of tests +#[cfg(test)] +pub(crate) fn get_ntt_inv( + inp: &[F], + size: usize, +) -> Result, NttError> { + let mut output = vec![F::zero(); size]; + ntt_inv(&mut output, inp, size)?; + + Ok(output) +} + /// An intermediate step in the computation of the inverse DFT. Exposing this function allows us to /// amortize the cost the modular inverse across multiple inverse DFT operations. pub(crate) fn ntt_inv_finish(outp: &mut [F], size: usize, size_inv: F) { @@ -121,26 +191,51 @@ fn bitrev(d: usize, x: usize) -> usize { #[cfg(test)] mod tests { use super::*; - use crate::field::{ - split_vector, Field128, Field64, FieldElement, FieldElementWithInteger, FieldPrio2, + use crate::{ + field::{ + split_vector, Field128, Field64, FieldElement, FieldElementWithInteger, FieldPrio2, + }, + polynomial::poly_eval_monomial, }; fn ntt_then_inv_test() -> Result<(), NttError> { let test_sizes = [1, 2, 4, 8, 16, 256, 1024, 2048]; for size in test_sizes.iter() { - let mut tmp = vec![F::zero(); *size]; - let mut got = vec![F::zero(); *size]; let want = F::random_vector(*size); - ntt(&mut tmp, &want, want.len())?; - ntt_inv(&mut got, &tmp, tmp.len())?; + let tmp = get_ntt(&want, *size)?; + let got = get_ntt_inv(&tmp, tmp.len())?; assert_eq!(got, want); } Ok(()) } + fn test_ntt_set_s() { + for log_n in 0..8 { + let n = 1 << log_n; + let nth_root = F::root(log_n).unwrap(); + let two_nth_root = F::root(log_n + 1).unwrap(); + + // Random polynomial in monomial basis + let p_monomial = F::random_vector(n); + + // Evaluate the polynomial at the powers of an n-th root of unity multiplied by a 2n-th + // root of unity + let monomial_evaluations = (0..n) + .map(|power| F::Integer::try_from(power).unwrap()) + .map(|power| poly_eval_monomial(&p_monomial, two_nth_root * nth_root.pow(power))) + .collect::>(); + + let mut ntt = vec![F::zero(); n]; + ntt_set_s(&mut ntt, &p_monomial, n).unwrap(); + + // Monomial evaluations should match NTT with set_s = true + assert_eq!(monomial_evaluations, ntt); + } + } + #[test] fn test_priov2_field32() { ntt_then_inv_test::().expect("unexpected error"); @@ -149,11 +244,13 @@ mod tests { #[test] fn test_field64() { ntt_then_inv_test::().expect("unexpected error"); + test_ntt_set_s::(); } #[test] fn test_field128() { ntt_then_inv_test::().expect("unexpected error"); + test_ntt_set_s::(); } // This test demonstrates a consequence of \[BBG+19, Fact 4.4\]: interpolating a polynomial @@ -187,8 +284,7 @@ mod tests { } } - let mut want = vec![Field64::zero(); len]; - ntt_inv(&mut want, &x, len).unwrap(); + let want = get_ntt_inv(&x, len).unwrap(); assert_eq!(got, want); } @@ -197,8 +293,7 @@ mod tests { fn test_ntt_interpolation() { let count = 128; let points = Field128::random_vector(count); - let mut poly = vec![Field128::zero(); count]; - ntt(&mut poly, &points, count).unwrap(); + let poly = get_ntt(&points, count).unwrap(); let principal_root = Field128::root(7).unwrap(); // log_2(128); for (power, poly_coeff) in poly.iter().enumerate() { let expected = points diff --git a/src/polynomial.rs b/src/polynomial.rs index 8146c410..fba3fda1 100644 --- a/src/polynomial.rs +++ b/src/polynomial.rs @@ -3,8 +3,8 @@ //! Functions for polynomial interpolation and evaluation -#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] -use crate::ntt::{ntt, ntt_inv_finish}; +#[cfg(test)] +use crate::ntt::{ntt_inv, ntt_set_s, NttError}; use crate::{ field::{FieldElement, NttFriendlyFieldElement}, fp::log2, @@ -12,8 +12,8 @@ use crate::{ use std::convert::TryFrom; -/// Evaluate a polynomial using Horner's method. -pub fn poly_eval(poly: &[F], eval_at: F) -> F { +/// Evaluate a polynomial in the monomial basis using Horner's method. +pub fn poly_eval_monomial(poly: &[F], eval_at: F) -> F { if poly.is_empty() { return F::zero(); } @@ -36,8 +36,8 @@ pub fn poly_deg(p: &[F]) -> usize { d.saturating_sub(1) } -/// Multiplies polynomials `p` and `q` and returns the result. -pub fn poly_mul(p: &[F], q: &[F]) -> Vec { +/// Multiplies polynomials `p` and `q`, given in the monomial basis. +pub fn poly_mul_monomial(p: &[F], q: &[F]) -> Vec { let p_size = poly_deg(p) + 1; let q_size = poly_deg(q) + 1; let mut out = vec![F::zero(); p_size + q_size]; @@ -50,17 +50,21 @@ pub fn poly_mul(p: &[F], q: &[F]) -> Vec { out } -#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] +/// Interpolate a polynomial from the provided points and evaluate it, using `tmp_coeffs` as scratch +/// space for the NTT. +#[cfg(any(test, all(feature = "crypto-dependencies", feature = "experimental")))] #[inline] pub fn poly_interpret_eval( points: &[F], eval_at: F, tmp_coeffs: &mut [F], ) -> F { + use crate::ntt::{ntt, ntt_inv_finish}; + let size_inv = F::from(F::Integer::try_from(points.len()).unwrap()).inv(); ntt(tmp_coeffs, points, points.len()).unwrap(); ntt_inv_finish(tmp_coeffs, points.len(), size_inv); - poly_eval(&tmp_coeffs[..points.len()], eval_at) + poly_eval_monomial(&tmp_coeffs[..points.len()], eval_at) } /// Returns the element `1/n` on `F`, where `n` must be a power of two. @@ -77,19 +81,62 @@ fn inv_pow2(n: usize) -> F { x } -/// Evaluates multiple polynomials given in the Lagrange basis. +/// Multiplies polynomials `p` and `q`, given in the Lagrange basis. The polynomials must have the +/// same length, which must be a power of 2. For input polynomials of length `k` and degree `k - 1`, +/// the output polynomial will have length `2k` and degree `2k - 2`, containing one excess +/// coordinate. This is necessary for compatibility with NTT algorithms. /// -/// This is Algorithm 7 of rhizomes paper. -/// . -pub(crate) fn poly_eval_batched( - polynomials: &[Vec], +/// Implements `Lagrange.poly_mul` of [6.1.3.2][2], using the polynomial multiplication technique of +/// [Faz25 section 3.3][1]. +/// +/// [1]: https://eprint.iacr.org/2025/1727 +/// [2]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-18#section-6.1.3.2 +// TODO(#1394): make available outside of tests +#[cfg(test)] +pub(crate) fn poly_mul_lagrange( + p: &[F], + q: &[F], +) -> Result, NttError> { + assert_eq!(p.len(), q.len()); + assert!(p.len().is_power_of_two()); + + let mut p_doubled = double_evaluations(p)?; + + for (p_element, q_element) in p_doubled.iter_mut().zip(double_evaluations(q)?) { + *p_element *= q_element; + } + + Ok(p_doubled) +} + +/// Evaluates multiple polynomials given in the Lagrange basis. All the polynomials must have the +/// same length `n`, and `roots` must contain the same number of powers of the primitive `n`-th root +/// of unity (which you can compute using [`nth_root_powers`]). +/// +/// Implements `Lagrange.poly_eval_batched` of [6.1.3.2][2], using algorithm 7 from [Faz25][1]. +/// `Lagrange.poly_eval` can be realized by passing a slice containing a single polynomial. +/// +/// [1]: https://eprint.iacr.org/2025/1727 +/// [2]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-18#section-6.1.3.2 +pub(crate) fn poly_eval_lagrange_batched>( + polynomials: &[P], roots: &[F], x: F, ) -> Vec { + let poly_len = polynomials[0].as_ref().len(); + assert!( + polynomials.iter().all(|p| p.as_ref().len() == poly_len), + "polynomials must be of equal length" + ); + // TODO(#1394): enable these assertions once `Flp.query` is fixed to match + // https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-18#section-7.3.4 + // assert!(polynomials[0].as_ref().len().is_power_of_two()); + // assert_eq!(roots.len(), poly_len, "incorrect number of roots provided"); + let mut l = F::one(); let mut u: Vec = polynomials .iter() - .map(|p| p.first().copied().unwrap_or_else(F::zero)) + .map(|p| p.as_ref().first().copied().unwrap_or_else(F::zero)) .collect(); let mut d = roots[0] - x; for (i, wn_i) in (1..).zip(&roots[1..]) { @@ -98,7 +145,7 @@ pub(crate) fn poly_eval_batched( let t = l * *wn_i; for (u_j, poly) in u.iter_mut().zip(polynomials) { *u_j *= d; - if let Some(yi) = poly.get(i) { + if let Some(yi) = poly.as_ref().get(i) { *u_j += t * *yi; } } @@ -154,26 +201,106 @@ pub(crate) fn nth_root_powers(n: usize) -> Vec { roots } -/// Returns a polynomial that evaluates to `0` if the input is in range `[start, end)`. Otherwise, -/// the output is not `0`. +/// Appends evaluations of the polynomial to the provided slice until it is full. The length of the +/// slice must be a power of 2. The slice must contain `num_values` evaluations of the polynomial. +/// The remaining values in the slice are overwritten. +/// +/// Corresponds to `extend_values_to_power_of_2` of [6.1.3.2][1]. +/// +/// [1]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-18#section-6.1.3.2 +// TODO(#1394): make available outside of tests +#[cfg(test)] +pub(crate) fn extend_values_to_power_of_2( + polynomial: &mut [F], + num_values: usize, +) { + let desired_values = polynomial.len(); + assert!(desired_values.is_power_of_two()); + assert!(num_values <= polynomial.len()); + + let root_powers: Vec = nth_root_powers(desired_values); + + let mut w = vec![F::zero(); desired_values]; + for i in 0..num_values { + w[i] = (0..num_values) + .filter(|j| i != *j) + .fold(F::one(), |acc, j| acc * (root_powers[i] - root_powers[j])); + } + + for k in num_values..desired_values { + for i in 0..k { + w[i] *= root_powers[i] - root_powers[k]; + } + + let mut y_numerator = F::zero(); + let mut y_denominator = F::one(); + for (i, value) in polynomial[..k].iter().enumerate() { + y_numerator = y_numerator * w[i] + y_denominator * *value; + y_denominator *= w[i]; + } + + w[k] = (0..k).fold(F::one(), |acc, j| acc * (root_powers[k] - root_powers[j])); + polynomial[k] = -w[k] * y_numerator * y_denominator.inv(); + } +} + +/// Compute `2n` evaluations of the polynomial interpolated from `evaluations`, which consists of +/// `n` Lagrange basis evaluations. `n` must be a power of 2. +/// +/// Corresponds to `double_evaluations` of [6.1.3.2][1]. +/// +/// [1]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-18#section-6.1.3.2 +// TODO(#1394): make available outside of tests +#[cfg(test)] +pub(crate) fn double_evaluations( + evaluations: &[F], +) -> Result, NttError> { + assert!(evaluations.len().is_power_of_two()); + let mut output = vec![F::zero(); evaluations.len() * 2]; + + // Do inverse NTT into the front half of output, then forward NTT into the back half to get odd + // indices of output + let (front, back) = output.split_at_mut(evaluations.len()); + ntt_inv(front, evaluations, evaluations.len())?; + ntt_set_s(back, front, evaluations.len())?; + + // Interleave the input (even indices) with the back half of output (odd indices), into output. + // This is safe to do because any element of pre-overwrite output can only contribute to a + // smaller index post-overwrite, and thus overwriting doesn't destroy any information we need. + for output_position in 0..output.len() { + output[output_position] = if output_position % 2 == 0 { + evaluations[output_position / 2] + } else { + output[evaluations.len() + output_position / 2] + }; + } + + Ok(output) +} + +/// Returns a polynomial in the monomial basis that evaluates to `0` if the input is in range +/// `[start, end)`. Otherwise, the output is not `0`. pub(crate) fn poly_range_check(start: usize, end: usize) -> Vec { let mut p = vec![F::one()]; let mut q = [F::zero(), F::one()]; for i in start..end { q[0] = -F::from(F::Integer::try_from(i).unwrap()); - p = poly_mul(&p, &q); + p = poly_mul_monomial(&p, &q); } p } #[cfg(test)] mod tests { - #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] - use crate::polynomial::{poly_eval_batched, poly_interpret_eval}; use crate::{ field::{Field64, FieldElement, FieldPrio2, NttFriendlyFieldElement}, fp::log2, - polynomial::{nth_root_powers, poly_deg, poly_eval, poly_mul, poly_range_check}, + ntt::get_ntt, + polynomial::{ + double_evaluations, extend_values_to_power_of_2, nth_root_powers, poly_deg, + poly_eval_lagrange_batched, poly_eval_monomial, poly_interpret_eval, poly_mul_lagrange, + poly_mul_monomial, poly_range_check, + }, }; use std::convert::TryFrom; @@ -184,10 +311,10 @@ mod tests { poly[1] = 1.into(); poly[2] = 5.into(); // 5*3^2 + 3 + 2 = 50 - assert_eq!(poly_eval(&poly[..3], 3.into()), 50); + assert_eq!(poly_eval_monomial(&poly[..3], 3.into()), 50); poly[3] = 4.into(); // 4*3^3 + 5*3^2 + 3 + 2 = 158 - assert_eq!(poly_eval(&poly[..4], 3.into()), 158); + assert_eq!(poly_eval_monomial(&poly[..4], 3.into()), 158); } #[test] @@ -224,7 +351,7 @@ mod tests { Field64::from(u64::try_from(15).unwrap()), ]; - let got = poly_mul(&p, &q); + let got = poly_mul_monomial(&p, &q); assert_eq!(&got, &want); } @@ -237,18 +364,18 @@ mod tests { // Check each number in the range. for i in start..end { let x = Field64::from(i as u64); - let y = poly_eval(&p, x); + let y = poly_eval_monomial(&p, x); assert_eq!(y, Field64::zero(), "range check failed for {i}"); } // Check the number below the range. let x = Field64::from((start - 1) as u64); - let y = poly_eval(&p, x); + let y = poly_eval_monomial(&p, x); assert_ne!(y, Field64::zero()); // Check a number above the range. let x = Field64::from(end as u64); - let y = poly_eval(&p, x); + let y = poly_eval_monomial(&p, x); assert_ne!(y, Field64::zero()); } @@ -279,21 +406,17 @@ mod tests { } } - #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] #[test] - fn test_poly_eval_batched() { + fn test_poly_eval_lagrange_batched() { // Single polynomial with constant terms - test_poly_eval_batched_with_lengths(&[1]); + test_poly_eval_lagrange_batched_with_lengths(&[1]); // Constant terms - test_poly_eval_batched_with_lengths(&[1, 1]); + test_poly_eval_lagrange_batched_with_lengths(&[1, 1]); // Powers of two - test_poly_eval_batched_with_lengths(&[1, 2, 4, 16, 64]); - // arbitrary - test_poly_eval_batched_with_lengths(&[1, 6, 3, 9]); + test_poly_eval_lagrange_batched_with_lengths(&[64, 64, 64, 64, 64]); } - #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] - fn test_poly_eval_batched_with_lengths(lengths: &[usize]) { + fn test_poly_eval_lagrange_batched_with_lengths(lengths: &[usize]) { let sizes = lengths .iter() .map(|s| s.next_power_of_two()) @@ -318,8 +441,69 @@ mod tests { }) .collect::>(); - // Simultaneouly evaluates several polynomials directly in the Lagrange basis (batched). - let got = poly_eval_batched(&polynomials, &roots, x); + // Simultaneously evaluates several polynomials directly in the Lagrange basis (batched). + let got = poly_eval_lagrange_batched(&polynomials, &roots, x); assert_eq!(got, want, "sizes: {sizes:?} x: {x} P: {polynomials:?}"); } + + #[test] + fn test_poly_mul_lagrange() { + for log_n in 0..8 { + let n = 1 << log_n; + + let p_monomial = Field64::random_vector(n); + let q_monomial = Field64::random_vector(n); + + let p_lagrange = get_ntt(&p_monomial, n).unwrap(); + let q_lagrange = get_ntt(&q_monomial, n).unwrap(); + + let product_lagrange = poly_mul_lagrange(&p_lagrange, &q_lagrange).unwrap(); + let product_monomial = poly_mul_monomial(&p_monomial, &q_monomial); + let product_monomial_ntt = get_ntt(&product_monomial, 2 * n).unwrap(); + assert_eq!(product_lagrange, product_monomial_ntt); + } + } + + #[test] + fn test_extend_values_to_power_of_2() { + for log_n in 0..7 { + let n = 1 << log_n; + for k in 0..n + 1 { + // Random monomial polynomial of degree k - 1 + let mut p_monomial = Field64::random_vector(k); + p_monomial.extend_from_slice(&vec![Field64::zero(); n - k]); + + // Convert to Lagrange basis + let p_lagrange = get_ntt(&p_monomial, n).unwrap(); + + // Truncate to k values + let mut p_lagrange_truncated = p_lagrange.clone(); + for element in p_lagrange_truncated.iter_mut().skip(k) { + *element = Field64::zero(); + } + + // Recover the n Lagrange basis values + extend_values_to_power_of_2(&mut p_lagrange_truncated, k); + + assert_eq!(p_lagrange_truncated, p_lagrange, "log_n = {log_n} k = {k}"); + } + } + } + + #[test] + fn test_double_evaluations() { + for log_n in 0..8 { + let n = 1 << log_n; + // Random monomial polynomial + let p_monomial = Field64::random_vector(n); + + // Convert to Lagrange basis + let p_lagrange = get_ntt(&p_monomial, n).unwrap(); + + assert_eq!( + double_evaluations(&p_lagrange).unwrap(), + get_ntt(&p_monomial, 2 * n).unwrap() + ); + } + } }