Skip to content

Commit 4fddb21

Browse files
committed
rebased and changed according to comments
1 parent 5ff8766 commit 4fddb21

File tree

2 files changed

+94
-58
lines changed

2 files changed

+94
-58
lines changed

.DS_Store

10 KB
Binary file not shown.

kzg/src/msm/strauss.rs

Lines changed: 94 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ where
2727
TG1Affine: G1Affine<TG1, TG1Fp>,
2828
TG1ProjAddAffine: G1ProjAddAffine<TG1, TG1Fp, TG1Affine>,
2929
{
30-
/// precomputed per-chunk tables (each chunk table is a Vec of projective points
31-
/// holding sums for each mask in 0..(1<<chunk_len)). Storing projective
32-
/// points avoids repeated affine->projective conversions during the inner loop.
33-
chunk_tables: Vec<Vec<TG1>>,
34-
30+
/// Precomputed per-chunk tables stored as AFFINE points.
31+
/// Saves 33% memory (2 field elements vs 3 for projective).
32+
/// Projective + Affine addition is also 1 multiplication faster
33+
chunk_tables: Vec<Vec<TG1Affine>>, // Projective -> Affine change
3534
numpoints: usize,
3635

3736
batch_numpoints: usize,
38-
batch_points: Vec<Vec<TG1>>,
37+
batch_points: Vec<Vec<TG1Affine>>, // Projective -> Affine change
38+
batch_chunk_tables: Vec<Vec<Vec<TG1Affine>>>, // recomputed tables per row
3939

4040
g1_marker: PhantomData<TG1>,
4141
g1_fp_marker: PhantomData<TG1Fp>,
@@ -76,26 +76,35 @@ where
7676
let end = core::cmp::min(start + chunk_size, n);
7777
let chunk_len = end - start;
7878
let table_size = (1usize << chunk_len) - 1;
79-
80-
let chunk: Vec<TG1> = cache.table[offset..offset + table_size]
81-
.iter()
82-
.map(|affine| affine.to_proj())
83-
.collect();
79+
80+
// Store directly as affine
81+
let chunk: Vec<TG1Affine> = cache.table[offset..offset + table_size]
82+
.to_vec();
8483
chunk_tables.push(chunk);
8584
offset += table_size;
8685
}
8786

88-
let batch_points = cache
89-
.batch_table
87+
// Store batch points as affine too
88+
let batch_points = cache.batch_table.clone();
89+
90+
// Rebuild batch_chunk_tables from batch_points
91+
let batch_chunk_tables: Vec<Vec<Vec<TG1Affine>>> = batch_points
9092
.iter()
91-
.map(|row| row.iter().map(|affine| affine.to_proj()).collect())
92-
.collect::<Vec<_>>();
93+
.map(|point_row| {
94+
let proj_row: Vec<TG1> = point_row
95+
.iter()
96+
.map(|affine| affine.to_proj())
97+
.collect();
98+
Self::build_chunk_tables(&proj_row, chunk_size)
99+
})
100+
.collect();
93101

94102
Self {
95103
chunk_tables,
96104
numpoints: cache.numpoints,
97105
batch_numpoints: cache.batch_numpoints,
98106
batch_points,
107+
batch_chunk_tables,
99108
g1_marker: PhantomData,
100109
g1_fp_marker: PhantomData,
101110
fr_marker: PhantomData,
@@ -112,30 +121,23 @@ where
112121
fn try_write_cache(
113122
points: &[TG1],
114123
matrix: &[Vec<TG1>],
115-
chunk_tables: &[Vec<TG1>],
124+
chunk_tables: &[Vec<TG1Affine>], // now takes affine
116125
numpoints: usize,
117-
batch_points: &[Vec<TG1>],
126+
batch_points: &[Vec<TG1Affine>], // now takes affine
118127
batch_numpoints: usize,
119128
contenthash: Option<[u8; 32]>,
120129
) -> Result<(), String> {
121130
#[cfg(feature = "diskcache")]
122131
{
123-
// Flatten chunk_tables to a single vector and convert to affine
132+
// Flatten chunk_tables
124133
let table_affine: Vec<TG1Affine> = chunk_tables
125134
.iter()
126135
.flat_map(|chunk| chunk.iter())
127-
.map(|proj| TG1Affine::into_affine(proj))
136+
.cloned()
128137
.collect();
129138

130-
// Convert batch_points to affine
131-
let batch_table_affine: Vec<Vec<TG1Affine>> = batch_points
132-
.iter()
133-
.map(|row| {
134-
row.iter()
135-
.map(|proj| TG1Affine::into_affine(proj))
136-
.collect()
137-
})
138-
.collect();
139+
// Batch points are already affine, just clone
140+
let batch_table_affine = batch_points.to_vec();
139141

140142
DiskCache::<TG1, TG1Fp, TG1Affine>::save(
141143
"strauss",
@@ -174,6 +176,7 @@ where
174176
numpoints: 0,
175177
batch_numpoints: 0,
176178
batch_points: Vec::new(),
179+
batch_chunk_tables: Vec::new(),
177180
g1_marker: PhantomData,
178181
g1_fp_marker: PhantomData,
179182
fr_marker: PhantomData,
@@ -183,7 +186,7 @@ where
183186
return Ok(Some(table));
184187
}
185188

186-
// Build chunk tables directly from projective points
189+
// Build chunk tables as affine
187190
let chunk_tables = Self::build_chunk_tables(points, strauss_chunk_size);
188191

189192
Self::try_write_cache(points, matrix, &chunk_tables, n, &[], 0, contenthash)?;
@@ -193,6 +196,7 @@ where
193196
numpoints: n,
194197
batch_numpoints: 0,
195198
batch_points: Vec::new(),
199+
batch_chunk_tables: Vec::new(),
196200
g1_marker: PhantomData,
197201
g1_fp_marker: PhantomData,
198202
fr_marker: PhantomData,
@@ -202,7 +206,7 @@ where
202206
return Ok(Some(table));
203207
}
204208

205-
// Batch case: store projective points directly
209+
// Batch case: convert projective input to affine for storage
206210
let batch_numpoints = matrix[0].len();
207211
let mut batch_points = Vec::new();
208212

@@ -211,9 +215,29 @@ where
211215
.map_err(|_| "Strauss precomputation table is too large".to_owned())?;
212216

213217
for row in matrix {
214-
batch_points.push(row.to_vec());
218+
// Convert projective to affine for storage
219+
let affine_row: Vec<TG1Affine> = row
220+
.iter()
221+
.map(|proj| TG1Affine::into_affine(proj))
222+
.collect();
223+
batch_points.push(affine_row);
215224
}
216225

226+
// precompute once during table creation
227+
let batch_chunk_tables: Vec<Vec<Vec<TG1Affine>>> = batch_points
228+
.iter()
229+
.map(|point_row| {
230+
// Convert affine back to projective for table building
231+
let proj_row: Vec<TG1> = point_row
232+
.iter()
233+
.map(|affine| affine.to_proj())
234+
.collect();
235+
// Build chunk tables for this row
236+
Self::build_chunk_tables(&proj_row, strauss_chunk_size)
237+
})
238+
.collect();
239+
240+
217241
// We still need to build the single-point-set table for the main points
218242
// (though it may not be used for batch operations)
219243
let n = points.len();
@@ -238,6 +262,7 @@ where
238262
numpoints: n,
239263
batch_numpoints,
240264
batch_points,
265+
batch_chunk_tables,
241266
g1_marker: PhantomData,
242267
g1_fp_marker: PhantomData,
243268
fr_marker: PhantomData,
@@ -247,10 +272,10 @@ where
247272
Ok(Some(table))
248273
}
249274

250-
/// Helper to build chunk tables from projective points
251-
fn build_chunk_tables(points: &[TG1], chunk_size: usize) -> Vec<Vec<TG1>> {
275+
/// Build chunk tables - returns AFFINE for storage efficiency
276+
fn build_chunk_tables(points: &[TG1], chunk_size: usize) -> Vec<Vec<TG1Affine>> {
252277
let n = points.len();
253-
let mut chunk_tables: Vec<Vec<TG1>> = Vec::new();
278+
let mut chunk_tables: Vec<Vec<TG1Affine>> = Vec::new();
254279

255280
let num_chunks = n.div_ceil(chunk_size);
256281

@@ -263,6 +288,7 @@ where
263288
let table_size = (1usize << chunk_len) - 1;
264289

265290
// Build incremental table in projective space using the lowest-bit trick.
291+
// faster additions in projective
266292
let mut table_proj: Vec<TG1> = Vec::with_capacity(table_size);
267293

268294
for mask in 1..=table_size {
@@ -277,7 +303,14 @@ where
277303
}
278304
}
279305

280-
chunk_tables.push(table_proj);
306+
// Convert to affine ONCE for storage
307+
let table_affine: Vec<TG1Affine> = table_proj
308+
.iter()
309+
.map(|proj| TG1Affine::into_affine(proj))
310+
.collect();
311+
312+
313+
chunk_tables.push(table_affine);
281314
}
282315

283316
chunk_tables
@@ -289,7 +322,7 @@ where
289322
}
290323

291324
/// Core multiplication logic using provided tables
292-
fn multiply_with_tables(scalars: &[TFr], chunk_tables: &[Vec<TG1>]) -> TG1 {
325+
fn multiply_with_tables(scalars: &[TFr], chunk_tables: &[Vec<TG1Affine>]) -> TG1 {
293326
let n = scalars.len();
294327
if n == 0 || chunk_tables.is_empty() {
295328
return TG1::zero();
@@ -324,32 +357,35 @@ where
324357
let mut pt_idx = 0usize;
325358
for table in chunk_tables.iter() {
326359
let table_size = table.len();
327-
if table_size == 0 {
328-
continue;
329-
}
330360
// Derive chunk_len from table size: table_size = 2^chunk_len - 1
331361
// So 2^chunk_len = table_size + 1, thus log2(table_size + 1)
332362
let chunk_len =
333363
(usize::BITS - 1) as usize - (table_size + 1).leading_zeros() as usize;
334364

335-
// Build mask for this bit across chunk scalars
336-
let mut mask = 0usize;
337-
for i in 0..chunk_len {
365+
// Only process this chunk if we have scalars for it
366+
// This handles the case where tables were built for more points than we're using
367+
if pt_idx >= scalar_values.len() {
368+
break;
369+
}
370+
371+
// Build table_index for this bit across chunk scalars
372+
let mut table_index = 0usize;
373+
let actual_chunk_len = core::cmp::min(chunk_len, scalar_values.len() - pt_idx);
374+
375+
for i in 0..actual_chunk_len {
338376
let scalar_idx = pt_idx + i;
339-
if scalar_idx >= scalar_values.len() {
340-
break;
341-
}
342377

343378
let s = &scalar_values[scalar_idx];
344379
// Extract single bit at position 'bit' from scalar
345380
if (get_wval_limb(s, bit, 1) & 1) != 0 {
346-
mask |= 1 << i;
381+
table_index |= 1 << i;
347382
}
348383
}
349384

350-
if mask != 0 {
351-
let tab_proj = &table[mask - 1];
352-
accumulator.add_or_dbl_assign(tab_proj);
385+
if table_index != 0 {
386+
// Mixed addition - Projective + Affine (should be faster than Proj + Proj)
387+
let affine_pt = &table[table_index - 1];
388+
TG1ProjAddAffine::add_or_double_assign_affine(&mut accumulator, affine_pt);
353389
}
354390

355391
pt_idx += chunk_len;
@@ -360,31 +396,31 @@ where
360396
}
361397

362398
pub fn multiply_batch(&self, scalars: &[Vec<TFr>]) -> Vec<TG1> {
363-
if self.batch_points.is_empty() {
399+
if self.batch_chunk_tables.is_empty() {
364400
// Fall back to sequential calls using main chunk_tables
365401
scalars
366402
.iter()
367403
.map(|s| self.multiply_sequential(s))
368404
.collect()
369405
} else {
370-
// Use batch_points: build temporary tables for each row
406+
// Use precomputed batch_chunk_tables
371407
assert!(
372-
scalars.len() == self.batch_points.len(),
373-
"Scalars length {} != batch_points length {}",
408+
scalars.len() == self.batch_chunk_tables.len(),
409+
"Scalars length {} != batch_chunk_tables length {}",
374410
scalars.len(),
375-
self.batch_points.len()
411+
self.batch_chunk_tables.len()
376412
);
377413

378414
let strauss_chunk_size: usize = get_window_size();
379415

380416
scalars
381417
.iter()
382-
.zip(self.batch_points.iter())
383-
.map(|(scalar_row, point_row)| {
384-
let chunk_tables = Self::build_chunk_tables(point_row, strauss_chunk_size);
385-
Self::multiply_with_tables(scalar_row, &chunk_tables)
418+
.zip(self.batch_chunk_tables.iter())
419+
.map(|(scalar_row, chunk_tables)| {
420+
Self::multiply_with_tables(scalar_row, chunk_tables)
386421
})
387422
.collect()
388423
}
389424
}
390425
}
426+

0 commit comments

Comments
 (0)