Skip to content

Commit 02ce5db

Browse files
authored
Polynomial and NTT interfaces for Rhizomes (#1399)
Some preliminary work ahead of #1394. - `mod polynomial`: rename functions to make it clear which basis inputs and outputs are in - add `polynomial::poly_mul_lagrange, extend_values_to_power_of_2`, `double_evaluations` - `test_poly_eval_lagrange_batched` now requires all input polynomials to have equal length, per spec. Further assertions will follow in #1394. - add `ntt::ntt_set_s` and various convenience function to `mod ntt` Part of #1394
1 parent f7cb817 commit 02ce5db

File tree

5 files changed

+367
-88
lines changed

5 files changed

+367
-88
lines changed

src/benchmarked.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pub fn benchmarked_gadget_mul_call_poly_ntt<F: NttFriendlyFieldElement>(
1616
outp: &mut [F],
1717
inp: &[Vec<F>],
1818
) -> Result<(), FlpError> {
19-
g.call_poly_ntt(outp, inp)
19+
g.eval_poly_ntt(outp, inp)
2020
}
2121

2222
/// Sets `outp` to `inp[0] * inp[1]`, where `inp[0]` and `inp[1]` are polynomials. This function
@@ -26,5 +26,5 @@ pub fn benchmarked_gadget_mul_call_poly_direct<F: NttFriendlyFieldElement>(
2626
outp: &mut [F],
2727
inp: &[Vec<F>],
2828
) -> Result<(), FlpError> {
29-
g.call_poly_direct(outp, inp)
29+
g.eval_poly_direct(outp, inp)
3030
}

src/flp.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ use crate::dp::DifferentialPrivacyStrategy;
5151
use crate::field::{FieldElement, FieldElementWithInteger, FieldError, NttFriendlyFieldElement};
5252
use crate::fp::log2;
5353
use crate::ntt::{ntt, ntt_inv_finish, NttError};
54-
use crate::polynomial::{nth_root_powers, poly_eval, poly_eval_batched};
54+
use crate::polynomial::{nth_root_powers, poly_eval_lagrange_batched, poly_eval_monomial};
5555
use std::any::Any;
5656
use std::convert::TryFrom;
5757
use std::fmt::Debug;
@@ -471,7 +471,7 @@ pub trait Flp: Sized + Eq + Clone + Debug {
471471
// This avoids using NTTs to convert them to the monomial basis.
472472
let roots = nth_root_powers(m);
473473
let polynomials = &gadget.f_vals[..gadget.arity()];
474-
let mut evals = poly_eval_batched(polynomials, &roots, *query_rand_val);
474+
let mut evals = poly_eval_lagrange_batched(polynomials, &roots, *query_rand_val);
475475
verifier.append(&mut evals);
476476

477477
// Add the value of the gadget polynomial evaluated at the query randomness value.
@@ -613,12 +613,12 @@ pub trait Gadget<F: NttFriendlyFieldElement>: Debug {
613613
/// Evaluate the gadget on input of a sequence of polynomials. The output is written to `outp`.
614614
fn eval_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError>;
615615

616-
/// Returns the arity of the gadget. This is the length of `inp` passed to `call` or
617-
/// `call_poly`.
616+
/// Returns the arity of the gadget. This is the length of `inp` passed to `eval` or
617+
/// `eval_poly`.
618618
fn arity(&self) -> usize;
619619

620620
/// Returns the circuit's arithmetic degree. This determines the minimum length the `outp`
621-
/// buffer passed to `call_poly`.
621+
/// buffer passed to `eval_poly`.
622622
fn degree(&self) -> usize;
623623

624624
/// Returns the number of times the gadget is expected to be called.
@@ -737,7 +737,7 @@ impl<F: NttFriendlyFieldElement> QueryShimGadget<F> {
737737
let step = (1 << (log2(p as u128) - log2(m as u128))) as usize;
738738

739739
// Evaluate the gadget polynomial `p` at query randomness `r`.
740-
let p_at_r = poly_eval(&proof_data[gadget_arity..], r);
740+
let p_at_r = poly_eval_monomial(&proof_data[gadget_arity..], r);
741741

742742
Ok(Self {
743743
inner,
@@ -1169,7 +1169,7 @@ mod tests {
11691169
}
11701170

11711171
// In https://github.com/divviup/libprio-rs/issues/254 an out-of-bounds bug was reported that
1172-
// gets triggered when the size of the buffer passed to `gadget.call_poly()` is larger than
1172+
// gets triggered when the size of the buffer passed to `gadget.eval_poly()` is larger than
11731173
// needed for computing the gadget polynomial.
11741174
#[test]
11751175
fn issue254() {

src/flp/gadgets.rs

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::field::add_vector;
77
use crate::field::NttFriendlyFieldElement;
88
use crate::flp::{gadget_poly_len, wire_poly_len, FlpError, Gadget};
99
use crate::ntt::{ntt, ntt_inv_finish};
10-
use crate::polynomial::{poly_deg, poly_eval, poly_mul};
10+
use crate::polynomial::{poly_deg, poly_eval_monomial, poly_mul_monomial};
1111

1212
#[cfg(feature = "multithreaded")]
1313
use rayon::prelude::*;
@@ -46,18 +46,18 @@ impl<F: NttFriendlyFieldElement> Mul<F> {
4646
}
4747

4848
/// Multiply input polynomials directly.
49-
pub(crate) fn call_poly_direct(
49+
pub(crate) fn eval_poly_direct(
5050
&mut self,
5151
outp: &mut [F],
5252
inp: &[Vec<F>],
5353
) -> Result<(), FlpError> {
54-
let v = poly_mul(&inp[0], &inp[1]);
54+
let v = poly_mul_monomial(&inp[0], &inp[1]);
5555
outp[..v.len()].clone_from_slice(&v);
5656
Ok(())
5757
}
5858

5959
/// Multiply input polynomials using NTT.
60-
pub(crate) fn call_poly_ntt(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
60+
pub(crate) fn eval_poly_ntt(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
6161
let n = self.n;
6262
let mut buf = vec![F::zero(); n];
6363

@@ -83,9 +83,9 @@ impl<F: NttFriendlyFieldElement> Gadget<F> for Mul<F> {
8383
fn eval_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
8484
gadget_call_poly_check(self, outp, inp)?;
8585
if inp[0].len() >= NTT_THRESHOLD {
86-
self.call_poly_ntt(outp, inp)
86+
self.eval_poly_ntt(outp, inp)
8787
} else {
88-
self.call_poly_direct(outp, inp)
88+
self.eval_poly_direct(outp, inp)
8989
}
9090
}
9191

@@ -137,7 +137,7 @@ impl<F: NttFriendlyFieldElement> PolyEval<F> {
137137

138138
impl<F: NttFriendlyFieldElement> PolyEval<F> {
139139
/// Multiply input polynomials directly.
140-
fn call_poly_direct(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
140+
fn eval_poly_direct(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
141141
outp[0] = self.poly[0];
142142
let mut x = inp[0].to_vec();
143143
for i in 1..self.poly.len() {
@@ -146,14 +146,14 @@ impl<F: NttFriendlyFieldElement> PolyEval<F> {
146146
}
147147

148148
if i < self.poly.len() - 1 {
149-
x = poly_mul(&x, &inp[0]);
149+
x = poly_mul_monomial(&x, &inp[0]);
150150
}
151151
}
152152
Ok(())
153153
}
154154

155155
/// Multiply input polynomials using NTT.
156-
fn call_poly_ntt(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
156+
fn eval_poly_ntt(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
157157
let n = self.n;
158158
let inp = &inp[0];
159159

@@ -186,7 +186,7 @@ impl<F: NttFriendlyFieldElement> PolyEval<F> {
186186
impl<F: NttFriendlyFieldElement> Gadget<F> for PolyEval<F> {
187187
fn eval(&mut self, inp: &[F]) -> Result<F, FlpError> {
188188
gadget_call_check(self, inp.len())?;
189-
Ok(poly_eval(&self.poly, inp[0]))
189+
Ok(poly_eval_monomial(&self.poly, inp[0]))
190190
}
191191

192192
fn eval_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
@@ -197,9 +197,9 @@ impl<F: NttFriendlyFieldElement> Gadget<F> for PolyEval<F> {
197197
}
198198

199199
if inp[0].len() >= NTT_THRESHOLD {
200-
self.call_poly_ntt(outp, inp)
200+
self.eval_poly_ntt(outp, inp)
201201
} else {
202-
self.call_poly_direct(outp, inp)
202+
self.eval_poly_direct(outp, inp)
203203
}
204204
}
205205

@@ -322,7 +322,7 @@ where
322322
struct ParallelSumFoldState<F, G> {
323323
/// Inner gadget.
324324
inner: G,
325-
/// Output buffer for `call_poly()`.
325+
/// Output buffer for `eval_poly()`.
326326
partial_output: Vec<F>,
327327
/// Sum accumulator.
328328
partial_sum: Vec<F>,
@@ -405,7 +405,7 @@ where
405405
}
406406
}
407407

408-
/// Check that the input parameters of g.call() are well-formed.
408+
/// Check that the input parameters of g.eval() are well-formed.
409409
fn gadget_call_check<F: NttFriendlyFieldElement, G: Gadget<F>>(
410410
gadget: &G,
411411
in_len: usize,
@@ -425,23 +425,23 @@ fn gadget_call_check<F: NttFriendlyFieldElement, G: Gadget<F>>(
425425
Ok(())
426426
}
427427

428-
/// Check that the input parameters of g.call_poly() are well-formed.
429-
fn gadget_call_poly_check<F: NttFriendlyFieldElement, G: Gadget<F>>(
428+
/// Check that the input parameters of g.eval_poly() are well-formed.
429+
fn gadget_call_poly_check<F: NttFriendlyFieldElement, G: Gadget<F>, P: AsRef<[F]>>(
430430
gadget: &G,
431431
outp: &[F],
432-
inp: &[Vec<F>],
432+
inp: &[P],
433433
) -> Result<(), FlpError> {
434434
gadget_call_check(gadget, inp.len())?;
435435

436436
for i in 1..inp.len() {
437-
if inp[i].len() != inp[0].len() {
437+
if inp[i].as_ref().len() != inp[0].as_ref().len() {
438438
return Err(FlpError::Gadget(
439439
"gadget called on wire polynomials with different lengths".to_string(),
440440
));
441441
}
442442
}
443443

444-
let expected = gadget_poly_len(gadget.degree(), inp[0].len()).next_power_of_two();
444+
let expected = gadget_poly_len(gadget.degree(), inp[0].as_ref().len()).next_power_of_two();
445445
if outp.len() != expected {
446446
return Err(FlpError::Gadget(format!(
447447
"incorrect output length: got {}; want {}",
@@ -550,8 +550,8 @@ mod tests {
550550
}
551551
}
552552

553-
/// Test that calling g.call_poly() and evaluating the output at a given point is equivalent
554-
/// to evaluating each of the inputs at the same point and applying g.call() on the results.
553+
/// Test that calling g.eval_poly() and evaluating the output at a given point is equivalent
554+
/// to evaluating each of the inputs at the same point and applying g.eval() on the results.
555555
fn gadget_test<F: NttFriendlyFieldElement, G: Gadget<F>>(g: &mut G, num_calls: usize) {
556556
let wire_poly_len = (1 + num_calls).next_power_of_two();
557557
let mut prng = Prng::new();
@@ -564,17 +564,17 @@ mod tests {
564564
for out in wire_polys[i].iter_mut().take(wire_poly_len) {
565565
*out = prng.get();
566566
}
567-
inp[i] = poly_eval(&wire_polys[i], r);
567+
inp[i] = poly_eval_monomial(&wire_polys[i], r);
568568
}
569569

570570
g.eval_poly(&mut gadget_poly, &wire_polys).unwrap();
571-
let got = poly_eval(&gadget_poly, r);
571+
let got = poly_eval_monomial(&gadget_poly, r);
572572
let want = g.eval(&inp).unwrap();
573573
assert_eq!(got, want);
574574

575575
// Repeat the call to make sure that the gadget's memory is reset properly between calls.
576576
g.eval_poly(&mut gadget_poly, &wire_polys).unwrap();
577-
let got = poly_eval(&gadget_poly, r);
577+
let got = poly_eval_monomial(&gadget_poly, r);
578578
assert_eq!(got, want);
579579
}
580580
}

0 commit comments

Comments
 (0)