Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/flp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ use crate::dp::DifferentialPrivacyStrategy;
use crate::field::{FieldElement, FieldElementWithInteger, FieldError, NttFriendlyFieldElement};
use crate::fp::log2;
use crate::ntt::{ntt, ntt_inv_finish, NttError};
use crate::polynomial::{nth_root_powers, poly_eval, poly_eval_batched};
use crate::polynomial::{nth_root_powers, poly_eval_lagrange_batched, poly_eval_monomial};
use std::any::Any;
use std::convert::TryFrom;
use std::fmt::Debug;
Expand Down Expand Up @@ -471,7 +471,7 @@ pub trait Flp: Sized + Eq + Clone + Debug {
// This avoids using NTTs to convert them to the monomial basis.
let roots = nth_root_powers(m);
let polynomials = &gadget.f_vals[..gadget.arity()];
let mut evals = poly_eval_batched(polynomials, &roots, *query_rand_val);
let mut evals = poly_eval_lagrange_batched(polynomials, &roots, *query_rand_val);
verifier.append(&mut evals);

// Add the value of the gadget polynomial evaluated at the query randomness value.
Expand Down Expand Up @@ -737,7 +737,7 @@ impl<F: NttFriendlyFieldElement> QueryShimGadget<F> {
let step = (1 << (log2(p as u128) - log2(m as u128))) as usize;

// Evaluate the gadget polynomial `p` at query randomness `r`.
let p_at_r = poly_eval(&proof_data[gadget_arity..], r);
let p_at_r = poly_eval_monomial(&proof_data[gadget_arity..], r);

Ok(Self {
inner,
Expand Down
22 changes: 11 additions & 11 deletions src/flp/gadgets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::field::add_vector;
use crate::field::NttFriendlyFieldElement;
use crate::flp::{gadget_poly_len, wire_poly_len, FlpError, Gadget};
use crate::ntt::{ntt, ntt_inv_finish};
use crate::polynomial::{poly_deg, poly_eval, poly_mul};
use crate::polynomial::{poly_deg, poly_eval_monomial, poly_mul_monomial};

#[cfg(feature = "multithreaded")]
use rayon::prelude::*;
Expand Down Expand Up @@ -51,7 +51,7 @@ impl<F: NttFriendlyFieldElement> Mul<F> {
outp: &mut [F],
inp: &[Vec<F>],
) -> Result<(), FlpError> {
let v = poly_mul(&inp[0], &inp[1]);
let v = poly_mul_monomial(&inp[0], &inp[1]);
outp[..v.len()].clone_from_slice(&v);
Ok(())
}
Expand Down Expand Up @@ -146,7 +146,7 @@ impl<F: NttFriendlyFieldElement> PolyEval<F> {
}

if i < self.poly.len() - 1 {
x = poly_mul(&x, &inp[0]);
x = poly_mul_monomial(&x, &inp[0]);
}
}
Ok(())
Expand Down Expand Up @@ -186,7 +186,7 @@ impl<F: NttFriendlyFieldElement> PolyEval<F> {
impl<F: NttFriendlyFieldElement> Gadget<F> for PolyEval<F> {
fn eval(&mut self, inp: &[F]) -> Result<F, FlpError> {
gadget_call_check(self, inp.len())?;
Ok(poly_eval(&self.poly, inp[0]))
Ok(poly_eval_monomial(&self.poly, inp[0]))
}

fn eval_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
Expand Down Expand Up @@ -426,22 +426,22 @@ fn gadget_call_check<F: NttFriendlyFieldElement, G: Gadget<F>>(
}

/// Check that the input parameters of g.call_poly() are well-formed.
fn gadget_call_poly_check<F: NttFriendlyFieldElement, G: Gadget<F>>(
fn gadget_call_poly_check<F: NttFriendlyFieldElement, G: Gadget<F>, P: AsRef<[F]>>(
gadget: &G,
outp: &[F],
inp: &[Vec<F>],
inp: &[P],
) -> Result<(), FlpError> {
gadget_call_check(gadget, inp.len())?;

for i in 1..inp.len() {
if inp[i].len() != inp[0].len() {
if inp[i].as_ref().len() != inp[0].as_ref().len() {
return Err(FlpError::Gadget(
"gadget called on wire polynomials with different lengths".to_string(),
));
}
}

let expected = gadget_poly_len(gadget.degree(), inp[0].len()).next_power_of_two();
let expected = gadget_poly_len(gadget.degree(), inp[0].as_ref().len()).next_power_of_two();
if outp.len() != expected {
return Err(FlpError::Gadget(format!(
"incorrect output length: got {}; want {}",
Expand Down Expand Up @@ -564,17 +564,17 @@ mod tests {
for out in wire_polys[i].iter_mut().take(wire_poly_len) {
*out = prng.get();
}
inp[i] = poly_eval(&wire_polys[i], r);
inp[i] = poly_eval_monomial(&wire_polys[i], r);
}

g.eval_poly(&mut gadget_poly, &wire_polys).unwrap();
let got = poly_eval(&gadget_poly, r);
let got = poly_eval_monomial(&gadget_poly, r);
let want = g.eval(&inp).unwrap();
assert_eq!(got, want);

// Repeat the call to make sure that the gadget's memory is reset properly between calls.
g.eval_poly(&mut gadget_poly, &wire_polys).unwrap();
let got = poly_eval(&gadget_poly, r);
let got = poly_eval_monomial(&gadget_poly, r);
assert_eq!(got, want);
}
}
123 changes: 108 additions & 15 deletions src/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,53 @@ pub enum NttError {
SizeInvalid,
}

/// Sets `outp` to the NTT of `inp`.
/// Sets `outp` to the NTT of `inp`, converting a polynomial in the monomial basis to the Lagrange
/// basis.
///
/// Interpreting the input as the coefficients of a polynomial, the output is equal to the input
/// evaluated at points `p^0, p^1, ... p^(size-1)`, where `p` is the `2^size`-th principal root of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This talks about the 2^size-th ... root, but size is already a power of two (e.g., 128), so would this be just the size-th root, like in the next comment down on line 48?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right. I also updated the comment on ntt_set_s to say p is a size-th root and s is a 2size-th root.

/// unity.
///
/// This corresponds to the `Field.ntt` interface of [6.1.2][1], with `set_s = false`, and uses
/// Algorithm 4 of [Faz25][2].
///
/// [1]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-18#section-6.1.2
/// [2]: https://eprint.iacr.org/2025/1727.pdf
pub(crate) fn ntt<F: NttFriendlyFieldElement>(
outp: &mut [F],
inp: &[F],
size: usize,
) -> Result<(), NttError> {
ntt_internal(outp, inp, size, false)
}

/// Sets `outp` to the NTT of `inp`.
///
/// Interpreting the input as the coefficients of a polynomial, the output is equal to the input
/// evaluated at points `s * p^0, s * p^1, ... s * p^(size-1)`, where `p` is the n-th principal
/// root of unity and `s` is a 2n-th root of unity.
///
/// This corresponds to the `Field.ntt` interface of [6.1.2][1], with `set_s = true` and uses
/// Algorithm 4 of [Faz25][2].
///
/// [1]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-18#section-6.1.2
/// [2]: https://eprint.iacr.org/2025/1727.pdf
// TODO(#1394): make available outside of tests
#[cfg(test)]
pub(crate) fn ntt_set_s<F: NttFriendlyFieldElement>(
outp: &mut [F],
inp: &[F],
size: usize,
) -> Result<(), NttError> {
ntt_internal(outp, inp, size, true)
}

#[allow(clippy::many_single_char_names)]
pub fn ntt<F: NttFriendlyFieldElement>(
fn ntt_internal<F: NttFriendlyFieldElement>(
outp: &mut [F],
inp: &[F],
size: usize,
set_s: bool,
) -> Result<(), NttError> {
let d = usize::try_from(log2(size as u128)).map_err(|_| NttError::SizeTooLarge)?;

Expand All @@ -58,7 +95,11 @@ pub fn ntt<F: NttFriendlyFieldElement>(

let mut w: F;
for l in 1..d + 1 {
w = F::one();
w = if set_s {
F::root(l + 1).unwrap()
} else {
F::one()
};
let r = F::root(l).unwrap();
let y = 1 << (l - 1);
let chunk = (size / y) >> 1;
Expand All @@ -67,7 +108,7 @@ pub fn ntt<F: NttFriendlyFieldElement>(
for j in 0..chunk {
let x = j << l;
let u = outp[x];
let v = outp[x + y];
let v = w * outp[x + y];
outp[x] = u + v;
outp[x + y] = u - v;
}
Expand All @@ -87,7 +128,21 @@ pub fn ntt<F: NttFriendlyFieldElement>(
Ok(())
}

/// Sets `outp` to the inverse of the DFT of `inp`.
/// Does the same thing as [`ntt`], but returns the output.
// TODO(#1394): make available outside of tests
#[cfg(test)]
pub(crate) fn get_ntt<F: NttFriendlyFieldElement>(
input: &[F],
size: usize,
) -> Result<Vec<F>, NttError> {
let mut output = vec![F::zero(); size];
ntt(&mut output, input, size)?;

Ok(output)
}

/// Sets `outp` to the inverse of the NTT of `inp`.
// TODO(#1394): make available outside of tests and config experimental
#[cfg(any(test, all(feature = "crypto-dependencies", feature = "experimental")))]
pub(crate) fn ntt_inv<F: NttFriendlyFieldElement>(
outp: &mut [F],
Expand All @@ -100,6 +155,19 @@ pub(crate) fn ntt_inv<F: NttFriendlyFieldElement>(
Ok(())
}

/// Does the same thing as [`ntt_inv`], but returns the output.
// TODO(#1394): make available outside of tests
#[cfg(test)]
pub(crate) fn get_ntt_inv<F: NttFriendlyFieldElement>(
inp: &[F],
size: usize,
) -> Result<Vec<F>, NttError> {
let mut output = vec![F::zero(); size];
ntt_inv(&mut output, inp, size)?;

Ok(output)
}

/// An intermediate step in the computation of the inverse DFT. Exposing this function allows us to
/// amortize the cost the modular inverse across multiple inverse DFT operations.
pub(crate) fn ntt_inv_finish<F: NttFriendlyFieldElement>(outp: &mut [F], size: usize, size_inv: F) {
Expand All @@ -121,26 +189,51 @@ fn bitrev(d: usize, x: usize) -> usize {
#[cfg(test)]
mod tests {
use super::*;
use crate::field::{
split_vector, Field128, Field64, FieldElement, FieldElementWithInteger, FieldPrio2,
use crate::{
field::{
split_vector, Field128, Field64, FieldElement, FieldElementWithInteger, FieldPrio2,
},
polynomial::poly_eval_monomial,
};

fn ntt_then_inv_test<F: NttFriendlyFieldElement>() -> Result<(), NttError> {
let test_sizes = [1, 2, 4, 8, 16, 256, 1024, 2048];

for size in test_sizes.iter() {
let mut tmp = vec![F::zero(); *size];
let mut got = vec![F::zero(); *size];
let want = F::random_vector(*size);

ntt(&mut tmp, &want, want.len())?;
ntt_inv(&mut got, &tmp, tmp.len())?;
let tmp = get_ntt(&want, *size)?;
let got = get_ntt_inv(&tmp, tmp.len())?;
assert_eq!(got, want);
}

Ok(())
}

fn test_ntt_set_s<F: NttFriendlyFieldElement>() {
for log_n in 0..8 {
let n = 1 << log_n;
let nth_root = F::root(log_n).unwrap();
let two_nth_root = F::root(log_n + 1).unwrap();

// Random polynomial in monomial basis
let p_monomial = F::random_vector(n);

// Evaluate the polynomial at the powers of an n-th root of unity multiplied by a 2n-th
// root of unity
let monomial_evaluations = (0..n)
.map(|power| F::Integer::try_from(power).unwrap())
.map(|power| poly_eval_monomial(&p_monomial, two_nth_root * nth_root.pow(power)))
.collect::<Vec<_>>();

let mut ntt = vec![F::zero(); n];
ntt_set_s(&mut ntt, &p_monomial, n).unwrap();

// Monomial evaluations should match NTT with set_s = true
assert_eq!(monomial_evaluations, ntt);
}
}

#[test]
fn test_priov2_field32() {
ntt_then_inv_test::<FieldPrio2>().expect("unexpected error");
Expand All @@ -149,11 +242,13 @@ mod tests {
#[test]
fn test_field64() {
ntt_then_inv_test::<Field64>().expect("unexpected error");
test_ntt_set_s::<Field64>();
}

#[test]
fn test_field128() {
ntt_then_inv_test::<Field128>().expect("unexpected error");
test_ntt_set_s::<Field128>();
}

// This test demonstrates a consequence of \[BBG+19, Fact 4.4\]: interpolating a polynomial
Expand Down Expand Up @@ -187,8 +282,7 @@ mod tests {
}
}

let mut want = vec![Field64::zero(); len];
ntt_inv(&mut want, &x, len).unwrap();
let want = get_ntt_inv(&x, len).unwrap();

assert_eq!(got, want);
}
Expand All @@ -197,8 +291,7 @@ mod tests {
fn test_ntt_interpolation() {
let count = 128;
let points = Field128::random_vector(count);
let mut poly = vec![Field128::zero(); count];
ntt(&mut poly, &points, count).unwrap();
let poly = get_ntt(&points, count).unwrap();
let principal_root = Field128::root(7).unwrap(); // log_2(128);
for (power, poly_coeff) in poly.iter().enumerate() {
let expected = points
Expand Down
Loading
Loading