11use std::borrow::Borrow;
22
3- use openvm_circuit_primitives::utils::{assert_array_eq, not};
3+ use openvm_circuit_primitives::{
4+ utils::{assert_array_eq, not},
5+ SubAir,
6+ };
47use openvm_stark_backend::{
58 interaction::InteractionBuilder, BaseAirWithPublicValues, PartitionedBaseAir,
69};
@@ -18,24 +21,56 @@ use crate::{
1821 },
1922 bus::{ExpressionClaimNMaxBus, ExpressionClaimNMaxMessage, HyperdimBus, HyperdimBusMessage},
2023 primitives::bus::{PowerCheckerBus, PowerCheckerBusMessage},
24+ subairs::nested_for_loop::{NestedForLoopIoCols, NestedForLoopSubAir},
2125 utils::{base_to_ext, ext_field_add, ext_field_multiply, ext_field_multiply_scalar},
2226};
2327
2428/// For each proof, this AIR will receive 2t interaction claims and t constraint claims.
2529/// (2 interaction claims and 1 constraint claim per trace).
2630/// These values are folded (algebraic batching) with mu into a single value, which
2731/// should match the final sumcheck claim.
32+ ///
33+ /// Rows are structured as a nested loop: for each proof, group 0 (interactions) comes first,
34+ /// then group 1 (constraints). Within the interaction group, each trace occupies 2 rows
35+ /// (numerator then denominator). Within the constraint group, each trace occupies 1 row.
36+ /// `NestedForLoopSubAir<4>` enforces canonical nested enumeration over
37+ /// `(proof_idx, group_idx, trace_idx, idx_parity)`.
38+ ///
39+ /// Example for t = 2 traces (one proof):
40+ ///
41+ /// ```text
42+ /// row | is_first | group_idx | is_first_in_group | trace_idx | idx_parity | idx
43+ /// ----|----------|-----------|-------------------|-----------|------------|----
44+ /// 0 | 1 | 0 | 1 | 0 | 0 | 0 ← numerator trace 0
45+ /// 1 | 0 | 0 | 0 | 0 | 1 | 1 ← denominator trace 0
46+ /// 2 | 0 | 0 | 0 | 1 | 0 | 2 ← numerator trace 1
47+ /// 3 | 0 | 0 | 0 | 1 | 1 | 3 ← denominator trace 1
48+ /// 4 | 0 | 1 | 1 | 0 | 0 | 0 ← constraint trace 0
49+ /// 5 | 0 | 1 | 0 | 1 | 0 | 1 ← constraint trace 1
50+ /// ```
51+ ///
52+ /// `is_interaction` is derived as `1 - group_idx` (not a separate column).
2853#[derive(AlignedBorrow, Copy, Clone, Debug)]
2954#[repr(C)]
3055pub struct ExpressionClaimCols<T> {
56+ // --- Loop structure (enforced by NestedForLoopSubAir<4>) ---
3157 pub is_valid: T,
58+ /// First row of a proof. Marks proof boundaries.
3259 pub is_first: T,
3360 pub proof_idx: T,
61+ /// 0 = interaction group, 1 = constraint group. Monotone within a proof.
62+ pub group_idx: T,
63+ /// Marks the first row of each group (set at proof start and interaction→constraint boundary).
64+ pub is_first_in_group: T,
3465
35- pub is_interaction: T,
36- /// Index within the proof, 0 ~ 2t-1 are interaction claims, 0~t-1 are constraint claims.
66+ // --- Claim indexing (derived from loop counters) ---
67+ /// Claim index within its group. For interactions: `2 * trace_idx + idx_parity` (0..2t).
68+ /// For constraints: `trace_idx` (0..t).
3769 pub idx: T,
70+ /// 0 = numerator, 1 = denominator. Always 0 on constraint rows. Alternates on interaction
71+ /// rows.
3872 pub idx_parity: T,
73+ /// Sorted trace index within the group. Monotone non-decreasing; resets at group boundaries.
3974 pub trace_idx: T,
4075 /// The received evaluation claim. Note that for interactions, this is without norm_factor and
4176 /// eq_sharp_ns. These are interactions_evals (without norm_factor and eq_sharp_ns) and
@@ -90,66 +125,159 @@ where
90125 let local: &ExpressionClaimCols<AB::Var> = (*local).borrow();
91126 let next: &ExpressionClaimCols<AB::Var> = (*next).borrow();
92127
93- builder.assert_bool(local.is_valid);
94- builder.assert_bool(local.is_first);
95- builder.assert_bool(local.is_interaction);
96- builder.assert_bool(local.idx_parity);
97- builder.assert_bool(local.n_sign);
128+ // === Loop structure via NestedForLoopSubAir<4> ===
129+ // Enforces canonical nested enumeration for:
130+ // [proof_idx, group_idx, trace_idx, idx_parity]
131+ // with first-flags:
132+ // [is_first, is_first_in_group, is_valid - idx_parity, is_valid].
133+ type LoopSubAir = NestedForLoopSubAir<4>;
134+ let local_is_trace_start = local.is_valid - local.idx_parity;
135+ let next_is_trace_start = next.is_valid - next.idx_parity;
136+ LoopSubAir {}.eval(
137+ builder,
138+ (
139+ NestedForLoopIoCols {
140+ is_enabled: local.is_valid.into(),
141+ counter: [
142+ local.proof_idx.into(),
143+ local.group_idx.into(),
144+ local.trace_idx.into(),
145+ local.idx_parity.into(),
146+ ],
147+ is_first: [
148+ local.is_first.into(),
149+ local.is_first_in_group.into(),
150+ local_is_trace_start,
151+ local.is_valid.into(),
152+ ],
153+ },
154+ NestedForLoopIoCols {
155+ is_enabled: next.is_valid.into(),
156+ counter: [
157+ next.proof_idx.into(),
158+ next.group_idx.into(),
159+ next.trace_idx.into(),
160+ next.idx_parity.into(),
161+ ],
162+ is_first: [
163+ next.is_first.into(),
164+ next.is_first_in_group.into(),
165+ next_is_trace_start,
166+ next.is_valid.into(),
167+ ],
168+ },
169+ ),
170+ );
171+
172+ // Derived expressions:
173+ // is_interaction: true for interaction group (group_idx == 0)
174+ let is_interaction: AB::Expr = AB::Expr::ONE - local.group_idx.into();
175+ // is_same_proof: next row is valid and within the same proof
176+ let is_same_proof: AB::Expr = LoopSubAir::local_is_transition(next.is_valid, next.is_first);
177+ // is_last_in_proof: current row is the last row of its proof
178+ let is_last_in_proof: AB::Expr =
179+ LoopSubAir::local_is_last(local.is_valid, next.is_valid, next.is_first);
180+
181+ // Each proof starts with group 0 (interactions) and ends with 1 (constraints).
182+ // Start with group 0 is guaranteed by NestedForLoop
98183 builder
99- .when(local.is_first)
100- .assert_one(local.is_interaction);
101- builder.when(local.is_first).assert_zero(local.idx_parity);
184+ .when(is_last_in_proof.clone())
185+ .assert_one(local.group_idx);
186+
187+ // === Claim indexing constraints ===
188+ builder.assert_bool(local.n_sign);
189+ // idx_parity alternates 0/1.
102190 builder
103- .when(local.is_interaction)
191+ .when(local.is_valid)
192+ .when(is_interaction.clone())
104193 .assert_eq(local.idx_parity + next.idx_parity, AB::Expr::ONE);
194+ // only group 0 can have idx_parity set.
195+ builder.when(local.idx_parity).assert_zero(local.group_idx);
196+
197+ // idx binding to trace_idx / idx_parity
198+ // Interaction rows: idx = 2 * trace_idx + idx_parity
199+ builder.when(is_interaction.clone()).assert_eq(
200+ local.idx,
201+ local.trace_idx * AB::Expr::TWO + local.idx_parity,
202+ );
203+ // Constraint rows: idx = trace_idx
105204 builder
106- .when(local.idx_parity)
107- .assert_one(local.is_interaction);
205+ .when(local.group_idx)
206+ .assert_eq(local.idx, local.trace_idx);
207+
208+ // === mu constancy within a proof ===
209+ assert_array_eq(
210+ &mut builder.when(is_same_proof.clone()),
211+ next.mu,
212+ local.mu.map(Into::into),
213+ );
214+
215+ // === Hyperdim metadata constancy within numerator/denominator pairs ===
216+ // A numerator row (idx_parity=0, is_interaction=1) is always followed by its
217+ // denominator (idx_parity=1) due to the alternation constraint. Ensure they
218+ // share the same trace metadata so the hyperdim lookup on the numerator binds both.
219+ builder
220+ .when(local.is_valid)
221+ .when(is_interaction.clone())
222+ .when(not(local.idx_parity))
223+ .assert_eq(next.n_abs, local.n_abs);
224+ builder
225+ .when(local.is_valid)
226+ .when(is_interaction.clone())
227+ .when(not(local.idx_parity))
228+ .assert_eq(next.n_sign, local.n_sign);
108229
109230 // === cum sum folding ===
110- // cur_sum = next_cur_sum * mu + value * multiplier
231+ // Fold recurrence within a proof: cur_sum = value * multiplier + next_cur_sum * mu
111232 assert_array_eq(
112- &mut builder.when(local.is_valid * not(next.is_first) ),
233+ &mut builder.when(is_same_proof ),
113234 local.cur_sum,
114235 ext_field_add::<AB::Expr>(
115236 ext_field_multiply::<AB::Expr>(local.value, local.multiplier),
116237 ext_field_multiply::<AB::Expr>(next.cur_sum, local.mu),
117238 ),
118239 );
240+ // Terminal base case: last row of each proof's fold
241+ assert_array_eq(
242+ &mut builder.when(is_last_in_proof),
243+ local.cur_sum,
244+ ext_field_multiply::<AB::Expr>(local.value, local.multiplier),
245+ );
246+
119247 // multiplier = 1 if not interaction
120248 assert_array_eq(
121- &mut builder.when(not( local.is_interaction) ).when(local.is_valid),
249+ &mut builder.when(local.group_idx ).when(local.is_valid),
122250 local.multiplier,
123251 base_to_ext::<AB::Expr>(AB::Expr::ONE),
124252 );
125253
126254 // IF negative n and numerator
127255 assert_array_eq(
128- &mut builder.when(local.n_sign * (local. is_interaction - local.idx_parity)),
256+ &mut builder.when(local.n_sign * (is_interaction.clone() - local.idx_parity)),
129257 ext_field_multiply_scalar::<AB::Expr>(local.multiplier, local.n_abs_pow),
130258 local.eq_sharp_ns,
131259 );
132- // ELSE 1
260+ // ELSE 1: positive n, interaction row
133261 assert_array_eq(
134- &mut builder.when(local. is_interaction * (AB::Expr::ONE - local.n_sign)),
262+ &mut builder.when(is_interaction.clone() * (AB::Expr::ONE - local.n_sign)),
135263 local.multiplier,
136264 local.eq_sharp_ns,
137265 );
138- // ELSE 2
266+ // ELSE 2: denominator row
139267 assert_array_eq(
140268 &mut builder.when(local.idx_parity),
141269 local.multiplier,
142270 local.eq_sharp_ns,
143271 );
144272
145- // === interactions ===
273+ // === bus interactions ===
146274 self.expr_claim_bus.receive(
147275 builder,
148276 local.proof_idx,
149277 ExpressionClaimMessage {
150- is_interaction: local. is_interaction,
151- idx: local.idx,
152- value: local.value,
278+ is_interaction: is_interaction.clone() ,
279+ idx: local.idx.into() ,
280+ value: local.value.map(Into::into) ,
153281 },
154282 local.is_valid,
155283 );
@@ -193,7 +321,7 @@ where
193321 n_abs: local.n_abs.into(),
194322 n_sign_bit: local.n_sign.into(),
195323 },
196- local.is_valid * (local. is_interaction - local.idx_parity),
324+ local.is_valid * (is_interaction.clone() - local.idx_parity),
197325 );
198326
199327 self.eq_n_outer_bus.lookup_key(
@@ -204,7 +332,7 @@ where
204332 n: local.n_abs * (AB::Expr::ONE - local.n_sign),
205333 value: local.eq_sharp_ns.map(Into::into),
206334 },
207- local.is_valid * local. is_interaction,
335+ local.is_valid * is_interaction.clone() ,
208336 );
209337
210338 self.pow_checker_bus.lookup_key(
@@ -213,7 +341,7 @@ where
213341 log: local.n_abs.into(),
214342 exp: local.n_abs_pow.into(),
215343 },
216- local.is_valid * local. is_interaction,
344+ local.is_valid * is_interaction,
217345 );
218346 }
219347}
0 commit comments