Skip to content

Commit 4bbe630

Browse files
jonathanpwangclaude
andcommitted
perf: store vars during preflight to avoid recomputation in trace fill
Previously, `run_field_expression` was called twice - once during preflight and again during trace filling - causing expensive `vars` computation to be duplicated. This change: - Extends FieldExpressionCoreRecordMut to store computed vars - Updates preflight to save vars in the record after computation - Updates trace filler to read precomputed vars from the record - Adds generate_subrow_with_precomputed_vars to FieldExpr - Extracts derive_flags_from_opcode helper for reuse The vars (intermediate field expression results) are now computed once during preflight and reused during trace generation, eliminating redundant BigUint arithmetic operations. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 389afd6 commit 4bbe630

File tree

3 files changed

+219
-39
lines changed

3 files changed

+219
-39
lines changed

crates/circuits/mod-builder/src/builder.rs

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,118 @@ impl FieldExpr {
616616
.collect()
617617
}
618618

619+
/// Generates a trace subrow using precomputed `vars` to avoid redundant computation.
620+
/// This is used during trace filling when vars have already been computed during preflight.
621+
pub fn generate_subrow_with_precomputed_vars<F: PrimeField64>(
622+
&self,
623+
range_checker: &VariableRangeCheckerChip,
624+
inputs: &[BigUint],
625+
flags: &[bool],
626+
vars: &[BigUint],
627+
sub_row: &mut [F],
628+
) {
629+
assert!(self.builder.is_finalized());
630+
assert_eq!(inputs.len(), self.num_input);
631+
assert_eq!(vars.len(), self.num_variables);
632+
assert_eq!(flags.len(), self.builder.num_flags);
633+
634+
let limb_bits = self.limb_bits;
635+
636+
// BigInt type is required for computing the quotient.
637+
let input_bigint = inputs
638+
.iter()
639+
.map(|x| BigInt::from_biguint(Sign::Plus, x.clone()))
640+
.collect::<Vec<BigInt>>();
641+
let vars_bigint: Vec<BigInt> = vars
642+
.iter()
643+
.map(|x| BigInt::from_biguint(Sign::Plus, x.clone()))
644+
.collect();
645+
646+
// OverflowInt type is required for computing the carries.
647+
let input_overflow = inputs
648+
.iter()
649+
.map(|x| OverflowInt::<isize>::from_biguint(x, self.limb_bits, Some(self.num_limbs)))
650+
.collect::<Vec<_>>();
651+
let vars_overflow: Vec<OverflowInt<isize>> = vars
652+
.iter()
653+
.map(|x| OverflowInt::<isize>::from_biguint(x, self.limb_bits, Some(self.num_limbs)))
654+
.collect();
655+
656+
// Note: in cases where the prime fits in less limbs than `num_limbs`, we use the smaller
657+
// number of limbs.
658+
let prime_overflow = OverflowInt::<isize>::from_biguint(&self.prime, self.limb_bits, None);
659+
660+
let constants: Vec<_> = self
661+
.constants
662+
.iter()
663+
.map(|(_, limbs)| {
664+
let limbs_isize: Vec<_> = limbs.iter().map(|i| *i as isize).collect();
665+
OverflowInt::from_canonical_unsigned_limbs(limbs_isize, self.limb_bits)
666+
})
667+
.collect();
668+
669+
let mut all_q = vec![];
670+
let mut all_carry = vec![];
671+
for i in 0..self.constraints.len() {
672+
// expr = q * p
673+
let expr_bigint =
674+
self.constraints[i].evaluate_bigint(&input_bigint, &vars_bigint, flags);
675+
let q = &expr_bigint / &self.prime_bigint;
676+
// If this is not true then the evaluated constraint is not divisible by p.
677+
debug_assert_eq!(expr_bigint, &q * &self.prime_bigint);
678+
let q_limbs = big_int_to_num_limbs(&q, limb_bits, self.q_limbs[i]);
679+
assert_eq!(q_limbs.len(), self.q_limbs[i]); // If this fails, the q_limbs estimate is wrong.
680+
for &q in q_limbs.iter() {
681+
range_checker.add_count((q + (1 << limb_bits)) as u32, limb_bits + 1);
682+
}
683+
let q_overflow = OverflowInt::from_canonical_signed_limbs(q_limbs.clone(), limb_bits);
684+
// compute carries of (expr - q * p)
685+
let expr = self.constraints[i].evaluate_overflow_isize(
686+
&input_overflow,
687+
&vars_overflow,
688+
&constants,
689+
flags,
690+
);
691+
let expr = expr - q_overflow * prime_overflow.clone();
692+
let carries = expr.calculate_carries(limb_bits);
693+
assert_eq!(carries.len(), self.carry_limbs[i]); // If this fails, the carry limbs estimate is wrong.
694+
let max_overflow_bits = expr.max_overflow_bits();
695+
let (carry_min_abs, carry_bits) =
696+
get_carry_max_abs_and_bits(max_overflow_bits, limb_bits);
697+
for &carry in carries.iter() {
698+
range_checker.add_count((carry + carry_min_abs as isize) as u32, carry_bits);
699+
}
700+
all_q.push(vec_isize_to_f::<F>(q_limbs));
701+
all_carry.push(vec_isize_to_f::<F>(carries));
702+
}
703+
for var in vars_overflow.iter() {
704+
for limb in var.limbs().iter() {
705+
range_checker.add_count(*limb as u32, limb_bits);
706+
}
707+
}
708+
709+
let input_limbs = input_overflow
710+
.iter()
711+
.map(|x| vec_isize_to_f::<F>(x.limbs().to_vec()))
712+
.collect::<Vec<_>>();
713+
let vars_limbs = vars_overflow
714+
.iter()
715+
.map(|x| vec_isize_to_f::<F>(x.limbs().to_vec()))
716+
.collect::<Vec<_>>();
717+
718+
sub_row.copy_from_slice(
719+
&[
720+
vec![F::ONE],
721+
input_limbs.concat(),
722+
vars_limbs.concat(),
723+
all_q.concat(),
724+
all_carry.concat(),
725+
flags.iter().map(|x| F::from_bool(*x)).collect::<Vec<_>>(),
726+
]
727+
.concat(),
728+
);
729+
}
730+
619731
pub fn load_vars<T: Clone>(&self, arr: &[T]) -> FieldExprCols<T> {
620732
assert!(self.builder.is_finalized());
621733
let is_valid = arr[0].clone();

crates/circuits/mod-builder/src/core_chip.rs

Lines changed: 94 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,15 @@ where
171171

172172
pub struct FieldExpressionMetadata<F, A> {
173173
pub total_input_limbs: usize, // num_inputs * limbs_per_input
174+
pub total_var_limbs: usize, // num_variables * limbs_per_variable
174175
_phantom: PhantomData<(F, A)>,
175176
}
176177

177178
impl<F, A> Clone for FieldExpressionMetadata<F, A> {
178179
fn clone(&self) -> Self {
179180
Self {
180181
total_input_limbs: self.total_input_limbs,
182+
total_var_limbs: self.total_var_limbs,
181183
_phantom: PhantomData,
182184
}
183185
}
@@ -187,15 +189,17 @@ impl<F, A> Default for FieldExpressionMetadata<F, A> {
187189
fn default() -> Self {
188190
Self {
189191
total_input_limbs: 0,
192+
total_var_limbs: 0,
190193
_phantom: PhantomData,
191194
}
192195
}
193196
}
194197

195198
impl<F, A> FieldExpressionMetadata<F, A> {
196-
pub fn new(total_input_limbs: usize) -> Self {
199+
pub fn new(total_input_limbs: usize, total_var_limbs: usize) -> Self {
197200
Self {
198201
total_input_limbs,
202+
total_var_limbs,
199203
_phantom: PhantomData,
200204
}
201205
}
@@ -216,6 +220,7 @@ pub type FieldExpressionRecordLayout<F, A> = AdapterCoreLayout<FieldExpressionMe
216220
pub struct FieldExpressionCoreRecordMut<'a> {
217221
pub opcode: &'a mut u8,
218222
pub input_limbs: &'a mut [u8],
223+
pub var_limbs: &'a mut [u8],
219224
}
220225

221226
impl<'a, F, A> CustomBorrow<'a, FieldExpressionCoreRecordMut<'a>, FieldExpressionRecordLayout<F, A>>
@@ -226,14 +231,19 @@ impl<'a, F, A> CustomBorrow<'a, FieldExpressionCoreRecordMut<'a>, FieldExpressio
226231
layout: FieldExpressionRecordLayout<F, A>,
227232
) -> FieldExpressionCoreRecordMut<'a> {
228233
// SAFETY: The buffer length is the width of the trace which should be at least 1
229-
let (opcode_buf, input_limbs_buff) = unsafe { self.split_at_mut_unchecked(1) };
234+
let (opcode_buf, rest) = unsafe { self.split_at_mut_unchecked(1) };
230235

231236
// SAFETY: opcode_buf has exactly 1 element from split_at_mut_unchecked(1)
232237
let opcode_buf = unsafe { opcode_buf.get_unchecked_mut(0) };
233238

239+
// SAFETY: rest has at least total_input_limbs + total_var_limbs bytes
240+
let (input_limbs, var_limbs_buf) =
241+
unsafe { rest.split_at_mut_unchecked(layout.metadata.total_input_limbs) };
242+
234243
FieldExpressionCoreRecordMut {
235244
opcode: opcode_buf,
236-
input_limbs: &mut input_limbs_buff[..layout.metadata.total_input_limbs],
245+
input_limbs,
246+
var_limbs: &mut var_limbs_buf[..layout.metadata.total_var_limbs],
237247
}
238248
}
239249

@@ -244,7 +254,9 @@ impl<'a, F, A> CustomBorrow<'a, FieldExpressionCoreRecordMut<'a>, FieldExpressio
244254

245255
impl<F, A> SizedRecord<FieldExpressionRecordLayout<F, A>> for FieldExpressionCoreRecordMut<'_> {
246256
fn size(layout: &FieldExpressionRecordLayout<F, A>) -> usize {
247-
layout.metadata.total_input_limbs + 1
257+
1 // opcode
258+
+ layout.metadata.total_input_limbs
259+
+ layout.metadata.total_var_limbs
248260
}
249261

250262
fn alignment(_layout: &FieldExpressionRecordLayout<F, A>) -> usize {
@@ -257,9 +269,13 @@ impl<'a> FieldExpressionCoreRecordMut<'a> {
257269
pub fn new_from_execution_data(
258270
buffer: &'a mut [u8],
259271
inputs: &[BigUint],
272+
num_variables: usize,
260273
limbs_per_input: usize,
261274
) -> Self {
262-
let record_info = FieldExpressionMetadata::<(), ()>::new(inputs.len() * limbs_per_input);
275+
let record_info = FieldExpressionMetadata::<(), ()>::new(
276+
inputs.len() * limbs_per_input,
277+
num_variables * limbs_per_input,
278+
);
263279

264280
let record: Self = buffer.custom_borrow(FieldExpressionRecordLayout {
265281
metadata: record_info,
@@ -319,9 +335,11 @@ impl<A> FieldExpressionExecutor<A> {
319335
}
320336

321337
pub fn get_record_layout<F>(&self) -> FieldExpressionRecordLayout<F, A> {
338+
let num_limbs = self.expr.canonical_num_limbs();
322339
FieldExpressionRecordLayout {
323340
metadata: FieldExpressionMetadata::new(
324-
self.expr.builder.num_input * self.expr.canonical_num_limbs(),
341+
self.expr.builder.num_input * num_limbs,
342+
self.expr.builder.num_variables * num_limbs,
325343
),
326344
}
327345
}
@@ -372,9 +390,11 @@ impl<A> FieldExpressionFiller<A> {
372390
}
373391

374392
pub fn get_record_layout<F>(&self) -> FieldExpressionRecordLayout<F, A> {
393+
let num_limbs = self.expr.canonical_num_limbs();
375394
FieldExpressionRecordLayout {
376395
metadata: FieldExpressionMetadata::new(
377-
self.num_inputs() * self.expr.canonical_num_limbs(),
396+
self.num_inputs() * num_limbs,
397+
self.expr.builder.num_variables * num_limbs,
378398
),
379399
}
380400
}
@@ -410,14 +430,23 @@ where
410430
&data.0,
411431
);
412432

413-
let (writes, _, _) = run_field_expression(
433+
let (writes, _, _, vars) = run_field_expression(
414434
&self.expr,
415435
&self.local_opcode_idx,
416436
&self.opcode_flag_idx,
417437
core_record.input_limbs,
418438
*core_record.opcode as usize,
419439
);
420440

441+
// Store computed vars in the record to avoid recomputation during trace filling
442+
let field_element_limbs = self.expr.canonical_num_limbs();
443+
for (i, var) in vars.iter().enumerate() {
444+
let start = i * field_element_limbs;
445+
let end = start + field_element_limbs;
446+
let limbs = biguint_to_limbs_vec(var, field_element_limbs);
447+
core_record.var_limbs[start..end].copy_from_slice(&limbs);
448+
}
449+
421450
self.adapter.write(
422451
state.memory,
423452
instruction,
@@ -457,17 +486,32 @@ where
457486
let record: FieldExpressionCoreRecordMut =
458487
unsafe { get_record_from_slice(&mut core_row, self.get_record_layout::<F>()) };
459488

460-
let (_, inputs, flags) = run_field_expression(
489+
// Reconstruct inputs from record
490+
let field_element_limbs = self.expr.canonical_num_limbs();
491+
let inputs: Vec<BigUint> = record
492+
.input_limbs
493+
.chunks(field_element_limbs)
494+
.map(BigUint::from_bytes_le)
495+
.collect();
496+
497+
// Read precomputed vars from record (computed during preflight)
498+
let vars: Vec<BigUint> = record
499+
.var_limbs
500+
.chunks(field_element_limbs)
501+
.map(BigUint::from_bytes_le)
502+
.collect();
503+
504+
// Derive flags from opcode (cheap operation)
505+
let flags = derive_flags_from_opcode(
461506
&self.expr,
462507
&self.local_opcode_idx,
463508
&self.opcode_flag_idx,
464-
record.input_limbs,
465509
*record.opcode as usize,
466510
);
467511

468512
let range_checker = self.range_checker.as_ref();
469513
self.expr
470-
.generate_subrow((range_checker, inputs, flags), core_row);
514+
.generate_subrow_with_precomputed_vars(range_checker, &inputs, &flags, &vars, core_row);
471515
}
472516

473517
fn fill_dummy_trace_row(&self, row_slice: &mut [F]) {
@@ -493,7 +537,7 @@ fn run_field_expression(
493537
opcode_flag_idx: &[usize],
494538
data: &[u8],
495539
local_opcode_idx: usize,
496-
) -> (DynArray<u8>, Vec<BigUint>, Vec<bool>) {
540+
) -> (DynArray<u8>, Vec<BigUint>, Vec<bool>, Vec<BigUint>) {
497541
let field_element_limbs = expr.canonical_num_limbs();
498542
assert_eq!(data.len(), expr.builder.num_input * field_element_limbs);
499543

@@ -506,24 +550,12 @@ fn run_field_expression(
506550
inputs.push(input);
507551
}
508552

509-
let mut flags = vec![];
510-
if expr.needs_setup() {
511-
flags = vec![false; expr.builder.num_flags];
512-
513-
// Find which opcode this is in our local_opcode_idx list
514-
if let Some(opcode_position) = local_opcode_flags
515-
.iter()
516-
.position(|&idx| idx == local_opcode_idx)
517-
{
518-
// If this is NOT the last opcode (setup), set the corresponding flag
519-
if opcode_position < opcode_flag_idx.len() {
520-
let flag_idx = opcode_flag_idx[opcode_position];
521-
flags[flag_idx] = true;
522-
}
523-
// If opcode_position == step.opcode_flag_idx.len(), it's the setup operation
524-
// and all flags should remain false (which they already are)
525-
}
526-
}
553+
let flags = derive_flags_from_opcode(
554+
expr,
555+
local_opcode_flags,
556+
opcode_flag_idx,
557+
local_opcode_idx,
558+
);
527559

528560
let vars = expr.execute(inputs.clone(), flags.clone());
529561
assert_eq!(vars.len(), expr.builder.num_variables);
@@ -542,7 +574,38 @@ fn run_field_expression(
542574
.collect::<Vec<_>>()
543575
.into();
544576

545-
(writes, inputs, flags)
577+
(writes, inputs, flags, vars)
578+
}
579+
580+
/// Derives the flags vector from the local opcode index.
581+
/// This is used during preflight and trace filling to determine which operation is being performed.
582+
fn derive_flags_from_opcode(
583+
expr: &FieldExpr,
584+
local_opcode_flags: &[usize],
585+
opcode_flag_idx: &[usize],
586+
local_opcode_idx: usize,
587+
) -> Vec<bool> {
588+
if !expr.needs_setup() {
589+
return vec![];
590+
}
591+
592+
let mut flags = vec![false; expr.builder.num_flags];
593+
594+
// Find which opcode this is in our local_opcode_idx list
595+
if let Some(opcode_position) = local_opcode_flags
596+
.iter()
597+
.position(|&idx| idx == local_opcode_idx)
598+
{
599+
// If this is NOT the last opcode (setup), set the corresponding flag
600+
if opcode_position < opcode_flag_idx.len() {
601+
let flag_idx = opcode_flag_idx[opcode_position];
602+
flags[flag_idx] = true;
603+
}
604+
// If opcode_position == step.opcode_flag_idx.len(), it's the setup operation
605+
// and all flags should remain false (which they already are)
606+
}
607+
608+
flags
546609
}
547610

548611
#[inline(always)]

0 commit comments

Comments
 (0)