Skip to content

Commit 46bb8d8

Browse files
authored
fix: ensure segments are below thresholds (#2138)
Replace the optimistic execution/segmentation with a checkpointing approach that checkpoints the last `trace_height`/`instret` value that is below the thresholds and use these values for the segments. This should make the segmentation more predictable for downstream usage since the segments should satisfy the thresholds (with the only caveat being if segmentation happens when there is no checkpoint to fall back to i.e. if we overshoot the threshold before the first segmentation check) This requires storing some extra state and results in a higher segment count compared to earlier for the same thresholds. Also makes execution slightly slower since we're doing some extra work now [benchmark run](https://github.com/axiom-crypto/openvm-reth-benchmark/actions/runs/17750607194) with 0.7B max cells [benchmark run](https://github.com/axiom-crypto/openvm-reth-benchmark/actions/runs/17771662984) with 1.2B max cells
1 parent 849e2ac commit 46bb8d8

File tree

2 files changed

+113
-56
lines changed

2 files changed

+113
-56
lines changed

crates/vm/src/arch/execution_mode/metered/ctx.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,6 @@ impl<const PAGE_BITS: usize> MeteredCtx<PAGE_BITS> {
116116

117117
fn reset_segment(&mut self) {
118118
self.memory_ctx.clear();
119-
for (i, &is_constant) in self.is_trace_height_constant.iter().enumerate() {
120-
if !is_constant {
121-
self.trace_heights[i] = 0;
122-
}
123-
}
124119
// Add merkle height contributions for all registers
125120
self.memory_ctx.add_register_merkle_heights();
126121
}
@@ -143,7 +138,7 @@ impl<const PAGE_BITS: usize> MeteredCtx<PAGE_BITS> {
143138
.lazy_update_boundary_heights(&mut self.trace_heights);
144139
let did_segment = self.segmentation_ctx.check_and_segment(
145140
instret,
146-
&self.trace_heights,
141+
&mut self.trace_heights,
147142
&self.is_trace_height_constant,
148143
);
149144

crates/vm/src/arch/execution_mode/metered/segment_ctx.rs

Lines changed: 112 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
55

66
pub const DEFAULT_SEGMENT_CHECK_INSNS: u64 = 1000;
77

8-
pub const DEFAULT_MAX_TRACE_HEIGHT: u32 = (1 << 23) - 10000;
8+
pub const DEFAULT_MAX_TRACE_HEIGHT: u32 = 1 << 23;
99
pub const DEFAULT_MAX_CELLS: usize = 2_000_000_000; // 2B
1010
const DEFAULT_MAX_INTERACTIONS: usize = BabyBear::ORDER_U32 as usize;
1111

@@ -46,6 +46,10 @@ pub struct SegmentationCtx {
4646
pub instret_last_segment_check: u64,
4747
#[getset(set_with = "pub")]
4848
pub segment_check_insns: u64,
49+
/// Checkpoint of trace heights at last known state where all thresholds satisfied
50+
pub(crate) checkpoint_trace_heights: Vec<u32>,
51+
/// Instruction count at the checkpoint
52+
checkpoint_instret: u64,
4953
}
5054

5155
impl SegmentationCtx {
@@ -58,6 +62,7 @@ impl SegmentationCtx {
5862
assert_eq!(air_names.len(), widths.len());
5963
assert_eq!(air_names.len(), interactions.len());
6064

65+
let num_airs = air_names.len();
6166
Self {
6267
segments: Vec::new(),
6368
air_names,
@@ -66,6 +71,8 @@ impl SegmentationCtx {
6671
segmentation_limits,
6772
segment_check_insns: DEFAULT_SEGMENT_CHECK_INSNS,
6873
instret_last_segment_check: 0,
74+
checkpoint_trace_heights: vec![0; num_airs],
75+
checkpoint_instret: 0,
6976
}
7077
}
7178

@@ -77,6 +84,7 @@ impl SegmentationCtx {
7784
assert_eq!(air_names.len(), widths.len());
7885
assert_eq!(air_names.len(), interactions.len());
7986

87+
let num_airs = air_names.len();
8088
Self {
8189
segments: Vec::new(),
8290
air_names,
@@ -85,6 +93,8 @@ impl SegmentationCtx {
8593
segmentation_limits: SegmentationLimits::default(),
8694
segment_check_insns: DEFAULT_SEGMENT_CHECK_INSNS,
8795
instret_last_segment_check: 0,
96+
checkpoint_trace_heights: vec![0; num_airs],
97+
checkpoint_instret: 0,
8898
}
8999
}
90100

@@ -100,37 +110,6 @@ impl SegmentationCtx {
100110
self.segmentation_limits.max_interactions = max_interactions;
101111
}
102112

103-
/// Calculate the total cells used based on trace heights and widths
104-
#[inline(always)]
105-
fn calculate_total_cells(&self, trace_heights: &[u32]) -> usize {
106-
debug_assert_eq!(trace_heights.len(), self.widths.len());
107-
108-
// SAFETY: Length equality is asserted during initialization
109-
let widths_slice = unsafe { self.widths.get_unchecked(..trace_heights.len()) };
110-
111-
trace_heights
112-
.iter()
113-
.zip(widths_slice)
114-
.map(|(&height, &width)| height as usize * width)
115-
.sum()
116-
}
117-
118-
/// Calculate the total interactions based on trace heights and interaction counts
119-
#[inline(always)]
120-
fn calculate_total_interactions(&self, trace_heights: &[u32]) -> usize {
121-
debug_assert_eq!(trace_heights.len(), self.interactions.len());
122-
123-
// SAFETY: Length equality is asserted during initialization
124-
let interactions_slice = unsafe { self.interactions.get_unchecked(..trace_heights.len()) };
125-
126-
trace_heights
127-
.iter()
128-
.zip(interactions_slice)
129-
// We add 1 for the zero messages from the padding rows
130-
.map(|(&height, &interactions)| (height + 1) as usize * interactions)
131-
.sum()
132-
}
133-
134113
#[inline(always)]
135114
fn should_segment(
136115
&self,
@@ -140,6 +119,8 @@ impl SegmentationCtx {
140119
) -> bool {
141120
debug_assert_eq!(trace_heights.len(), is_trace_height_constant.len());
142121
debug_assert_eq!(trace_heights.len(), self.air_names.len());
122+
debug_assert_eq!(trace_heights.len(), self.widths.len());
123+
debug_assert_eq!(trace_heights.len(), self.interactions.len());
143124

144125
let instret_start = self
145126
.segments
@@ -152,44 +133,51 @@ impl SegmentationCtx {
152133
return false;
153134
}
154135

155-
for (i, (&height, is_constant)) in trace_heights
136+
let mut total_cells = 0;
137+
for (i, ((padded_height, width), is_constant)) in trace_heights
156138
.iter()
139+
.map(|&height| height.next_power_of_two())
140+
.zip(self.widths.iter())
157141
.zip(is_trace_height_constant.iter())
158142
.enumerate()
159143
{
160-
// Only segment if the height is not constant and exceeds the maximum height
161-
if !is_constant && height > self.segmentation_limits.max_trace_height {
162-
let air_name = &self.air_names[i];
144+
// Only segment if the height is not constant and exceeds the maximum height after
145+
// padding
146+
if !is_constant && padded_height > self.segmentation_limits.max_trace_height {
147+
let air_name = unsafe { self.air_names.get_unchecked(i) };
163148
tracing::info!(
164-
"Segment {:2} | instret {:9} | chip {} ({}) height ({:8}) > max ({:8})",
165-
self.segments.len(),
149+
"instret {:9} | chip {} ({}) height ({:8}) > max ({:8})",
166150
instret,
167151
i,
168152
air_name,
169-
height,
153+
padded_height,
170154
self.segmentation_limits.max_trace_height
171155
);
172156
return true;
173157
}
158+
total_cells += padded_height as usize * width;
174159
}
175160

176-
let total_cells = self.calculate_total_cells(trace_heights);
177161
if total_cells > self.segmentation_limits.max_cells {
178162
tracing::info!(
179-
"Segment {:2} | instret {:9} | total cells ({:10}) > max ({:10})",
180-
self.segments.len(),
163+
"instret {:9} | total cells ({:10}) > max ({:10})",
181164
instret,
182165
total_cells,
183166
self.segmentation_limits.max_cells
184167
);
185168
return true;
186169
}
187170

188-
let total_interactions = self.calculate_total_interactions(trace_heights);
171+
// All padding rows contribute a single message to the interactions (+1) since
172+
// we assume chips don't send/receive with nonzero multiplicity on padding rows.
173+
let total_interactions: usize = trace_heights
174+
.iter()
175+
.zip(self.interactions.iter())
176+
.map(|(&height, &interactions)| (height + 1) as usize * interactions)
177+
.sum();
189178
if total_interactions > self.segmentation_limits.max_interactions {
190179
tracing::info!(
191-
"Segment {:2} | instret {:9} | total interactions ({:11}) > max ({:11})",
192-
self.segments.len(),
180+
"instret {:9} | total interactions ({:11}) > max ({:11})",
193181
instret,
194182
total_interactions,
195183
self.segmentation_limits.max_interactions
@@ -204,16 +192,84 @@ impl SegmentationCtx {
204192
pub fn check_and_segment(
205193
&mut self,
206194
instret: u64,
207-
trace_heights: &[u32],
195+
trace_heights: &mut [u32],
208196
is_trace_height_constant: &[bool],
209197
) -> bool {
210-
let ret = self.should_segment(instret, trace_heights, is_trace_height_constant);
211-
if ret {
212-
self.segment(instret, trace_heights);
198+
let should_seg = self.should_segment(instret, trace_heights, is_trace_height_constant);
199+
200+
if should_seg {
201+
self.create_segment_from_checkpoint(instret, trace_heights, is_trace_height_constant);
202+
} else {
203+
self.update_checkpoint(instret, trace_heights);
213204
}
205+
214206
self.instret_last_segment_check = instret;
207+
should_seg
208+
}
215209

216-
ret
210+
#[inline(always)]
211+
fn create_segment_from_checkpoint(
212+
&mut self,
213+
instret: u64,
214+
trace_heights: &mut [u32],
215+
is_trace_height_constant: &[bool],
216+
) {
217+
let instret_start = self
218+
.segments
219+
.last()
220+
.map_or(0, |s| s.instret_start + s.num_insns);
221+
222+
let (segment_instret, segment_heights) = if self.checkpoint_instret > instret_start {
223+
(
224+
self.checkpoint_instret,
225+
self.checkpoint_trace_heights.clone(),
226+
)
227+
} else {
228+
// No valid checkpoint, use current values
229+
(instret, trace_heights.to_vec())
230+
};
231+
232+
// Reset current trace heights and checkpoint
233+
self.reset_trace_heights(trace_heights, &segment_heights, is_trace_height_constant);
234+
self.checkpoint_instret = 0;
235+
236+
tracing::info!(
237+
"Segment {:2} | instret {:9} | {} instructions",
238+
self.segments.len(),
239+
instret_start,
240+
segment_instret - instret_start
241+
);
242+
self.segments.push(Segment {
243+
instret_start,
244+
num_insns: segment_instret - instret_start,
245+
trace_heights: segment_heights,
246+
});
247+
}
248+
249+
/// Resets trace heights by subtracting segment heights
250+
#[inline(always)]
251+
fn reset_trace_heights(
252+
&self,
253+
trace_heights: &mut [u32],
254+
segment_heights: &[u32],
255+
is_trace_height_constant: &[bool],
256+
) {
257+
for ((trace_height, &segment_height), &is_trace_height_constant) in trace_heights
258+
.iter_mut()
259+
.zip(segment_heights.iter())
260+
.zip(is_trace_height_constant.iter())
261+
{
262+
if !is_trace_height_constant {
263+
*trace_height = trace_height.checked_sub(segment_height).unwrap();
264+
}
265+
}
266+
}
267+
268+
/// Updates the checkpoint with current safe state
269+
#[inline(always)]
270+
fn update_checkpoint(&mut self, instret: u64, trace_heights: &[u32]) {
271+
self.checkpoint_trace_heights.copy_from_slice(trace_heights);
272+
self.checkpoint_instret = instret;
217273
}
218274

219275
/// Try segment if there is at least one cycle
@@ -227,6 +283,12 @@ impl SegmentationCtx {
227283

228284
debug_assert!(num_insns > 0, "Segment should contain at least one cycle");
229285

286+
tracing::info!(
287+
"Segment {:2} | instret {:9} | {} instructions [FINAL]",
288+
self.segments.len(),
289+
instret_start,
290+
num_insns
291+
);
230292
self.segments.push(Segment {
231293
instret_start,
232294
num_insns,

0 commit comments

Comments
 (0)