Skip to content

Commit 6a83ba3

Browse files
fix: add back num_air_id_lookups
1 parent 09a7977 commit 6a83ba3

File tree

10 files changed

+46
-25
lines changed

10 files changed

+46
-25
lines changed

crates/continuations-v2/src/tests/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,8 @@ fn test_internal_recursive_vk_stabilization(def_hook_commit_set: bool) -> Result
248248
let (_, app_vk) = engine.keygen(&config.create_airs()?.into_airs().collect_vec());
249249
let def_hook_commit = def_hook_commit_set.then_some([F::ZERO; DIGEST_SIZE]);
250250

251-
let leaf_prover = InnerProver::<DEFAULT_MAX_NUM_PROOFS>::new::<Engine>(
251+
const MAX_LEAF_NUM_PROOFS: usize = 3;
252+
let leaf_prover = InnerProver::<MAX_LEAF_NUM_PROOFS>::new::<Engine>(
252253
Arc::new(app_vk),
253254
leaf_system_params(),
254255
false,

crates/recursion/cuda/include/types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ typedef struct {
2020
size_t cached_idx;
2121
size_t starting_cidx;
2222
size_t total_interactions;
23+
size_t num_air_id_lookups;
2324
} TraceMetadata;
2425

2526
typedef struct {

crates/recursion/cuda/src/proof_shape/air.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ template <typename T, size_t MAX_CACHED> struct ProofShapeCols {
6161
T n_max;
6262
T is_n_max_greater;
6363

64+
T num_air_id_lookups;
65+
6466
// T idx_flags[IDX_FLAGS];
6567
T cached_commits[MAX_CACHED][DIGEST_SIZE];
6668
};
@@ -103,6 +105,12 @@ __device__ __forceinline__ void fill_present_row(
103105
size_t lifted_height = max(height, (size_t)(1 << l_skip));
104106
COL_WRITE_VALUE(row, typename Cols<MAX_CACHED>::template Type, is_present, Fp::one());
105107
COL_WRITE_VALUE(row, typename Cols<MAX_CACHED>::template Type, height, height);
108+
COL_WRITE_VALUE(
109+
row,
110+
typename Cols<MAX_CACHED>::template Type,
111+
num_air_id_lookups,
112+
trace_data.num_air_id_lookups
113+
);
106114

107115
Decomp lifted_height_decomp, num_interactions_decomp, total_interactions_decomp;
108116
decompose(lifted_height_decomp, lifted_height);
@@ -191,6 +199,7 @@ __device__ __forceinline__ void fill_non_present_row(
191199
COL_WRITE_VALUE(row, typename Cols<MAX_CACHED>::template Type, starting_cidx, final_cidx);
192200
COL_WRITE_VALUE(row, typename Cols<MAX_CACHED>::template Type, is_present, Fp::zero());
193201
COL_WRITE_VALUE(row, typename Cols<MAX_CACHED>::template Type, height, Fp::zero());
202+
COL_WRITE_VALUE(row, typename Cols<MAX_CACHED>::template Type, num_air_id_lookups, Fp::zero());
194203
row.fill_zero(
195204
COL_INDEX(typename Cols<MAX_CACHED>::template Type, lifted_height_limbs), NUM_LIMBS
196205
);
@@ -232,6 +241,7 @@ __device__ __forceinline__ void fill_summary_row(
232241
COL_WRITE_VALUE(row, typename Cols<MAX_CACHED>::template Type, is_last, Fp::one());
233242
COL_WRITE_VALUE(row, typename Cols<MAX_CACHED>::template Type, sorted_idx, Fp::zero());
234243
COL_WRITE_VALUE(row, typename Cols<MAX_CACHED>::template Type, is_present, Fp::zero());
244+
COL_WRITE_VALUE(row, typename Cols<MAX_CACHED>::template Type, num_air_id_lookups, Fp::zero());
235245
row.fill_zero(cached_commits_idx, MAX_CACHED * DIGEST_SIZE);
236246

237247
Decomp interaction_decomp, max_interaction_decomp;

crates/recursion/src/cuda/preflight.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ use crate::{
1010
cuda::{
1111
to_device_or_nullptr,
1212
types::{TraceHeight, TraceMetadata},
13-
},
14-
system::Preflight,
13+
}, proof_shape::proof_shape::compute_air_shape_lookup_counts, system::Preflight
1514
};
1615

1716
/*
@@ -106,6 +105,8 @@ impl PreflightGpu {
106105
let mut total_interactions = 0;
107106
let l_skip = vk.inner.params.l_skip;
108107

108+
let bc_air_shape_lookups = compute_air_shape_lookup_counts(vk);
109+
109110
let (sorted_trace_heights, sorted_trace_metadata): (Vec<_>, Vec<_>) = preflight
110111
.proof_shape
111112
.sorted_trace_vdata
@@ -119,6 +120,7 @@ impl PreflightGpu {
119120
cached_idx: sorted_cached_commits.len(),
120121
starting_cidx: cidx,
121122
total_interactions,
123+
num_air_id_lookups: bc_air_shape_lookups[*air_idx],
122124
};
123125
cidx += vdata.cached_commitments.len()
124126
+ vk.inner.per_air[*air_idx].preprocessed_data.is_some() as usize;

crates/recursion/src/cuda/types.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub struct TraceMetadata {
1313
pub cached_idx: usize,
1414
pub starting_cidx: usize,
1515
pub total_interactions: usize,
16+
pub num_air_id_lookups: usize,
1617
}
1718

1819
#[repr(C)]

crates/recursion/src/proof_shape/mod.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ pub struct AirMetadata {
4646
need_rot: bool,
4747
num_public_values: usize,
4848
num_interactions: usize,
49-
num_dag_nodes: usize,
5049
main_width: usize,
5150
cached_widths: Vec<usize>,
5251
preprocessed_width: Option<usize>,
@@ -115,7 +114,6 @@ impl ProofShapeModule {
115114
need_rot: avk.params.need_rot,
116115
num_public_values: avk.params.num_public_values,
117116
num_interactions: avk.num_interactions,
118-
num_dag_nodes: avk.num_dag_nodes,
119117
main_width: avk.params.width.common_main,
120118
cached_widths: avk.params.width.cached_mains.clone(),
121119
preprocessed_width: avk.params.width.preprocessed,

crates/recursion/src/proof_shape/proof_shape/air.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ pub struct ProofShapeCols<F, const NUM_LIMBS: usize> {
9090
/// Computed as max(0, n0, n1, ...) where ni = log_height_i - l_skip for each present trace.
9191
pub n_max: F,
9292
pub is_n_max_greater: F,
93+
94+
pub num_air_id_lookups: F,
9395
}
9496

9597
// Variable-length columns are stored at the end
@@ -261,7 +263,6 @@ where
261263
let mut main_common_width = AB::Expr::ZERO;
262264
let mut preprocessed_stacked_width = AB::Expr::ZERO;
263265
let mut cached_widths = vec![AB::Expr::ZERO; self.max_cached];
264-
let mut num_dag_nodes = AB::Expr::ZERO;
265266

266267
// Select values for CommitmentsBus
267268
let mut preprocessed_commit = [AB::Expr::ZERO; DIGEST_SIZE];
@@ -279,7 +280,6 @@ where
279280
air_idx += is_current_air.clone() * AB::F::from_usize(i);
280281
need_rot += is_current_air.clone() * AB::F::from_bool(air_data.need_rot);
281282
main_common_width += is_current_air.clone() * AB::F::from_usize(air_data.main_width);
282-
num_dag_nodes += is_current_air.clone() * AB::F::from_usize(air_data.num_dag_nodes);
283283

284284
if air_data.num_public_values != 0 {
285285
has_pvs += is_current_air.clone();
@@ -498,7 +498,7 @@ where
498498
property_idx: AirShapeProperty::AirId.to_field(),
499499
value: air_idx.clone(),
500500
},
501-
local.is_present * (num_dag_nodes.clone() + AB::Expr::TWO),
501+
local.is_present * (local.num_air_id_lookups + AB::Expr::TWO),
502502
);
503503

504504
self.air_shape_bus.add_key_with_lookups(
@@ -555,7 +555,7 @@ where
555555
n_abs: n_abs.clone(),
556556
n_sign_bit: local.n_sign_bit.into(),
557557
},
558-
local.is_present * (num_dag_nodes + AB::F::ONE),
558+
local.is_present * (local.num_air_id_lookups + AB::F::ONE),
559559
);
560560

561561
///////////////////////////////////////////////////////////////////////////////////////////

crates/recursion/src/proof_shape/proof_shape/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ mod air;
22
mod trace;
33

44
pub use air::*;
5-
pub(in crate::proof_shape) use trace::*;
5+
pub(crate) use trace::*;
66

77
#[cfg(feature = "cuda")]
88
pub(crate) mod cuda;

crates/recursion/src/proof_shape/proof_shape/trace.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,29 @@ use crate::{
1313
},
1414
system::{Preflight, POW_CHECKER_HEIGHT},
1515
tracegen::RowMajorChip,
16+
utils::interaction_length,
1617
};
1718

19+
pub(crate) fn compute_air_shape_lookup_counts(
20+
child_vk: &MultiStarkVerifyingKey<BabyBearPoseidon2Config>,
21+
) -> Vec<usize> {
22+
child_vk
23+
.inner
24+
.per_air
25+
.iter()
26+
.map(|avk| {
27+
let dag = &avk.symbolic_constraints;
28+
dag.constraints.nodes.len()
29+
+ avk.unused_variables.len()
30+
+ dag
31+
.interactions
32+
.iter()
33+
.map(interaction_length)
34+
.sum::<usize>()
35+
})
36+
.collect::<Vec<_>>()
37+
}
38+
1839
#[derive(derive_new::new)]
1940
pub(in crate::proof_shape) struct ProofShapeChip<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
2041
idx_encoder: Arc<Encoder>,
@@ -69,6 +90,7 @@ impl<const NUM_LIMBS: usize, const LIMB_BITS: usize> RowMajorChip<F>
6990
let mut total_interactions = 0usize;
7091
let mut cidx = 1usize;
7192
let mut num_present = 0usize;
93+
let bc_air_shape_lookups = compute_air_shape_lookup_counts(child_vk);
7294

7395
// Present AIRs
7496
for (idx, vdata) in &preflight.proof_shape.sorted_trace_vdata {
@@ -110,6 +132,7 @@ impl<const NUM_LIMBS: usize, const LIMB_BITS: usize> RowMajorChip<F>
110132
total_interactions += num_interactions;
111133

112134
cols.n_max = F::from_usize(preflight.proof_shape.n_max);
135+
cols.num_air_id_lookups = F::from_usize(bc_air_shape_lookups[*idx]);
113136

114137
let vcols: &mut ProofShapeVarColsMut<'_, F> = &mut borrow_var_cols_mut(
115138
&mut chunk[cols_width..],

crates/recursion/src/system/frame.rs

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ use openvm_stark_backend::{
88
};
99
use openvm_stark_sdk::config::baby_bear_poseidon2::{BabyBearPoseidon2Config, Digest, F};
1010

11-
use crate::utils::interaction_length;
12-
1311
/*
1412
* Modified versions of the STARK and multi-STARK verifying keys for AirModule
1513
* implementations. AirModules should use MultiStarkVerifyingKeyFrame instead
@@ -22,7 +20,6 @@ use crate::utils::interaction_length;
2220
pub struct StarkVkeyFrame {
2321
pub preprocessed_data: Option<VerifierSinglePreprocessedData<Digest>>,
2422
pub params: StarkVerifyingParams,
25-
pub num_dag_nodes: usize,
2623
pub num_interactions: usize,
2724
pub max_constraint_degree: u8,
2825
pub is_required: bool,
@@ -40,7 +37,6 @@ impl From<&StarkVerifyingKey<F, Digest>> for StarkVkeyFrame {
4037
Self {
4138
preprocessed_data: vk.preprocessed_data.clone(),
4239
params: vk.params.clone(),
43-
num_dag_nodes: compute_num_dag_nodes(vk),
4440
num_interactions: vk.num_interactions(),
4541
max_constraint_degree: vk.max_constraint_degree,
4642
is_required: vk.is_required,
@@ -57,14 +53,3 @@ impl From<&MultiStarkVerifyingKey<BabyBearPoseidon2Config>> for MultiStarkVkeyFr
5753
}
5854
}
5955
}
60-
61-
fn compute_num_dag_nodes<F, DIGEST>(vk: &StarkVerifyingKey<F, DIGEST>) -> usize {
62-
let dag = &vk.symbolic_constraints;
63-
dag.constraints.nodes.len()
64-
+ vk.unused_variables.len()
65-
+ dag
66-
.interactions
67-
.iter()
68-
.map(interaction_length)
69-
.sum::<usize>()
70-
}

0 commit comments

Comments
 (0)