diff --git a/crates/circuits/mod-builder/src/builder.rs b/crates/circuits/mod-builder/src/builder.rs index 75c31f579a..4a25564f87 100644 --- a/crates/circuits/mod-builder/src/builder.rs +++ b/crates/circuits/mod-builder/src/builder.rs @@ -616,6 +616,118 @@ impl FieldExpr { .collect() } + /// Generates a trace subrow using precomputed `vars` to avoid redundant computation. + /// This is used during trace filling when vars have already been computed during preflight. + pub fn generate_subrow_with_precomputed_vars( + &self, + range_checker: &VariableRangeCheckerChip, + inputs: &[BigUint], + flags: &[bool], + vars: &[BigUint], + sub_row: &mut [F], + ) { + assert!(self.builder.is_finalized()); + assert_eq!(inputs.len(), self.num_input); + assert_eq!(vars.len(), self.num_variables); + assert_eq!(flags.len(), self.builder.num_flags); + + let limb_bits = self.limb_bits; + + // BigInt type is required for computing the quotient. + let input_bigint = inputs + .iter() + .map(|x| BigInt::from_biguint(Sign::Plus, x.clone())) + .collect::>(); + let vars_bigint: Vec = vars + .iter() + .map(|x| BigInt::from_biguint(Sign::Plus, x.clone())) + .collect(); + + // OverflowInt type is required for computing the carries. + let input_overflow = inputs + .iter() + .map(|x| OverflowInt::::from_biguint(x, self.limb_bits, Some(self.num_limbs))) + .collect::>(); + let vars_overflow: Vec> = vars + .iter() + .map(|x| OverflowInt::::from_biguint(x, self.limb_bits, Some(self.num_limbs))) + .collect(); + + // Note: in cases where the prime fits in less limbs than `num_limbs`, we use the smaller + // number of limbs. + let prime_overflow = OverflowInt::::from_biguint(&self.prime, self.limb_bits, None); + + let constants: Vec<_> = self + .constants + .iter() + .map(|(_, limbs)| { + let limbs_isize: Vec<_> = limbs.iter().map(|i| *i as isize).collect(); + OverflowInt::from_canonical_unsigned_limbs(limbs_isize, self.limb_bits) + }) + .collect(); + + let mut all_q = vec![]; + let mut all_carry = vec![]; + for i in 0..self.constraints.len() { + // expr = q * p + let expr_bigint = + self.constraints[i].evaluate_bigint(&input_bigint, &vars_bigint, flags); + let q = &expr_bigint / &self.prime_bigint; + // If this is not true then the evaluated constraint is not divisible by p. + debug_assert_eq!(expr_bigint, &q * &self.prime_bigint); + let q_limbs = big_int_to_num_limbs(&q, limb_bits, self.q_limbs[i]); + assert_eq!(q_limbs.len(), self.q_limbs[i]); // If this fails, the q_limbs estimate is wrong. + for &q in q_limbs.iter() { + range_checker.add_count((q + (1 << limb_bits)) as u32, limb_bits + 1); + } + let q_overflow = OverflowInt::from_canonical_signed_limbs(q_limbs.clone(), limb_bits); + // compute carries of (expr - q * p) + let expr = self.constraints[i].evaluate_overflow_isize( + &input_overflow, + &vars_overflow, + &constants, + flags, + ); + let expr = expr - q_overflow * prime_overflow.clone(); + let carries = expr.calculate_carries(limb_bits); + assert_eq!(carries.len(), self.carry_limbs[i]); // If this fails, the carry limbs estimate is wrong. + let max_overflow_bits = expr.max_overflow_bits(); + let (carry_min_abs, carry_bits) = + get_carry_max_abs_and_bits(max_overflow_bits, limb_bits); + for &carry in carries.iter() { + range_checker.add_count((carry + carry_min_abs as isize) as u32, carry_bits); + } + all_q.push(vec_isize_to_f::(q_limbs)); + all_carry.push(vec_isize_to_f::(carries)); + } + for var in vars_overflow.iter() { + for limb in var.limbs().iter() { + range_checker.add_count(*limb as u32, limb_bits); + } + } + + let input_limbs = input_overflow + .iter() + .map(|x| vec_isize_to_f::(x.limbs().to_vec())) + .collect::>(); + let vars_limbs = vars_overflow + .iter() + .map(|x| vec_isize_to_f::(x.limbs().to_vec())) + .collect::>(); + + sub_row.copy_from_slice( + &[ + vec![F::ONE], + input_limbs.concat(), + vars_limbs.concat(), + all_q.concat(), + all_carry.concat(), + flags.iter().map(|x| F::from_bool(*x)).collect::>(), + ] + .concat(), + ); + } + pub fn load_vars(&self, arr: &[T]) -> FieldExprCols { assert!(self.builder.is_finalized()); let is_valid = arr[0].clone(); diff --git a/crates/circuits/mod-builder/src/core_chip.rs b/crates/circuits/mod-builder/src/core_chip.rs index 719e1e0cf4..1211667eef 100644 --- a/crates/circuits/mod-builder/src/core_chip.rs +++ b/crates/circuits/mod-builder/src/core_chip.rs @@ -171,6 +171,7 @@ where pub struct FieldExpressionMetadata { pub total_input_limbs: usize, // num_inputs * limbs_per_input + pub total_var_limbs: usize, // num_variables * limbs_per_variable _phantom: PhantomData<(F, A)>, } @@ -178,6 +179,7 @@ impl Clone for FieldExpressionMetadata { fn clone(&self) -> Self { Self { total_input_limbs: self.total_input_limbs, + total_var_limbs: self.total_var_limbs, _phantom: PhantomData, } } @@ -187,15 +189,17 @@ impl Default for FieldExpressionMetadata { fn default() -> Self { Self { total_input_limbs: 0, + total_var_limbs: 0, _phantom: PhantomData, } } } impl FieldExpressionMetadata { - pub fn new(total_input_limbs: usize) -> Self { + pub fn new(total_input_limbs: usize, total_var_limbs: usize) -> Self { Self { total_input_limbs, + total_var_limbs, _phantom: PhantomData, } } @@ -216,6 +220,7 @@ pub type FieldExpressionRecordLayout = AdapterCoreLayout { pub opcode: &'a mut u8, pub input_limbs: &'a mut [u8], + pub var_limbs: &'a mut [u8], } impl<'a, F, A> CustomBorrow<'a, FieldExpressionCoreRecordMut<'a>, FieldExpressionRecordLayout> @@ -226,14 +231,19 @@ impl<'a, F, A> CustomBorrow<'a, FieldExpressionCoreRecordMut<'a>, FieldExpressio layout: FieldExpressionRecordLayout, ) -> FieldExpressionCoreRecordMut<'a> { // SAFETY: The buffer length is the width of the trace which should be at least 1 - let (opcode_buf, input_limbs_buff) = unsafe { self.split_at_mut_unchecked(1) }; + let (opcode_buf, rest) = unsafe { self.split_at_mut_unchecked(1) }; // SAFETY: opcode_buf has exactly 1 element from split_at_mut_unchecked(1) let opcode_buf = unsafe { opcode_buf.get_unchecked_mut(0) }; + // SAFETY: rest has at least total_input_limbs + total_var_limbs bytes + let (input_limbs, var_limbs_buf) = + unsafe { rest.split_at_mut_unchecked(layout.metadata.total_input_limbs) }; + FieldExpressionCoreRecordMut { opcode: opcode_buf, - input_limbs: &mut input_limbs_buff[..layout.metadata.total_input_limbs], + input_limbs, + var_limbs: &mut var_limbs_buf[..layout.metadata.total_var_limbs], } } @@ -244,7 +254,9 @@ impl<'a, F, A> CustomBorrow<'a, FieldExpressionCoreRecordMut<'a>, FieldExpressio impl SizedRecord> for FieldExpressionCoreRecordMut<'_> { fn size(layout: &FieldExpressionRecordLayout) -> usize { - layout.metadata.total_input_limbs + 1 + 1 // opcode + + layout.metadata.total_input_limbs + + layout.metadata.total_var_limbs } fn alignment(_layout: &FieldExpressionRecordLayout) -> usize { @@ -257,9 +269,13 @@ impl<'a> FieldExpressionCoreRecordMut<'a> { pub fn new_from_execution_data( buffer: &'a mut [u8], inputs: &[BigUint], + num_variables: usize, limbs_per_input: usize, ) -> Self { - let record_info = FieldExpressionMetadata::<(), ()>::new(inputs.len() * limbs_per_input); + let record_info = FieldExpressionMetadata::<(), ()>::new( + inputs.len() * limbs_per_input, + num_variables * limbs_per_input, + ); let record: Self = buffer.custom_borrow(FieldExpressionRecordLayout { metadata: record_info, @@ -319,9 +335,11 @@ impl FieldExpressionExecutor { } pub fn get_record_layout(&self) -> FieldExpressionRecordLayout { + let num_limbs = self.expr.canonical_num_limbs(); FieldExpressionRecordLayout { metadata: FieldExpressionMetadata::new( - self.expr.builder.num_input * self.expr.canonical_num_limbs(), + self.expr.builder.num_input * num_limbs, + self.expr.builder.num_variables * num_limbs, ), } } @@ -372,9 +390,11 @@ impl FieldExpressionFiller { } pub fn get_record_layout(&self) -> FieldExpressionRecordLayout { + let num_limbs = self.expr.canonical_num_limbs(); FieldExpressionRecordLayout { metadata: FieldExpressionMetadata::new( - self.num_inputs() * self.expr.canonical_num_limbs(), + self.num_inputs() * num_limbs, + self.expr.builder.num_variables * num_limbs, ), } } @@ -410,7 +430,7 @@ where &data.0, ); - let (writes, _, _) = run_field_expression( + let (writes, _, _, vars) = run_field_expression( &self.expr, &self.local_opcode_idx, &self.opcode_flag_idx, @@ -418,6 +438,15 @@ where *core_record.opcode as usize, ); + // Store computed vars in the record to avoid recomputation during trace filling + let field_element_limbs = self.expr.canonical_num_limbs(); + for (i, var) in vars.iter().enumerate() { + let start = i * field_element_limbs; + let end = start + field_element_limbs; + let limbs = biguint_to_limbs_vec(var, field_element_limbs); + core_record.var_limbs[start..end].copy_from_slice(&limbs); + } + self.adapter.write( state.memory, instruction, @@ -457,17 +486,32 @@ where let record: FieldExpressionCoreRecordMut = unsafe { get_record_from_slice(&mut core_row, self.get_record_layout::()) }; - let (_, inputs, flags) = run_field_expression( + // Reconstruct inputs from record + let field_element_limbs = self.expr.canonical_num_limbs(); + let inputs: Vec = record + .input_limbs + .chunks(field_element_limbs) + .map(BigUint::from_bytes_le) + .collect(); + + // Read precomputed vars from record (computed during preflight) + let vars: Vec = record + .var_limbs + .chunks(field_element_limbs) + .map(BigUint::from_bytes_le) + .collect(); + + // Derive flags from opcode (cheap operation) + let flags = derive_flags_from_opcode( &self.expr, &self.local_opcode_idx, &self.opcode_flag_idx, - record.input_limbs, *record.opcode as usize, ); let range_checker = self.range_checker.as_ref(); self.expr - .generate_subrow((range_checker, inputs, flags), core_row); + .generate_subrow_with_precomputed_vars(range_checker, &inputs, &flags, &vars, core_row); } fn fill_dummy_trace_row(&self, row_slice: &mut [F]) { @@ -493,7 +537,7 @@ fn run_field_expression( opcode_flag_idx: &[usize], data: &[u8], local_opcode_idx: usize, -) -> (DynArray, Vec, Vec) { +) -> (DynArray, Vec, Vec, Vec) { let field_element_limbs = expr.canonical_num_limbs(); assert_eq!(data.len(), expr.builder.num_input * field_element_limbs); @@ -506,24 +550,12 @@ fn run_field_expression( inputs.push(input); } - let mut flags = vec![]; - if expr.needs_setup() { - flags = vec![false; expr.builder.num_flags]; - - // Find which opcode this is in our local_opcode_idx list - if let Some(opcode_position) = local_opcode_flags - .iter() - .position(|&idx| idx == local_opcode_idx) - { - // If this is NOT the last opcode (setup), set the corresponding flag - if opcode_position < opcode_flag_idx.len() { - let flag_idx = opcode_flag_idx[opcode_position]; - flags[flag_idx] = true; - } - // If opcode_position == step.opcode_flag_idx.len(), it's the setup operation - // and all flags should remain false (which they already are) - } - } + let flags = derive_flags_from_opcode( + expr, + local_opcode_flags, + opcode_flag_idx, + local_opcode_idx, + ); let vars = expr.execute(inputs.clone(), flags.clone()); assert_eq!(vars.len(), expr.builder.num_variables); @@ -542,7 +574,38 @@ fn run_field_expression( .collect::>() .into(); - (writes, inputs, flags) + (writes, inputs, flags, vars) +} + +/// Derives the flags vector from the local opcode index. +/// This is used during preflight and trace filling to determine which operation is being performed. +fn derive_flags_from_opcode( + expr: &FieldExpr, + local_opcode_flags: &[usize], + opcode_flag_idx: &[usize], + local_opcode_idx: usize, +) -> Vec { + if !expr.needs_setup() { + return vec![]; + } + + let mut flags = vec![false; expr.builder.num_flags]; + + // Find which opcode this is in our local_opcode_idx list + if let Some(opcode_position) = local_opcode_flags + .iter() + .position(|&idx| idx == local_opcode_idx) + { + // If this is NOT the last opcode (setup), set the corresponding flag + if opcode_position < opcode_flag_idx.len() { + let flag_idx = opcode_flag_idx[opcode_position]; + flags[flag_idx] = true; + } + // If opcode_position == step.opcode_flag_idx.len(), it's the setup operation + // and all flags should remain false (which they already are) + } + + flags } #[inline(always)] diff --git a/crates/circuits/mod-builder/src/tests.rs b/crates/circuits/mod-builder/src/tests.rs index e9bffa1653..6c40138af6 100644 --- a/crates/circuits/mod-builder/src/tests.rs +++ b/crates/circuits/mod-builder/src/tests.rs @@ -7,7 +7,7 @@ use openvm_stark_backend::{ p3_air::BaseAir, p3_field::FieldAlgebra, p3_matrix::dense::RowMajorMatrix, }; use openvm_stark_sdk::{ - any_rap_arc_vec, config::baby_bear_blake3::BabyBearBlake3Engine, engine::StarkFriEngine, + any_rap_arc_vec, config::baby_bear_poseidon2::BabyBearPoseidon2Engine, engine::StarkFriEngine, p3_baby_bear::BabyBear, }; @@ -60,10 +60,11 @@ fn generate_recorded_trace( flags: Vec, width: usize, ) -> Vec { - let mut buffer = vec![0u8; 1024]; + let mut buffer = vec![0u8; 2048]; let mut record = FieldExpressionCoreRecordMut::new_from_execution_data( &mut buffer, inputs, + expr.num_variables, expr.canonical_num_limbs(), ); let data: Vec = inputs @@ -91,7 +92,7 @@ fn verify_stark_with_traces( ) { let trace_matrix = RowMajorMatrix::new(trace, width); let range_trace = range_checker.generate_trace(); - BabyBearBlake3Engine::run_simple_test_no_pis_fast( + BabyBearPoseidon2Engine::run_simple_test_no_pis_fast( any_rap_arc_vec![expr, range_checker.air], vec![trace_matrix, range_trace], ) @@ -429,10 +430,11 @@ fn test_recorded_execution_records() { let flags: Vec = vec![]; // Test record creation and reconstruction - let mut buffer = vec![0u8; 1024]; + let mut buffer = vec![0u8; 2048]; let mut record = FieldExpressionCoreRecordMut::new_from_execution_data( &mut buffer, &inputs, + expr.num_variables, expr.canonical_num_limbs(), ); let data: Vec = inputs @@ -515,10 +517,11 @@ fn test_record_arena_allocation_patterns() { ]; // Test record creation with various input sizes - let mut buffer = vec![0u8; 1024]; + let mut buffer = vec![0u8; 2048]; let mut record = FieldExpressionCoreRecordMut::new_from_execution_data( &mut buffer, &inputs, + expr.num_variables, expr.canonical_num_limbs(), ); let data: Vec = inputs @@ -530,9 +533,10 @@ fn test_record_arena_allocation_patterns() { // Test with maximum inputs let max_inputs = vec![BigUint::one(); 40]; // MAX_INPUT_LIMBS / 4 - let mut max_buffer = vec![0u8; 2048]; + let mut max_buffer = vec![0u8; 4096]; + // Using 10 as an arbitrary number of variables for this test let max_record = - FieldExpressionCoreRecordMut::new_from_execution_data(&mut max_buffer, &max_inputs, 4); + FieldExpressionCoreRecordMut::new_from_execution_data(&mut max_buffer, &max_inputs, 10, 4); assert_eq!(*max_record.opcode, 0); // Test input reconstruction @@ -570,10 +574,11 @@ fn test_tracestep_tracefiller_roundtrip() { let vars_direct = expr.execute(inputs.clone(), vec![]); // Test record creation and reconstruction roundtrip - let mut buffer = vec![0u8; 1024]; + let mut buffer = vec![0u8; 2048]; let mut record = FieldExpressionCoreRecordMut::new_from_execution_data( &mut buffer, &inputs, + expr.num_variables, expr.canonical_num_limbs(), ); let data: Vec = inputs