Skip to content

Commit f7cb817

Browse files
authored
Align trait Gadget interface with spec (#1398)
The `Gadget` interface of [7.3.2][1] uses the verb `eval` instead of `call`. Using that term in the implementation makes it easier to follow. [1]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-18#section-7.3.2
1 parent 4cb3884 commit f7cb817

File tree

4 files changed

+40
-40
lines changed

4 files changed

+40
-40
lines changed

src/flp.rs

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ pub trait Flp: Sized + Eq + Clone + Debug {
318318
let gadget_poly_len = gadget_poly_len(gadget.degree(), m);
319319
let start = proof_len + gadget.arity();
320320
let end = start + gadget_poly_len.next_power_of_two();
321-
gadget.call_poly(&mut proof[start..end], &f)?;
321+
gadget.eval_poly(&mut proof[start..end], &f)?;
322322
proof_len += gadget.arity() + gadget_poly_len;
323323
}
324324

@@ -503,7 +503,7 @@ pub trait Flp: Sized + Eq + Clone + Debug {
503503
for gadget in gadgets.iter_mut() {
504504
let next_len = 1 + gadget.arity();
505505

506-
let e = gadget.call(&verifier[verifier_len..verifier_len + next_len - 1])?;
506+
let e = gadget.eval(&verifier[verifier_len..verifier_len + next_len - 1])?;
507507
if e != verifier[verifier_len + next_len - 1] {
508508
return Ok(false);
509509
}
@@ -608,10 +608,10 @@ where
608608
/// A gadget, a non-affine arithmetic circuit that is called when evaluating a validity circuit.
609609
pub trait Gadget<F: NttFriendlyFieldElement>: Debug {
610610
/// Evaluates the gadget on input `inp` and returns the output.
611-
fn call(&mut self, inp: &[F]) -> Result<F, FlpError>;
611+
fn eval(&mut self, inp: &[F]) -> Result<F, FlpError>;
612612

613613
/// Evaluate the gadget on input of a sequence of polynomials. The output is written to `outp`.
614-
fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError>;
614+
fn eval_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError>;
615615

616616
/// Returns the arity of the gadget. This is the length of `inp` passed to `call` or
617617
/// `call_poly`.
@@ -661,16 +661,16 @@ impl<F: NttFriendlyFieldElement> ProveShimGadget<F> {
661661
}
662662

663663
impl<F: NttFriendlyFieldElement> Gadget<F> for ProveShimGadget<F> {
664-
fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
664+
fn eval(&mut self, inp: &[F]) -> Result<F, FlpError> {
665665
for (wire_poly_vals, inp_val) in self.f_vals[..inp.len()].iter_mut().zip(inp.iter()) {
666666
wire_poly_vals[self.ct] = *inp_val;
667667
}
668668
self.ct += 1;
669-
self.inner.call(inp)
669+
self.inner.eval(inp)
670670
}
671671

672-
fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
673-
self.inner.call_poly(outp, inp)
672+
fn eval_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
673+
self.inner.eval_poly(outp, inp)
674674
}
675675

676676
fn arity(&self) -> usize {
@@ -751,7 +751,7 @@ impl<F: NttFriendlyFieldElement> QueryShimGadget<F> {
751751
}
752752

753753
impl<F: NttFriendlyFieldElement> Gadget<F> for QueryShimGadget<F> {
754-
fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
754+
fn eval(&mut self, inp: &[F]) -> Result<F, FlpError> {
755755
for (wire_poly_vals, inp_val) in self.f_vals[..inp.len()].iter_mut().zip(inp.iter()) {
756756
wire_poly_vals[self.ct] = *inp_val;
757757
}
@@ -760,7 +760,7 @@ impl<F: NttFriendlyFieldElement> Gadget<F> for QueryShimGadget<F> {
760760
Ok(outp)
761761
}
762762

763-
fn call_poly(&mut self, _outp: &mut [F], _inp: &[Vec<F>]) -> Result<(), FlpError> {
763+
fn eval_poly(&mut self, _outp: &mut [F], _inp: &[Vec<F>]) -> Result<(), FlpError> {
764764
panic!("no-op");
765765
}
766766

@@ -1078,13 +1078,13 @@ mod tests {
10781078

10791079
// Check that `data[0]^3 == data[1]`.
10801080
let mut inp = [input[0], input[0]];
1081-
inp[0] = g[0].call(&inp)?;
1082-
inp[0] = g[0].call(&inp)?;
1081+
inp[0] = g[0].eval(&inp)?;
1082+
inp[0] = g[0].eval(&inp)?;
10831083
let x3_diff = inp[0] - input[1];
10841084
res += r * x3_diff;
10851085

10861086
// Check that `data[0]` is in the correct range.
1087-
let x_checked = g[1].call(&[input[0]])?;
1087+
let x_checked = g[1].eval(&[input[0]])?;
10881088
res += (r * r) * x_checked;
10891089

10901090
Ok(vec![res])
@@ -1218,10 +1218,10 @@ mod tests {
12181218
// use of multiple gadgets, each of which is called an arbitrary number of times.
12191219
let mut res = F::zero();
12201220
for _ in 0..self.num_gadget_calls[0] {
1221-
res += g[0].call(&[input[0]])?;
1221+
res += g[0].eval(&[input[0]])?;
12221222
}
12231223
for _ in 0..self.num_gadget_calls[1] {
1224-
res += g[1].call(&[input[0]])?;
1224+
res += g[1].eval(&[input[0]])?;
12251225
}
12261226
Ok(vec![res])
12271227
}

src/flp/gadgets.rs

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,12 @@ impl<F: NttFriendlyFieldElement> Mul<F> {
7575
}
7676

7777
impl<F: NttFriendlyFieldElement> Gadget<F> for Mul<F> {
78-
fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
78+
fn eval(&mut self, inp: &[F]) -> Result<F, FlpError> {
7979
gadget_call_check(self, inp.len())?;
8080
Ok(inp[0] * inp[1])
8181
}
8282

83-
fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
83+
fn eval_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
8484
gadget_call_poly_check(self, outp, inp)?;
8585
if inp[0].len() >= NTT_THRESHOLD {
8686
self.call_poly_ntt(outp, inp)
@@ -184,12 +184,12 @@ impl<F: NttFriendlyFieldElement> PolyEval<F> {
184184
}
185185

186186
impl<F: NttFriendlyFieldElement> Gadget<F> for PolyEval<F> {
187-
fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
187+
fn eval(&mut self, inp: &[F]) -> Result<F, FlpError> {
188188
gadget_call_check(self, inp.len())?;
189189
Ok(poly_eval(&self.poly, inp[0]))
190190
}
191191

192-
fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
192+
fn eval_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
193193
gadget_call_poly_check(self, outp, inp)?;
194194

195195
for item in outp.iter_mut() {
@@ -249,16 +249,16 @@ impl<F: NttFriendlyFieldElement, G: 'static + Gadget<F>> ParallelSumGadget<F, G>
249249
}
250250

251251
impl<F: NttFriendlyFieldElement, G: 'static + Gadget<F>> Gadget<F> for ParallelSum<F, G> {
252-
fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
252+
fn eval(&mut self, inp: &[F]) -> Result<F, FlpError> {
253253
gadget_call_check(self, inp.len())?;
254254
let mut outp = F::zero();
255255
for chunk in inp.chunks(self.inner.arity()) {
256-
outp += self.inner.call(chunk)?;
256+
outp += self.inner.eval(chunk)?;
257257
}
258258
Ok(outp)
259259
}
260260

261-
fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
261+
fn eval_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
262262
gadget_call_poly_check(self, outp, inp)?;
263263

264264
for x in outp.iter_mut() {
@@ -268,7 +268,7 @@ impl<F: NttFriendlyFieldElement, G: 'static + Gadget<F>> Gadget<F> for ParallelS
268268
let mut partial_outp = vec![F::zero(); outp.len()];
269269

270270
for chunk in inp.chunks(self.inner.arity()) {
271-
self.inner.call_poly(&mut partial_outp, chunk)?;
271+
self.inner.eval_poly(&mut partial_outp, chunk)?;
272272
for i in 0..outp.len() {
273273
outp[i] += partial_outp[i]
274274
}
@@ -349,11 +349,11 @@ where
349349
F: NttFriendlyFieldElement + Sync + Send,
350350
G: 'static + Gadget<F> + Clone + Sync + Send,
351351
{
352-
fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
353-
self.serial_sum.call(inp)
352+
fn eval(&mut self, inp: &[F]) -> Result<F, FlpError> {
353+
self.serial_sum.eval(inp)
354354
}
355355

356-
fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
356+
fn eval_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
357357
gadget_call_poly_check(self, outp, inp)?;
358358

359359
// Create a copy of the inner gadget and two working buffers on each thread. Evaluate the
@@ -369,7 +369,7 @@ where
369369
|mut state, chunk| {
370370
state
371371
.inner
372-
.call_poly(&mut state.partial_output, chunk)
372+
.eval_poly(&mut state.partial_output, chunk)
373373
.unwrap();
374374
for (sum_elem, output_elem) in state
375375
.partial_sum
@@ -524,8 +524,8 @@ mod tests {
524524

525525
// Test that both gadgets evaluate to the same value when run on scalar inputs.
526526
let inp = TestField::random_vector(arity);
527-
let result = g.call(&inp).unwrap();
528-
let result_serial = g_serial.call(&inp).unwrap();
527+
let result = g.eval(&inp).unwrap();
528+
let result_serial = g_serial.eval(&inp).unwrap();
529529
assert_eq!(result, result_serial);
530530

531531
// Test that both gadgets evaluate to the same value when run on polynomial inputs.
@@ -542,9 +542,9 @@ mod tests {
542542
.take(arity)
543543
.collect();
544544

545-
g.call_poly(&mut poly_outp, &poly_inp).unwrap();
545+
g.eval_poly(&mut poly_outp, &poly_inp).unwrap();
546546
g_serial
547-
.call_poly(&mut poly_outp_serial, &poly_inp)
547+
.eval_poly(&mut poly_outp_serial, &poly_inp)
548548
.unwrap();
549549
assert_eq!(poly_outp, poly_outp_serial);
550550
}
@@ -567,13 +567,13 @@ mod tests {
567567
inp[i] = poly_eval(&wire_polys[i], r);
568568
}
569569

570-
g.call_poly(&mut gadget_poly, &wire_polys).unwrap();
570+
g.eval_poly(&mut gadget_poly, &wire_polys).unwrap();
571571
let got = poly_eval(&gadget_poly, r);
572-
let want = g.call(&inp).unwrap();
572+
let want = g.eval(&inp).unwrap();
573573
assert_eq!(got, want);
574574

575575
// Repeat the call to make sure that the gadget's memory is reset properly between calls.
576-
g.call_poly(&mut gadget_poly, &wire_polys).unwrap();
576+
g.eval_poly(&mut gadget_poly, &wire_polys).unwrap();
577577
let got = poly_eval(&gadget_poly, r);
578578
assert_eq!(got, want);
579579
}

src/flp/types.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ impl<F: NttFriendlyFieldElement> Flp for Count<F> {
6868
_num_shares: usize,
6969
) -> Result<Vec<F>, FlpError> {
7070
self.valid_call_check(input, joint_rand)?;
71-
let out = g[0].call(&[input[0], input[0]])? - input[0];
71+
let out = g[0].eval(&[input[0], input[0]])? - input[0];
7272
Ok(vec![out])
7373
}
7474

@@ -210,7 +210,7 @@ impl<F: NttFriendlyFieldElement> Flp for Sum<F> {
210210
let gadget = &mut g[0];
211211
let mut output = vec![F::zero(); input.len()];
212212
for (bit, output_elem) in input.iter().zip(output[..input.len()].iter_mut()) {
213-
*output_elem = gadget.call(slice::from_ref(bit))?;
213+
*output_elem = gadget.eval(slice::from_ref(bit))?;
214214
}
215215

216216
Ok(output)
@@ -1071,7 +1071,7 @@ pub(crate) fn parallel_sum_range_checks<F: NttFriendlyFieldElement>(
10711071
// accessed again before returning.
10721072
}
10731073

1074-
output += gadget.call(&padded_chunk)?;
1074+
output += gadget.eval(&padded_chunk)?;
10751075
}
10761076

10771077
Ok(output)
@@ -1626,7 +1626,7 @@ mod tests {
16261626
_num_shares: usize,
16271627
) -> Result<Vec<Self::Field>, FlpError> {
16281628
self.valid_call_check(input, joint_rand)?;
1629-
let check = gadgets[0].call(input)?;
1629+
let check = gadgets[0].eval(input)?;
16301630
Ok(vec![check])
16311631
}
16321632

src/flp/types/fixedpoint_l2.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,13 +477,13 @@ where
477477
for chunk in decoded_entries?.chunks(self.gadget1_chunk_length) {
478478
let d = chunk.len();
479479
if d == self.gadget1_chunk_length {
480-
outp += g[1].call(chunk)?;
480+
outp += g[1].eval(chunk)?;
481481
} else {
482482
// If the chunk is smaller than the chunk length, extend
483483
// chunk with zeros.
484484
let mut padded_chunk: Vec<_> = chunk.to_owned();
485485
padded_chunk.resize(self.gadget1_chunk_length, zero_enc_share);
486-
outp += g[1].call(&padded_chunk)?;
486+
outp += g[1].eval(&padded_chunk)?;
487487
}
488488
}
489489

0 commit comments

Comments
 (0)