Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions crates/circuits/mod-builder/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,118 @@ impl FieldExpr {
.collect()
}

/// Generates a trace subrow using precomputed `vars` to avoid redundant computation.
/// This is used during trace filling when vars have already been computed during preflight.
pub fn generate_subrow_with_precomputed_vars<F: PrimeField64>(
&self,
range_checker: &VariableRangeCheckerChip,
inputs: &[BigUint],
flags: &[bool],
vars: &[BigUint],
sub_row: &mut [F],
) {
assert!(self.builder.is_finalized());
assert_eq!(inputs.len(), self.num_input);
assert_eq!(vars.len(), self.num_variables);
assert_eq!(flags.len(), self.builder.num_flags);

let limb_bits = self.limb_bits;

// BigInt type is required for computing the quotient.
let input_bigint = inputs
.iter()
.map(|x| BigInt::from_biguint(Sign::Plus, x.clone()))
.collect::<Vec<BigInt>>();
let vars_bigint: Vec<BigInt> = vars
.iter()
.map(|x| BigInt::from_biguint(Sign::Plus, x.clone()))
.collect();

// OverflowInt type is required for computing the carries.
let input_overflow = inputs
.iter()
.map(|x| OverflowInt::<isize>::from_biguint(x, self.limb_bits, Some(self.num_limbs)))
.collect::<Vec<_>>();
let vars_overflow: Vec<OverflowInt<isize>> = vars
.iter()
.map(|x| OverflowInt::<isize>::from_biguint(x, self.limb_bits, Some(self.num_limbs)))
.collect();

// Note: in cases where the prime fits in less limbs than `num_limbs`, we use the smaller
// number of limbs.
let prime_overflow = OverflowInt::<isize>::from_biguint(&self.prime, self.limb_bits, None);

let constants: Vec<_> = self
.constants
.iter()
.map(|(_, limbs)| {
let limbs_isize: Vec<_> = limbs.iter().map(|i| *i as isize).collect();
OverflowInt::from_canonical_unsigned_limbs(limbs_isize, self.limb_bits)
})
.collect();

let mut all_q = vec![];
let mut all_carry = vec![];
for i in 0..self.constraints.len() {
// expr = q * p
let expr_bigint =
self.constraints[i].evaluate_bigint(&input_bigint, &vars_bigint, flags);
let q = &expr_bigint / &self.prime_bigint;
// If this is not true then the evaluated constraint is not divisible by p.
debug_assert_eq!(expr_bigint, &q * &self.prime_bigint);
let q_limbs = big_int_to_num_limbs(&q, limb_bits, self.q_limbs[i]);
assert_eq!(q_limbs.len(), self.q_limbs[i]); // If this fails, the q_limbs estimate is wrong.
for &q in q_limbs.iter() {
range_checker.add_count((q + (1 << limb_bits)) as u32, limb_bits + 1);
}
let q_overflow = OverflowInt::from_canonical_signed_limbs(q_limbs.clone(), limb_bits);
// compute carries of (expr - q * p)
let expr = self.constraints[i].evaluate_overflow_isize(
&input_overflow,
&vars_overflow,
&constants,
flags,
);
let expr = expr - q_overflow * prime_overflow.clone();
let carries = expr.calculate_carries(limb_bits);
assert_eq!(carries.len(), self.carry_limbs[i]); // If this fails, the carry limbs estimate is wrong.
let max_overflow_bits = expr.max_overflow_bits();
let (carry_min_abs, carry_bits) =
get_carry_max_abs_and_bits(max_overflow_bits, limb_bits);
for &carry in carries.iter() {
range_checker.add_count((carry + carry_min_abs as isize) as u32, carry_bits);
}
all_q.push(vec_isize_to_f::<F>(q_limbs));
all_carry.push(vec_isize_to_f::<F>(carries));
}
for var in vars_overflow.iter() {
for limb in var.limbs().iter() {
range_checker.add_count(*limb as u32, limb_bits);
}
}

let input_limbs = input_overflow
.iter()
.map(|x| vec_isize_to_f::<F>(x.limbs().to_vec()))
.collect::<Vec<_>>();
let vars_limbs = vars_overflow
.iter()
.map(|x| vec_isize_to_f::<F>(x.limbs().to_vec()))
.collect::<Vec<_>>();

sub_row.copy_from_slice(
&[
vec![F::ONE],
input_limbs.concat(),
vars_limbs.concat(),
all_q.concat(),
all_carry.concat(),
flags.iter().map(|x| F::from_bool(*x)).collect::<Vec<_>>(),
]
.concat(),
);
}

pub fn load_vars<T: Clone>(&self, arr: &[T]) -> FieldExprCols<T> {
assert!(self.builder.is_finalized());
let is_valid = arr[0].clone();
Expand Down
125 changes: 94 additions & 31 deletions crates/circuits/mod-builder/src/core_chip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,15 @@ where

pub struct FieldExpressionMetadata<F, A> {
pub total_input_limbs: usize, // num_inputs * limbs_per_input
pub total_var_limbs: usize, // num_variables * limbs_per_variable
_phantom: PhantomData<(F, A)>,
}

impl<F, A> Clone for FieldExpressionMetadata<F, A> {
fn clone(&self) -> Self {
Self {
total_input_limbs: self.total_input_limbs,
total_var_limbs: self.total_var_limbs,
_phantom: PhantomData,
}
}
Expand All @@ -187,15 +189,17 @@ impl<F, A> Default for FieldExpressionMetadata<F, A> {
fn default() -> Self {
Self {
total_input_limbs: 0,
total_var_limbs: 0,
_phantom: PhantomData,
}
}
}

impl<F, A> FieldExpressionMetadata<F, A> {
pub fn new(total_input_limbs: usize) -> Self {
pub fn new(total_input_limbs: usize, total_var_limbs: usize) -> Self {
Self {
total_input_limbs,
total_var_limbs,
_phantom: PhantomData,
}
}
Expand All @@ -216,6 +220,7 @@ pub type FieldExpressionRecordLayout<F, A> = AdapterCoreLayout<FieldExpressionMe
pub struct FieldExpressionCoreRecordMut<'a> {
pub opcode: &'a mut u8,
pub input_limbs: &'a mut [u8],
pub var_limbs: &'a mut [u8],
}

impl<'a, F, A> CustomBorrow<'a, FieldExpressionCoreRecordMut<'a>, FieldExpressionRecordLayout<F, A>>
Expand All @@ -226,14 +231,19 @@ impl<'a, F, A> CustomBorrow<'a, FieldExpressionCoreRecordMut<'a>, FieldExpressio
layout: FieldExpressionRecordLayout<F, A>,
) -> FieldExpressionCoreRecordMut<'a> {
// SAFETY: The buffer length is the width of the trace which should be at least 1
let (opcode_buf, input_limbs_buff) = unsafe { self.split_at_mut_unchecked(1) };
let (opcode_buf, rest) = unsafe { self.split_at_mut_unchecked(1) };

// SAFETY: opcode_buf has exactly 1 element from split_at_mut_unchecked(1)
let opcode_buf = unsafe { opcode_buf.get_unchecked_mut(0) };

// SAFETY: rest has at least total_input_limbs + total_var_limbs bytes
let (input_limbs, var_limbs_buf) =
unsafe { rest.split_at_mut_unchecked(layout.metadata.total_input_limbs) };

FieldExpressionCoreRecordMut {
opcode: opcode_buf,
input_limbs: &mut input_limbs_buff[..layout.metadata.total_input_limbs],
input_limbs,
var_limbs: &mut var_limbs_buf[..layout.metadata.total_var_limbs],
}
}

Expand All @@ -244,7 +254,9 @@ impl<'a, F, A> CustomBorrow<'a, FieldExpressionCoreRecordMut<'a>, FieldExpressio

impl<F, A> SizedRecord<FieldExpressionRecordLayout<F, A>> for FieldExpressionCoreRecordMut<'_> {
fn size(layout: &FieldExpressionRecordLayout<F, A>) -> usize {
layout.metadata.total_input_limbs + 1
1 // opcode
+ layout.metadata.total_input_limbs
+ layout.metadata.total_var_limbs
}

fn alignment(_layout: &FieldExpressionRecordLayout<F, A>) -> usize {
Expand All @@ -257,9 +269,13 @@ impl<'a> FieldExpressionCoreRecordMut<'a> {
pub fn new_from_execution_data(
buffer: &'a mut [u8],
inputs: &[BigUint],
num_variables: usize,
limbs_per_input: usize,
) -> Self {
let record_info = FieldExpressionMetadata::<(), ()>::new(inputs.len() * limbs_per_input);
let record_info = FieldExpressionMetadata::<(), ()>::new(
inputs.len() * limbs_per_input,
num_variables * limbs_per_input,
);

let record: Self = buffer.custom_borrow(FieldExpressionRecordLayout {
metadata: record_info,
Expand Down Expand Up @@ -319,9 +335,11 @@ impl<A> FieldExpressionExecutor<A> {
}

pub fn get_record_layout<F>(&self) -> FieldExpressionRecordLayout<F, A> {
let num_limbs = self.expr.canonical_num_limbs();
FieldExpressionRecordLayout {
metadata: FieldExpressionMetadata::new(
self.expr.builder.num_input * self.expr.canonical_num_limbs(),
self.expr.builder.num_input * num_limbs,
self.expr.builder.num_variables * num_limbs,
),
}
}
Expand Down Expand Up @@ -372,9 +390,11 @@ impl<A> FieldExpressionFiller<A> {
}

pub fn get_record_layout<F>(&self) -> FieldExpressionRecordLayout<F, A> {
let num_limbs = self.expr.canonical_num_limbs();
FieldExpressionRecordLayout {
metadata: FieldExpressionMetadata::new(
self.num_inputs() * self.expr.canonical_num_limbs(),
self.num_inputs() * num_limbs,
self.expr.builder.num_variables * num_limbs,
),
}
}
Expand Down Expand Up @@ -410,14 +430,23 @@ where
&data.0,
);

let (writes, _, _) = run_field_expression(
let (writes, _, _, vars) = run_field_expression(
&self.expr,
&self.local_opcode_idx,
&self.opcode_flag_idx,
core_record.input_limbs,
*core_record.opcode as usize,
);

// Store computed vars in the record to avoid recomputation during trace filling
let field_element_limbs = self.expr.canonical_num_limbs();
for (i, var) in vars.iter().enumerate() {
let start = i * field_element_limbs;
let end = start + field_element_limbs;
let limbs = biguint_to_limbs_vec(var, field_element_limbs);
core_record.var_limbs[start..end].copy_from_slice(&limbs);
}

self.adapter.write(
state.memory,
instruction,
Expand Down Expand Up @@ -457,17 +486,32 @@ where
let record: FieldExpressionCoreRecordMut =
unsafe { get_record_from_slice(&mut core_row, self.get_record_layout::<F>()) };

let (_, inputs, flags) = run_field_expression(
// Reconstruct inputs from record
let field_element_limbs = self.expr.canonical_num_limbs();
let inputs: Vec<BigUint> = record
.input_limbs
.chunks(field_element_limbs)
.map(BigUint::from_bytes_le)
.collect();

// Read precomputed vars from record (computed during preflight)
let vars: Vec<BigUint> = record
.var_limbs
.chunks(field_element_limbs)
.map(BigUint::from_bytes_le)
.collect();

// Derive flags from opcode (cheap operation)
let flags = derive_flags_from_opcode(
&self.expr,
&self.local_opcode_idx,
&self.opcode_flag_idx,
record.input_limbs,
*record.opcode as usize,
);

let range_checker = self.range_checker.as_ref();
self.expr
.generate_subrow((range_checker, inputs, flags), core_row);
.generate_subrow_with_precomputed_vars(range_checker, &inputs, &flags, &vars, core_row);
}

fn fill_dummy_trace_row(&self, row_slice: &mut [F]) {
Expand All @@ -493,7 +537,7 @@ fn run_field_expression(
opcode_flag_idx: &[usize],
data: &[u8],
local_opcode_idx: usize,
) -> (DynArray<u8>, Vec<BigUint>, Vec<bool>) {
) -> (DynArray<u8>, Vec<BigUint>, Vec<bool>, Vec<BigUint>) {
let field_element_limbs = expr.canonical_num_limbs();
assert_eq!(data.len(), expr.builder.num_input * field_element_limbs);

Expand All @@ -506,24 +550,12 @@ fn run_field_expression(
inputs.push(input);
}

let mut flags = vec![];
if expr.needs_setup() {
flags = vec![false; expr.builder.num_flags];

// Find which opcode this is in our local_opcode_idx list
if let Some(opcode_position) = local_opcode_flags
.iter()
.position(|&idx| idx == local_opcode_idx)
{
// If this is NOT the last opcode (setup), set the corresponding flag
if opcode_position < opcode_flag_idx.len() {
let flag_idx = opcode_flag_idx[opcode_position];
flags[flag_idx] = true;
}
// If opcode_position == step.opcode_flag_idx.len(), it's the setup operation
// and all flags should remain false (which they already are)
}
}
let flags = derive_flags_from_opcode(
expr,
local_opcode_flags,
opcode_flag_idx,
local_opcode_idx,
);

let vars = expr.execute(inputs.clone(), flags.clone());
assert_eq!(vars.len(), expr.builder.num_variables);
Expand All @@ -542,7 +574,38 @@ fn run_field_expression(
.collect::<Vec<_>>()
.into();

(writes, inputs, flags)
(writes, inputs, flags, vars)
}

/// Derives the flags vector from the local opcode index.
/// This is used during preflight and trace filling to determine which operation is being performed.
fn derive_flags_from_opcode(
expr: &FieldExpr,
local_opcode_flags: &[usize],
opcode_flag_idx: &[usize],
local_opcode_idx: usize,
) -> Vec<bool> {
if !expr.needs_setup() {
return vec![];
}

let mut flags = vec![false; expr.builder.num_flags];

// Find which opcode this is in our local_opcode_idx list
if let Some(opcode_position) = local_opcode_flags
.iter()
.position(|&idx| idx == local_opcode_idx)
{
// If this is NOT the last opcode (setup), set the corresponding flag
if opcode_position < opcode_flag_idx.len() {
let flag_idx = opcode_flag_idx[opcode_position];
flags[flag_idx] = true;
}
// If opcode_position == step.opcode_flag_idx.len(), it's the setup operation
// and all flags should remain false (which they already are)
}

flags
}

#[inline(always)]
Expand Down
Loading
Loading