|
| 1 | +use crate::multilinear::provers::time::reductions::ReduceMode; |
1 | 2 | use ark_ff::Field; |
2 | 3 | use ark_std::vec::Vec; |
3 | 4 |
|
4 | | -#[cfg(feature = "parallel")] |
5 | | -use ark_std::cfg_into_iter; |
6 | | -#[cfg(feature = "parallel")] |
7 | | -use rayon::iter::{ |
8 | | - IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator, |
9 | | -}; |
10 | | - |
11 | 5 | use crate::streams::Stream; |
12 | 6 |
|
13 | 7 | pub struct TimeProver<F: Field, S: Stream<F>> { |
14 | 8 | pub claim: F, |
15 | 9 | pub current_round: usize, |
16 | 10 | pub evaluations: Option<Vec<F>>, |
17 | | - pub evaluation_stream: S, // TODO (z-tech): this can be released after the first call to vsbw_reduce_evaluations |
| 11 | + pub evaluation_streams: Vec<S>, // TODO (z-tech): this can be released after the first call to vsbw_reduce_evaluations |
18 | 12 | pub num_variables: usize, |
| 13 | + pub reduce_mode: ReduceMode, |
19 | 14 | } |
20 | 15 |
|
21 | 16 | impl<F: Field, S: Stream<F>> TimeProver<F, S> { |
22 | | - fn num_free_variables(&self) -> usize { |
23 | | - self.num_variables - self.current_round |
24 | | - } |
25 | | - pub fn vsbw_evaluate(&self) -> (F, F) { |
26 | | - // Calculate the bitmask for the number of free variables |
27 | | - let bitmask: usize = 1 << (self.num_free_variables() - 1); |
28 | | - |
29 | | - // Determine the length of evaluations to iterate through |
30 | | - let evaluations_len = match &self.evaluations { |
31 | | - Some(evaluations) => evaluations.len(), |
32 | | - None => 2usize.pow(self.evaluation_stream.num_variables() as u32), |
33 | | - }; |
34 | | - |
35 | | - #[cfg(feature = "parallel")] |
36 | | - let (sum_0, sum_1) = cfg_into_iter!(0..evaluations_len) |
37 | | - .map(|i| { |
38 | | - // Get the point evaluation |
39 | | - let val = if let Some(evals) = &self.evaluations { |
40 | | - evals[i] |
41 | | - } else { |
42 | | - self.evaluation_stream.evaluation(i) |
43 | | - }; |
44 | | - |
45 | | - // Route value into the proper bucket |
46 | | - if (i & bitmask) == 0 { |
47 | | - (val, F::zero()) // contributes to sum_0 |
48 | | - } else { |
49 | | - (F::zero(), val) // contributes to sum_1 |
50 | | - } |
51 | | - }) |
52 | | - // Combine partial (sum0, sum1) pairs from each worker/thread. |
53 | | - .reduce( |
54 | | - || (F::zero(), F::zero()), |
55 | | - |(a0, a1), (b0, b1)| (a0 + b0, a1 + b1), |
56 | | - ); |
57 | | - |
58 | | - // Initialize accumulators for sum_0 and sum_1 |
59 | | - #[cfg(not(feature = "parallel"))] |
60 | | - let mut sum_0 = F::ZERO; |
61 | | - #[cfg(not(feature = "parallel"))] |
62 | | - let mut sum_1 = F::ZERO; |
63 | | - #[cfg(not(feature = "parallel"))] |
64 | | - { |
65 | | - // Iterate through evaluations |
66 | | - for i in 0..evaluations_len { |
67 | | - // Check if the bit at the position specified by the bitmask is set |
68 | | - let is_set: bool = (i & bitmask) != 0; |
69 | | - |
70 | | - // Get the point evaluation for the current index |
71 | | - let point_evaluation = match &self.evaluations { |
72 | | - Some(evaluations) => evaluations[i], |
73 | | - None => self.evaluation_stream.evaluation(i), |
74 | | - }; |
75 | | - |
76 | | - // Accumulate the value based on whether the bit is set or not |
77 | | - match is_set { |
78 | | - false => sum_0 += point_evaluation, |
79 | | - true => sum_1 += point_evaluation, |
80 | | - } |
81 | | - } |
82 | | - } |
83 | | - |
84 | | - // Return the accumulated sums |
85 | | - (sum_0, sum_1) |
86 | | - } |
87 | | - |
88 | | - pub fn vsbw_reduce_evaluations(&mut self, verifier_message: F, verifier_message_hat: F) { |
89 | | - // Clone or initialize the evaluations vector |
90 | | - #[cfg(feature = "parallel")] |
91 | | - let is_first_go = self.evaluations.is_some(); |
92 | | - let mut evaluations = match &self.evaluations { |
93 | | - Some(evaluations) => evaluations.clone(), |
94 | | - None => vec![ |
95 | | - F::ZERO; |
96 | | - 2usize.pow(self.evaluation_stream.num_variables().try_into().unwrap()) / 2 |
97 | | - ], |
98 | | - }; |
99 | | - |
100 | | - // Determine the length of evaluations to iterate through |
101 | | - let evaluations_len = match &self.evaluations { |
102 | | - Some(evaluations) => evaluations.len() / 2, |
103 | | - None => evaluations.len(), |
104 | | - }; |
105 | | - |
106 | | - // Calculate what bit needs to be set to index the second half of the last round's evaluations |
107 | | - let setbit: usize = 1 << self.num_free_variables(); |
108 | | - |
109 | | - #[cfg(feature = "parallel")] |
110 | | - { |
111 | | - // We'll write to the first half only. |
112 | | - let dest = &mut evaluations[..evaluations_len]; |
113 | | - |
114 | | - if is_first_go { |
115 | | - // Read from the old immutable source (borrow, no extra clone). |
116 | | - let src = self.evaluations.as_ref().unwrap(); |
117 | | - dest.par_iter_mut() |
118 | | - .enumerate() |
119 | | - .for_each(|(i0, slot): (usize, &mut F)| { |
120 | | - let i1 = i0 | setbit; |
121 | | - let v0 = src[i0]; |
122 | | - let v1 = src[i1]; |
123 | | - *slot = v0 * verifier_message_hat + v1 * verifier_message; |
124 | | - }); |
125 | | - } else { |
126 | | - // Stream-only: compute both endpoints from the stream. |
127 | | - let stream = &self.evaluation_stream; |
128 | | - dest.par_iter_mut() |
129 | | - .enumerate() |
130 | | - .for_each(|(i0, slot): (usize, &mut F)| { |
131 | | - let i1 = i0 | setbit; |
132 | | - let v0 = stream.evaluation(i0); |
133 | | - let v1 = stream.evaluation(i1); |
134 | | - *slot = v0 * verifier_message_hat + v1 * verifier_message; |
135 | | - }); |
136 | | - } |
137 | | - } |
138 | | - |
139 | | - // Iterate through pairs of evaluations |
140 | | - #[cfg(not(feature = "parallel"))] |
141 | | - for i0 in 0..evaluations_len { |
142 | | - let i1 = i0 | setbit; |
143 | | - |
144 | | - // Get point evaluations for indices i0 and i1 |
145 | | - let point_evaluation_i0 = match &self.evaluations { |
146 | | - None => self.evaluation_stream.evaluation(i0), |
147 | | - Some(evaluations) => evaluations[i0], |
148 | | - }; |
149 | | - let point_evaluation_i1 = match &self.evaluations { |
150 | | - None => self.evaluation_stream.evaluation(i1), |
151 | | - Some(evaluations) => evaluations[i1], |
152 | | - }; |
153 | | - |
154 | | - // Update the i0-th evaluation based on the reduction operation |
155 | | - evaluations[i0] = |
156 | | - point_evaluation_i0 * verifier_message_hat + point_evaluation_i1 * verifier_message; |
157 | | - } |
158 | | - |
159 | | - // Truncate the evaluations vector to the correct length |
160 | | - evaluations.truncate(evaluations_len); |
161 | | - |
162 | | - // Update the internal state with the new evaluations vector |
163 | | - self.evaluations = Some(evaluations.clone()); |
164 | | - } |
165 | 17 | pub fn total_rounds(&self) -> usize { |
166 | 18 | self.num_variables |
167 | 19 | } |
|
0 commit comments