Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/multilinear/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ mod sumcheck;
pub use provers::{
blendy::{BlendyProver, BlendyProverConfig},
space::{SpaceProver, SpaceProverConfig},
time::{TimeProver, TimeProverConfig},
time::{ReduceMode, TimeProver, TimeProverConfig},
};
pub use sumcheck::Sumcheck;
21 changes: 17 additions & 4 deletions src/multilinear/provers/space/config.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use ark_ff::Field;

use crate::{prover::ProverConfig, streams::Stream};
use crate::{
prover::{BatchProverConfig, ProverConfig},
streams::Stream,
};

pub struct SpaceProverConfig<F, S>
where
Expand All @@ -9,7 +12,7 @@ where
{
pub num_variables: usize,
pub claim: F,
pub stream: S,
pub streams: Vec<S>,
}

impl<F, S> SpaceProverConfig<F, S>
Expand All @@ -21,7 +24,7 @@ where
Self {
claim,
num_variables,
stream,
streams: vec![stream],
}
}
}
Expand All @@ -31,7 +34,17 @@ impl<F: Field, S: Stream<F>> ProverConfig<F, S> for SpaceProverConfig<F, S> {
Self {
claim,
num_variables,
stream,
streams: vec![stream],
}
}
}

impl<F: Field, S: Stream<F>> BatchProverConfig<F, S> for SpaceProverConfig<F, S> {
fn default(claim: F, num_variables: usize, streams: Vec<S>) -> Self {
Self {
claim,
num_variables,
streams,
}
}
}
14 changes: 8 additions & 6 deletions src/multilinear/provers/space/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
pub struct SpaceProver<F: Field, S: Stream<F>> {
pub claim: F,
pub current_round: usize,
pub evaluation_stream: S,
pub evaluation_streams: Vec<S>,
pub num_variables: usize,
pub verifier_messages: Vec<F>,
pub verifier_message_hats: Vec<F>,
Expand Down Expand Up @@ -49,11 +49,13 @@ impl<F: Field, S: Stream<F>> SpaceProver<F, S> {
// Check if the bit at the position specified by the bitmask is set
let is_set: bool = (evaluation_index & bitmask) != 0;

// Use match to accumulate the appropriate value based on whether the bit is set or not
let inner_sum = self.evaluation_stream.evaluation(evaluation_index) * lag_poly;
match is_set {
false => sum_0 += inner_sum,
true => sum_1 += inner_sum,
for stream in &self.evaluation_streams {
// Use match to accumulate the appropriate value based on whether the bit is set or not
let inner_sum = stream.evaluation(evaluation_index) * lag_poly;
match is_set {
false => sum_0 += inner_sum,
true => sum_1 += inner_sum,
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/multilinear/provers/space/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ impl<F: Field, S: Stream<F>> Prover<F> for SpaceProver<F, S> {
fn new(prover_config: Self::ProverConfig) -> Self {
Self {
claim: prover_config.claim,
evaluation_stream: prover_config.stream,
evaluation_streams: prover_config.streams,
verifier_messages: Vec::<F>::with_capacity(prover_config.num_variables),
verifier_message_hats: Vec::<F>::with_capacity(prover_config.num_variables),
current_round: 0,
Expand Down
28 changes: 23 additions & 5 deletions src/multilinear/provers/time/config.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use ark_ff::Field;

use crate::{prover::ProverConfig, streams::Stream};
use crate::{
multilinear::provers::time::reductions::ReduceMode,
prover::{BatchProverConfig, ProverConfig},
streams::Stream,
};

pub struct TimeProverConfig<F, S>
where
Expand All @@ -9,19 +13,21 @@ where
{
pub num_variables: usize,
pub claim: F,
pub stream: S,
pub streams: Vec<S>,
pub reduce_mode: ReduceMode,
}

impl<F, S> TimeProverConfig<F, S>
where
F: Field,
S: Stream<F>,
{
pub fn new(claim: F, num_variables: usize, stream: S) -> Self {
pub fn new(claim: F, num_variables: usize, stream: S, reduce_mode: ReduceMode) -> Self {
Self {
claim,
num_variables,
stream,
streams: vec![stream],
reduce_mode,
}
}
}
Expand All @@ -31,7 +37,19 @@ impl<F: Field, S: Stream<F>> ProverConfig<F, S> for TimeProverConfig<F, S> {
Self {
claim,
num_variables,
stream,
streams: vec![stream],
reduce_mode: ReduceMode::Pairwise,
}
}
}

impl<F: Field, S: Stream<F>> BatchProverConfig<F, S> for TimeProverConfig<F, S> {
fn default(claim: F, num_variables: usize, streams: Vec<S>) -> Self {
Self {
claim,
num_variables,
streams,
reduce_mode: ReduceMode::Pairwise,
}
}
}
154 changes: 3 additions & 151 deletions src/multilinear/provers/time/core.rs
Original file line number Diff line number Diff line change
@@ -1,167 +1,19 @@
use crate::multilinear::provers::time::reductions::ReduceMode;
use ark_ff::Field;
use ark_std::vec::Vec;

#[cfg(feature = "parallel")]
use ark_std::cfg_into_iter;
#[cfg(feature = "parallel")]
use rayon::iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
};

use crate::streams::Stream;

pub struct TimeProver<F: Field, S: Stream<F>> {
pub claim: F,
pub current_round: usize,
pub evaluations: Option<Vec<F>>,
pub evaluation_stream: S, // TODO (z-tech): this can be released after the first call to vsbw_reduce_evaluations
pub evaluation_streams: Vec<S>, // TODO (z-tech): this can be released after the first call to vsbw_reduce_evaluations
pub num_variables: usize,
pub reduce_mode: ReduceMode,
}

impl<F: Field, S: Stream<F>> TimeProver<F, S> {
fn num_free_variables(&self) -> usize {
self.num_variables - self.current_round
}
pub fn vsbw_evaluate(&self) -> (F, F) {
// Calculate the bitmask for the number of free variables
let bitmask: usize = 1 << (self.num_free_variables() - 1);

// Determine the length of evaluations to iterate through
let evaluations_len = match &self.evaluations {
Some(evaluations) => evaluations.len(),
None => 2usize.pow(self.evaluation_stream.num_variables() as u32),
};

#[cfg(feature = "parallel")]
let (sum_0, sum_1) = cfg_into_iter!(0..evaluations_len)
.map(|i| {
// Get the point evaluation
let val = if let Some(evals) = &self.evaluations {
evals[i]
} else {
self.evaluation_stream.evaluation(i)
};

// Route value into the proper bucket
if (i & bitmask) == 0 {
(val, F::zero()) // contributes to sum_0
} else {
(F::zero(), val) // contributes to sum_1
}
})
// Combine partial (sum0, sum1) pairs from each worker/thread.
.reduce(
|| (F::zero(), F::zero()),
|(a0, a1), (b0, b1)| (a0 + b0, a1 + b1),
);

// Initialize accumulators for sum_0 and sum_1
#[cfg(not(feature = "parallel"))]
let mut sum_0 = F::ZERO;
#[cfg(not(feature = "parallel"))]
let mut sum_1 = F::ZERO;
#[cfg(not(feature = "parallel"))]
{
// Iterate through evaluations
for i in 0..evaluations_len {
// Check if the bit at the position specified by the bitmask is set
let is_set: bool = (i & bitmask) != 0;

// Get the point evaluation for the current index
let point_evaluation = match &self.evaluations {
Some(evaluations) => evaluations[i],
None => self.evaluation_stream.evaluation(i),
};

// Accumulate the value based on whether the bit is set or not
match is_set {
false => sum_0 += point_evaluation,
true => sum_1 += point_evaluation,
}
}
}

// Return the accumulated sums
(sum_0, sum_1)
}

pub fn vsbw_reduce_evaluations(&mut self, verifier_message: F, verifier_message_hat: F) {
// Clone or initialize the evaluations vector
#[cfg(feature = "parallel")]
let is_first_go = self.evaluations.is_some();
let mut evaluations = match &self.evaluations {
Some(evaluations) => evaluations.clone(),
None => vec![
F::ZERO;
2usize.pow(self.evaluation_stream.num_variables().try_into().unwrap()) / 2
],
};

// Determine the length of evaluations to iterate through
let evaluations_len = match &self.evaluations {
Some(evaluations) => evaluations.len() / 2,
None => evaluations.len(),
};

// Calculate what bit needs to be set to index the second half of the last round's evaluations
let setbit: usize = 1 << self.num_free_variables();

#[cfg(feature = "parallel")]
{
// We'll write to the first half only.
let dest = &mut evaluations[..evaluations_len];

if is_first_go {
// Read from the old immutable source (borrow, no extra clone).
let src = self.evaluations.as_ref().unwrap();
dest.par_iter_mut()
.enumerate()
.for_each(|(i0, slot): (usize, &mut F)| {
let i1 = i0 | setbit;
let v0 = src[i0];
let v1 = src[i1];
*slot = v0 * verifier_message_hat + v1 * verifier_message;
});
} else {
// Stream-only: compute both endpoints from the stream.
let stream = &self.evaluation_stream;
dest.par_iter_mut()
.enumerate()
.for_each(|(i0, slot): (usize, &mut F)| {
let i1 = i0 | setbit;
let v0 = stream.evaluation(i0);
let v1 = stream.evaluation(i1);
*slot = v0 * verifier_message_hat + v1 * verifier_message;
});
}
}

// Iterate through pairs of evaluations
#[cfg(not(feature = "parallel"))]
for i0 in 0..evaluations_len {
let i1 = i0 | setbit;

// Get point evaluations for indices i0 and i1
let point_evaluation_i0 = match &self.evaluations {
None => self.evaluation_stream.evaluation(i0),
Some(evaluations) => evaluations[i0],
};
let point_evaluation_i1 = match &self.evaluations {
None => self.evaluation_stream.evaluation(i1),
Some(evaluations) => evaluations[i1],
};

// Update the i0-th evaluation based on the reduction operation
evaluations[i0] =
point_evaluation_i0 * verifier_message_hat + point_evaluation_i1 * verifier_message;
}

// Truncate the evaluations vector to the correct length
evaluations.truncate(evaluations_len);

// Update the internal state with the new evaluations vector
self.evaluations = Some(evaluations.clone());
}
pub fn total_rounds(&self) -> usize {
self.num_variables
}
Expand Down
2 changes: 2 additions & 0 deletions src/multilinear/provers/time/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
mod config;
mod core;
mod prover;
mod reductions;

pub use config::TimeProverConfig;
pub use core::TimeProver;
pub use reductions::ReduceMode;
Loading
Loading