Skip to content

Commit 9c6ccf3

Browse files
authored
fix: expression claim constraints (#2498)
fix: INT-6590 fix: INT-6396, INT-6397, INT-6398, INT-6399, INT-6402, INT-6439
1 parent 89b8865 commit 9c6ccf3

File tree

2 files changed

+162
-37
lines changed

2 files changed

+162
-37
lines changed

crates/recursion/src/batch_constraint/expression_claim/air.rs

Lines changed: 156 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use std::borrow::Borrow;
22

3-
use openvm_circuit_primitives::utils::{assert_array_eq, not};
3+
use openvm_circuit_primitives::{
4+
utils::{assert_array_eq, not},
5+
SubAir,
6+
};
47
use openvm_stark_backend::{
58
interaction::InteractionBuilder, BaseAirWithPublicValues, PartitionedBaseAir,
69
};
@@ -18,24 +21,56 @@ use crate::{
1821
},
1922
bus::{ExpressionClaimNMaxBus, ExpressionClaimNMaxMessage, HyperdimBus, HyperdimBusMessage},
2023
primitives::bus::{PowerCheckerBus, PowerCheckerBusMessage},
24+
subairs::nested_for_loop::{NestedForLoopIoCols, NestedForLoopSubAir},
2125
utils::{base_to_ext, ext_field_add, ext_field_multiply, ext_field_multiply_scalar},
2226
};
2327

2428
/// For each proof, this AIR will receive 2t interaction claims and t constraint claims.
2529
/// (2 interaction claims and 1 constraint claim per trace).
2630
/// These values are folded (algebraic batching) with mu into a single value, which
2731
/// should match the final sumcheck claim.
32+
///
33+
/// Rows are structured as a nested loop: for each proof, group 0 (interactions) comes first,
34+
/// then group 1 (constraints). Within the interaction group, each trace occupies 2 rows
35+
/// (numerator then denominator). Within the constraint group, each trace occupies 1 row.
36+
/// `NestedForLoopSubAir<4>` enforces canonical nested enumeration over
37+
/// `(proof_idx, group_idx, trace_idx, idx_parity)`.
38+
///
39+
/// Example for t = 2 traces (one proof):
40+
///
41+
/// ```text
42+
/// row | is_first | group_idx | is_first_in_group | trace_idx | idx_parity | idx
43+
/// ----|----------|-----------|-------------------|-----------|------------|----
44+
/// 0 | 1 | 0 | 1 | 0 | 0 | 0 ← numerator trace 0
45+
/// 1 | 0 | 0 | 0 | 0 | 1 | 1 ← denominator trace 0
46+
/// 2 | 0 | 0 | 0 | 1 | 0 | 2 ← numerator trace 1
47+
/// 3 | 0 | 0 | 0 | 1 | 1 | 3 ← denominator trace 1
48+
/// 4 | 0 | 1 | 1 | 0 | 0 | 0 ← constraint trace 0
49+
/// 5 | 0 | 1 | 0 | 1 | 0 | 1 ← constraint trace 1
50+
/// ```
51+
///
52+
/// `is_interaction` is derived as `1 - group_idx` (not a separate column).
2853
#[derive(AlignedBorrow, Copy, Clone, Debug)]
2954
#[repr(C)]
3055
pub struct ExpressionClaimCols<T> {
56+
// --- Loop structure (enforced by NestedForLoopSubAir<4>) ---
3157
pub is_valid: T,
58+
/// First row of a proof. Marks proof boundaries.
3259
pub is_first: T,
3360
pub proof_idx: T,
61+
/// 0 = interaction group, 1 = constraint group. Monotone within a proof.
62+
pub group_idx: T,
63+
/// Marks the first row of each group (set at proof start and interaction→constraint boundary).
64+
pub is_first_in_group: T,
3465

35-
pub is_interaction: T,
36-
/// Index within the proof, 0 ~ 2t-1 are interaction claims, 0~t-1 are constraint claims.
66+
// --- Claim indexing (derived from loop counters) ---
67+
/// Claim index within its group. For interactions: `2 * trace_idx + idx_parity` (0..2t).
68+
/// For constraints: `trace_idx` (0..t).
3769
pub idx: T,
70+
/// 0 = numerator, 1 = denominator. Always 0 on constraint rows. Alternates on interaction
71+
/// rows.
3872
pub idx_parity: T,
73+
/// Sorted trace index within the group. Monotone non-decreasing; resets at group boundaries.
3974
pub trace_idx: T,
4075
/// The received evaluation claim. Note that for interactions, this is without norm_factor and
4176
/// eq_sharp_ns. These are interactions_evals (without norm_factor and eq_sharp_ns) and
@@ -90,66 +125,159 @@ where
90125
let local: &ExpressionClaimCols<AB::Var> = (*local).borrow();
91126
let next: &ExpressionClaimCols<AB::Var> = (*next).borrow();
92127

93-
builder.assert_bool(local.is_valid);
94-
builder.assert_bool(local.is_first);
95-
builder.assert_bool(local.is_interaction);
96-
builder.assert_bool(local.idx_parity);
97-
builder.assert_bool(local.n_sign);
128+
// === Loop structure via NestedForLoopSubAir<4> ===
129+
// Enforces canonical nested enumeration for:
130+
// [proof_idx, group_idx, trace_idx, idx_parity]
131+
// with first-flags:
132+
// [is_first, is_first_in_group, is_valid - idx_parity, is_valid].
133+
type LoopSubAir = NestedForLoopSubAir<4>;
134+
let local_is_trace_start = local.is_valid - local.idx_parity;
135+
let next_is_trace_start = next.is_valid - next.idx_parity;
136+
LoopSubAir {}.eval(
137+
builder,
138+
(
139+
NestedForLoopIoCols {
140+
is_enabled: local.is_valid.into(),
141+
counter: [
142+
local.proof_idx.into(),
143+
local.group_idx.into(),
144+
local.trace_idx.into(),
145+
local.idx_parity.into(),
146+
],
147+
is_first: [
148+
local.is_first.into(),
149+
local.is_first_in_group.into(),
150+
local_is_trace_start,
151+
local.is_valid.into(),
152+
],
153+
},
154+
NestedForLoopIoCols {
155+
is_enabled: next.is_valid.into(),
156+
counter: [
157+
next.proof_idx.into(),
158+
next.group_idx.into(),
159+
next.trace_idx.into(),
160+
next.idx_parity.into(),
161+
],
162+
is_first: [
163+
next.is_first.into(),
164+
next.is_first_in_group.into(),
165+
next_is_trace_start,
166+
next.is_valid.into(),
167+
],
168+
},
169+
),
170+
);
171+
172+
// Derived expressions:
173+
// is_interaction: true for interaction group (group_idx == 0)
174+
let is_interaction: AB::Expr = AB::Expr::ONE - local.group_idx.into();
175+
// is_same_proof: next row is valid and within the same proof
176+
let is_same_proof: AB::Expr = LoopSubAir::local_is_transition(next.is_valid, next.is_first);
177+
// is_last_in_proof: current row is the last row of its proof
178+
let is_last_in_proof: AB::Expr =
179+
LoopSubAir::local_is_last(local.is_valid, next.is_valid, next.is_first);
180+
181+
// Each proof starts with group 0 (interactions) and ends with 1 (constraints).
182+
// Start with group 0 is guaranteed by NestedForLoop
98183
builder
99-
.when(local.is_first)
100-
.assert_one(local.is_interaction);
101-
builder.when(local.is_first).assert_zero(local.idx_parity);
184+
.when(is_last_in_proof.clone())
185+
.assert_one(local.group_idx);
186+
187+
// === Claim indexing constraints ===
188+
builder.assert_bool(local.n_sign);
189+
// idx_parity alternates 0/1.
102190
builder
103-
.when(local.is_interaction)
191+
.when(local.is_valid)
192+
.when(is_interaction.clone())
104193
.assert_eq(local.idx_parity + next.idx_parity, AB::Expr::ONE);
194+
// only group 0 can have idx_parity set.
195+
builder.when(local.idx_parity).assert_zero(local.group_idx);
196+
197+
// idx binding to trace_idx / idx_parity
198+
// Interaction rows: idx = 2 * trace_idx + idx_parity
199+
builder.when(is_interaction.clone()).assert_eq(
200+
local.idx,
201+
local.trace_idx * AB::Expr::TWO + local.idx_parity,
202+
);
203+
// Constraint rows: idx = trace_idx
105204
builder
106-
.when(local.idx_parity)
107-
.assert_one(local.is_interaction);
205+
.when(local.group_idx)
206+
.assert_eq(local.idx, local.trace_idx);
207+
208+
// === mu constancy within a proof ===
209+
assert_array_eq(
210+
&mut builder.when(is_same_proof.clone()),
211+
next.mu,
212+
local.mu.map(Into::into),
213+
);
214+
215+
// === Hyperdim metadata constancy within numerator/denominator pairs ===
216+
// A numerator row (idx_parity=0, is_interaction=1) is always followed by its
217+
// denominator (idx_parity=1) due to the alternation constraint. Ensure they
218+
// share the same trace metadata so the hyperdim lookup on the numerator binds both.
219+
builder
220+
.when(local.is_valid)
221+
.when(is_interaction.clone())
222+
.when(not(local.idx_parity))
223+
.assert_eq(next.n_abs, local.n_abs);
224+
builder
225+
.when(local.is_valid)
226+
.when(is_interaction.clone())
227+
.when(not(local.idx_parity))
228+
.assert_eq(next.n_sign, local.n_sign);
108229

109230
// === cum sum folding ===
110-
// cur_sum = next_cur_sum * mu + value * multiplier
231+
// Fold recurrence within a proof: cur_sum = value * multiplier + next_cur_sum * mu
111232
assert_array_eq(
112-
&mut builder.when(local.is_valid * not(next.is_first)),
233+
&mut builder.when(is_same_proof),
113234
local.cur_sum,
114235
ext_field_add::<AB::Expr>(
115236
ext_field_multiply::<AB::Expr>(local.value, local.multiplier),
116237
ext_field_multiply::<AB::Expr>(next.cur_sum, local.mu),
117238
),
118239
);
240+
// Terminal base case: last row of each proof's fold
241+
assert_array_eq(
242+
&mut builder.when(is_last_in_proof),
243+
local.cur_sum,
244+
ext_field_multiply::<AB::Expr>(local.value, local.multiplier),
245+
);
246+
119247
// multiplier = 1 if not interaction
120248
assert_array_eq(
121-
&mut builder.when(not(local.is_interaction)).when(local.is_valid),
249+
&mut builder.when(local.group_idx).when(local.is_valid),
122250
local.multiplier,
123251
base_to_ext::<AB::Expr>(AB::Expr::ONE),
124252
);
125253

126254
// IF negative n and numerator
127255
assert_array_eq(
128-
&mut builder.when(local.n_sign * (local.is_interaction - local.idx_parity)),
256+
&mut builder.when(local.n_sign * (is_interaction.clone() - local.idx_parity)),
129257
ext_field_multiply_scalar::<AB::Expr>(local.multiplier, local.n_abs_pow),
130258
local.eq_sharp_ns,
131259
);
132-
// ELSE 1
260+
// ELSE 1: positive n, interaction row
133261
assert_array_eq(
134-
&mut builder.when(local.is_interaction * (AB::Expr::ONE - local.n_sign)),
262+
&mut builder.when(is_interaction.clone() * (AB::Expr::ONE - local.n_sign)),
135263
local.multiplier,
136264
local.eq_sharp_ns,
137265
);
138-
// ELSE 2
266+
// ELSE 2: denominator row
139267
assert_array_eq(
140268
&mut builder.when(local.idx_parity),
141269
local.multiplier,
142270
local.eq_sharp_ns,
143271
);
144272

145-
// === interactions ===
273+
// === bus interactions ===
146274
self.expr_claim_bus.receive(
147275
builder,
148276
local.proof_idx,
149277
ExpressionClaimMessage {
150-
is_interaction: local.is_interaction,
151-
idx: local.idx,
152-
value: local.value,
278+
is_interaction: is_interaction.clone(),
279+
idx: local.idx.into(),
280+
value: local.value.map(Into::into),
153281
},
154282
local.is_valid,
155283
);
@@ -193,7 +321,7 @@ where
193321
n_abs: local.n_abs.into(),
194322
n_sign_bit: local.n_sign.into(),
195323
},
196-
local.is_valid * (local.is_interaction - local.idx_parity),
324+
local.is_valid * (is_interaction.clone() - local.idx_parity),
197325
);
198326

199327
self.eq_n_outer_bus.lookup_key(
@@ -204,7 +332,7 @@ where
204332
n: local.n_abs * (AB::Expr::ONE - local.n_sign),
205333
value: local.eq_sharp_ns.map(Into::into),
206334
},
207-
local.is_valid * local.is_interaction,
335+
local.is_valid * is_interaction.clone(),
208336
);
209337

210338
self.pow_checker_bus.lookup_key(
@@ -213,7 +341,7 @@ where
213341
log: local.n_abs.into(),
214342
exp: local.n_abs_pow.into(),
215343
},
216-
local.is_valid * local.is_interaction,
344+
local.is_valid * is_interaction,
217345
);
218346
}
219347
}

crates/recursion/src/batch_constraint/expression_claim/trace.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,11 @@ impl RowMajorChip<F> for ExpressionClaimTraceGenerator {
9292
cols.is_first = F::from_bool(i == 0);
9393
cols.is_valid = F::ONE;
9494
cols.proof_idx = F::from_usize(pidx);
95-
cols.is_interaction = F::from_bool(is_interaction);
95+
// group_idx: 0 for interactions, 1 for constraints
96+
cols.group_idx = F::from_bool(!is_interaction);
97+
// is_first_in_group: true at start of proof (i==0) and at start of
98+
// constraint group (i == 2*num_present)
99+
cols.is_first_in_group = F::from_bool(i == 0 || i == 2 * num_present);
96100
cols.num_multilinear_sumcheck_rounds = F::from_usize(num_rounds);
97101
cols.idx = F::from_usize(if i < 2 * num_present {
98102
i
@@ -133,7 +137,7 @@ impl RowMajorChip<F> for ExpressionClaimTraceGenerator {
133137
.for_each(|chunk| {
134138
let cols: &mut ExpressionClaimCols<_> = chunk.borrow_mut();
135139
// if it's interaction, we need to multiply by eq_sharp_ns and norm_factor
136-
let multiplier = if cols.is_interaction == F::ONE {
140+
let multiplier = if cols.group_idx == F::ZERO {
137141
let mut mult =
138142
EF::from_basis_coefficients_slice(&cols.eq_sharp_ns).unwrap();
139143
if cols.n_sign == F::ONE && cols.idx.as_canonical_u32() % 2 == 0 {
@@ -153,13 +157,6 @@ impl RowMajorChip<F> for ExpressionClaimTraceGenerator {
153157

154158
cur_height += claims.len();
155159
}
156-
trace[cur_height * width..]
157-
.par_chunks_mut(width)
158-
.enumerate()
159-
.for_each(|(i, chunk)| {
160-
let cols: &mut ExpressionClaimCols<F> = chunk.borrow_mut();
161-
cols.proof_idx = F::from_usize(preflights.len() + i);
162-
});
163160
Some(RowMajorMatrix::new(trace, width))
164161
}
165162
}

0 commit comments

Comments
 (0)