Skip to content

Commit 53a5a65

Browse files
authored
Pairwise compression for TimeProver (#83)
* fix typo, chng comment * chkpt * clippy * support n streams multilinear TimeProver * refactor vsbw * chkpt * chkpt * fix tests * remove extension benches * cleanup no default features * lint
1 parent 1c59c61 commit 53a5a65

File tree

17 files changed

+393
-214
lines changed

17 files changed

+393
-214
lines changed

src/multilinear/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ mod sumcheck;
44
pub use provers::{
55
blendy::{BlendyProver, BlendyProverConfig},
66
space::{SpaceProver, SpaceProverConfig},
7-
time::{TimeProver, TimeProverConfig},
7+
time::{ReduceMode, TimeProver, TimeProverConfig},
88
};
99
pub use sumcheck::Sumcheck;

src/multilinear/provers/space/config.rs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use ark_ff::Field;
22

3-
use crate::{prover::ProverConfig, streams::Stream};
3+
use crate::{
4+
prover::{BatchProverConfig, ProverConfig},
5+
streams::Stream,
6+
};
47

58
pub struct SpaceProverConfig<F, S>
69
where
@@ -9,7 +12,7 @@ where
912
{
1013
pub num_variables: usize,
1114
pub claim: F,
12-
pub stream: S,
15+
pub streams: Vec<S>,
1316
}
1417

1518
impl<F, S> SpaceProverConfig<F, S>
@@ -21,7 +24,7 @@ where
2124
Self {
2225
claim,
2326
num_variables,
24-
stream,
27+
streams: vec![stream],
2528
}
2629
}
2730
}
@@ -31,7 +34,17 @@ impl<F: Field, S: Stream<F>> ProverConfig<F, S> for SpaceProverConfig<F, S> {
3134
Self {
3235
claim,
3336
num_variables,
34-
stream,
37+
streams: vec![stream],
38+
}
39+
}
40+
}
41+
42+
impl<F: Field, S: Stream<F>> BatchProverConfig<F, S> for SpaceProverConfig<F, S> {
43+
fn default(claim: F, num_variables: usize, streams: Vec<S>) -> Self {
44+
Self {
45+
claim,
46+
num_variables,
47+
streams,
3548
}
3649
}
3750
}

src/multilinear/provers/space/core.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::{
88
pub struct SpaceProver<F: Field, S: Stream<F>> {
99
pub claim: F,
1010
pub current_round: usize,
11-
pub evaluation_stream: S,
11+
pub evaluation_streams: Vec<S>,
1212
pub num_variables: usize,
1313
pub verifier_messages: Vec<F>,
1414
pub verifier_message_hats: Vec<F>,
@@ -49,11 +49,13 @@ impl<F: Field, S: Stream<F>> SpaceProver<F, S> {
4949
// Check if the bit at the position specified by the bitmask is set
5050
let is_set: bool = (evaluation_index & bitmask) != 0;
5151

52-
// Use match to accumulate the appropriate value based on whether the bit is set or not
53-
let inner_sum = self.evaluation_stream.evaluation(evaluation_index) * lag_poly;
54-
match is_set {
55-
false => sum_0 += inner_sum,
56-
true => sum_1 += inner_sum,
52+
for stream in &self.evaluation_streams {
53+
// Use match to accumulate the appropriate value based on whether the bit is set or not
54+
let inner_sum = stream.evaluation(evaluation_index) * lag_poly;
55+
match is_set {
56+
false => sum_0 += inner_sum,
57+
true => sum_1 += inner_sum,
58+
}
5759
}
5860
}
5961
}

src/multilinear/provers/space/prover.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ impl<F: Field, S: Stream<F>> Prover<F> for SpaceProver<F, S> {
1818
fn new(prover_config: Self::ProverConfig) -> Self {
1919
Self {
2020
claim: prover_config.claim,
21-
evaluation_stream: prover_config.stream,
21+
evaluation_streams: prover_config.streams,
2222
verifier_messages: Vec::<F>::with_capacity(prover_config.num_variables),
2323
verifier_message_hats: Vec::<F>::with_capacity(prover_config.num_variables),
2424
current_round: 0,
Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
use ark_ff::Field;
22

3-
use crate::{prover::ProverConfig, streams::Stream};
3+
use crate::{
4+
multilinear::provers::time::reductions::ReduceMode,
5+
prover::{BatchProverConfig, ProverConfig},
6+
streams::Stream,
7+
};
48

59
pub struct TimeProverConfig<F, S>
610
where
@@ -9,19 +13,21 @@ where
913
{
1014
pub num_variables: usize,
1115
pub claim: F,
12-
pub stream: S,
16+
pub streams: Vec<S>,
17+
pub reduce_mode: ReduceMode,
1318
}
1419

1520
impl<F, S> TimeProverConfig<F, S>
1621
where
1722
F: Field,
1823
S: Stream<F>,
1924
{
20-
pub fn new(claim: F, num_variables: usize, stream: S) -> Self {
25+
pub fn new(claim: F, num_variables: usize, stream: S, reduce_mode: ReduceMode) -> Self {
2126
Self {
2227
claim,
2328
num_variables,
24-
stream,
29+
streams: vec![stream],
30+
reduce_mode,
2531
}
2632
}
2733
}
@@ -31,7 +37,19 @@ impl<F: Field, S: Stream<F>> ProverConfig<F, S> for TimeProverConfig<F, S> {
3137
Self {
3238
claim,
3339
num_variables,
34-
stream,
40+
streams: vec![stream],
41+
reduce_mode: ReduceMode::Pairwise,
42+
}
43+
}
44+
}
45+
46+
impl<F: Field, S: Stream<F>> BatchProverConfig<F, S> for TimeProverConfig<F, S> {
47+
fn default(claim: F, num_variables: usize, streams: Vec<S>) -> Self {
48+
Self {
49+
claim,
50+
num_variables,
51+
streams,
52+
reduce_mode: ReduceMode::Pairwise,
3553
}
3654
}
3755
}

src/multilinear/provers/time/core.rs

Lines changed: 3 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -1,167 +1,19 @@
1+
use crate::multilinear::provers::time::reductions::ReduceMode;
12
use ark_ff::Field;
23
use ark_std::vec::Vec;
34

4-
#[cfg(feature = "parallel")]
5-
use ark_std::cfg_into_iter;
6-
#[cfg(feature = "parallel")]
7-
use rayon::iter::{
8-
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
9-
};
10-
115
use crate::streams::Stream;
126

137
pub struct TimeProver<F: Field, S: Stream<F>> {
148
pub claim: F,
159
pub current_round: usize,
1610
pub evaluations: Option<Vec<F>>,
17-
pub evaluation_stream: S, // TODO (z-tech): this can be released after the first call to vsbw_reduce_evaluations
11+
pub evaluation_streams: Vec<S>, // TODO (z-tech): this can be released after the first call to vsbw_reduce_evaluations
1812
pub num_variables: usize,
13+
pub reduce_mode: ReduceMode,
1914
}
2015

2116
impl<F: Field, S: Stream<F>> TimeProver<F, S> {
22-
fn num_free_variables(&self) -> usize {
23-
self.num_variables - self.current_round
24-
}
25-
pub fn vsbw_evaluate(&self) -> (F, F) {
26-
// Calculate the bitmask for the number of free variables
27-
let bitmask: usize = 1 << (self.num_free_variables() - 1);
28-
29-
// Determine the length of evaluations to iterate through
30-
let evaluations_len = match &self.evaluations {
31-
Some(evaluations) => evaluations.len(),
32-
None => 2usize.pow(self.evaluation_stream.num_variables() as u32),
33-
};
34-
35-
#[cfg(feature = "parallel")]
36-
let (sum_0, sum_1) = cfg_into_iter!(0..evaluations_len)
37-
.map(|i| {
38-
// Get the point evaluation
39-
let val = if let Some(evals) = &self.evaluations {
40-
evals[i]
41-
} else {
42-
self.evaluation_stream.evaluation(i)
43-
};
44-
45-
// Route value into the proper bucket
46-
if (i & bitmask) == 0 {
47-
(val, F::zero()) // contributes to sum_0
48-
} else {
49-
(F::zero(), val) // contributes to sum_1
50-
}
51-
})
52-
// Combine partial (sum0, sum1) pairs from each worker/thread.
53-
.reduce(
54-
|| (F::zero(), F::zero()),
55-
|(a0, a1), (b0, b1)| (a0 + b0, a1 + b1),
56-
);
57-
58-
// Initialize accumulators for sum_0 and sum_1
59-
#[cfg(not(feature = "parallel"))]
60-
let mut sum_0 = F::ZERO;
61-
#[cfg(not(feature = "parallel"))]
62-
let mut sum_1 = F::ZERO;
63-
#[cfg(not(feature = "parallel"))]
64-
{
65-
// Iterate through evaluations
66-
for i in 0..evaluations_len {
67-
// Check if the bit at the position specified by the bitmask is set
68-
let is_set: bool = (i & bitmask) != 0;
69-
70-
// Get the point evaluation for the current index
71-
let point_evaluation = match &self.evaluations {
72-
Some(evaluations) => evaluations[i],
73-
None => self.evaluation_stream.evaluation(i),
74-
};
75-
76-
// Accumulate the value based on whether the bit is set or not
77-
match is_set {
78-
false => sum_0 += point_evaluation,
79-
true => sum_1 += point_evaluation,
80-
}
81-
}
82-
}
83-
84-
// Return the accumulated sums
85-
(sum_0, sum_1)
86-
}
87-
88-
pub fn vsbw_reduce_evaluations(&mut self, verifier_message: F, verifier_message_hat: F) {
89-
// Clone or initialize the evaluations vector
90-
#[cfg(feature = "parallel")]
91-
let is_first_go = self.evaluations.is_some();
92-
let mut evaluations = match &self.evaluations {
93-
Some(evaluations) => evaluations.clone(),
94-
None => vec![
95-
F::ZERO;
96-
2usize.pow(self.evaluation_stream.num_variables().try_into().unwrap()) / 2
97-
],
98-
};
99-
100-
// Determine the length of evaluations to iterate through
101-
let evaluations_len = match &self.evaluations {
102-
Some(evaluations) => evaluations.len() / 2,
103-
None => evaluations.len(),
104-
};
105-
106-
// Calculate what bit needs to be set to index the second half of the last round's evaluations
107-
let setbit: usize = 1 << self.num_free_variables();
108-
109-
#[cfg(feature = "parallel")]
110-
{
111-
// We'll write to the first half only.
112-
let dest = &mut evaluations[..evaluations_len];
113-
114-
if is_first_go {
115-
// Read from the old immutable source (borrow, no extra clone).
116-
let src = self.evaluations.as_ref().unwrap();
117-
dest.par_iter_mut()
118-
.enumerate()
119-
.for_each(|(i0, slot): (usize, &mut F)| {
120-
let i1 = i0 | setbit;
121-
let v0 = src[i0];
122-
let v1 = src[i1];
123-
*slot = v0 * verifier_message_hat + v1 * verifier_message;
124-
});
125-
} else {
126-
// Stream-only: compute both endpoints from the stream.
127-
let stream = &self.evaluation_stream;
128-
dest.par_iter_mut()
129-
.enumerate()
130-
.for_each(|(i0, slot): (usize, &mut F)| {
131-
let i1 = i0 | setbit;
132-
let v0 = stream.evaluation(i0);
133-
let v1 = stream.evaluation(i1);
134-
*slot = v0 * verifier_message_hat + v1 * verifier_message;
135-
});
136-
}
137-
}
138-
139-
// Iterate through pairs of evaluations
140-
#[cfg(not(feature = "parallel"))]
141-
for i0 in 0..evaluations_len {
142-
let i1 = i0 | setbit;
143-
144-
// Get point evaluations for indices i0 and i1
145-
let point_evaluation_i0 = match &self.evaluations {
146-
None => self.evaluation_stream.evaluation(i0),
147-
Some(evaluations) => evaluations[i0],
148-
};
149-
let point_evaluation_i1 = match &self.evaluations {
150-
None => self.evaluation_stream.evaluation(i1),
151-
Some(evaluations) => evaluations[i1],
152-
};
153-
154-
// Update the i0-th evaluation based on the reduction operation
155-
evaluations[i0] =
156-
point_evaluation_i0 * verifier_message_hat + point_evaluation_i1 * verifier_message;
157-
}
158-
159-
// Truncate the evaluations vector to the correct length
160-
evaluations.truncate(evaluations_len);
161-
162-
// Update the internal state with the new evaluations vector
163-
self.evaluations = Some(evaluations.clone());
164-
}
16517
pub fn total_rounds(&self) -> usize {
16618
self.num_variables
16719
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
mod config;
22
mod core;
33
mod prover;
4+
mod reductions;
45

56
pub use config::TimeProverConfig;
67
pub use core::TimeProver;
8+
pub use reductions::ReduceMode;

0 commit comments

Comments
 (0)