Skip to content

Commit 51dd038

Browse files
fix: keep track of memory accesses since last checkpoint (#2332)
This pr primarily adds a `page_indices_since_checkpoint` buffer to track page accesses between checkpoints. This buffer is used to initialize the memory trace heights of the next segment Resolves INT-5778 --------- Co-authored-by: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com>
1 parent d7eab70 commit 51dd038

File tree

5 files changed

+225
-111
lines changed

5 files changed

+225
-111
lines changed

crates/vm/src/arch/execution_mode/metered/ctx.rs

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -49,45 +49,41 @@ impl<const PAGE_BITS: usize> MeteredCtx<PAGE_BITS> {
4949
})
5050
.unzip();
5151

52-
let memory_ctx = MemoryCtx::new(config);
52+
let segmentation_ctx =
53+
SegmentationCtx::new(air_names, widths, interactions, config.segmentation_limits);
54+
let memory_ctx = MemoryCtx::new(config, segmentation_ctx.segment_check_insns);
5355

5456
// Assert that the indices are correct
5557
debug_assert!(
56-
air_names[memory_ctx.boundary_idx].contains("Boundary"),
58+
segmentation_ctx.air_names[memory_ctx.boundary_idx].contains("Boundary"),
5759
"air_name={}",
58-
air_names[memory_ctx.boundary_idx]
60+
segmentation_ctx.air_names[memory_ctx.boundary_idx]
5961
);
6062
if let Some(merkle_tree_index) = memory_ctx.merkle_tree_index {
6163
debug_assert!(
62-
air_names[merkle_tree_index].contains("Merkle"),
64+
segmentation_ctx.air_names[merkle_tree_index].contains("Merkle"),
6365
"air_name={}",
64-
air_names[merkle_tree_index]
66+
segmentation_ctx.air_names[merkle_tree_index]
6567
);
6668
}
6769
debug_assert!(
68-
air_names[memory_ctx.adapter_offset].contains("AccessAdapterAir<2>"),
70+
segmentation_ctx.air_names[memory_ctx.adapter_offset].contains("AccessAdapterAir<2>"),
6971
"air_name={}",
70-
air_names[memory_ctx.adapter_offset]
72+
segmentation_ctx.air_names[memory_ctx.adapter_offset]
7173
);
7274

73-
let segmentation_ctx =
74-
SegmentationCtx::new(air_names, widths, interactions, config.segmentation_limits);
75-
7675
let mut ctx = Self {
7776
trace_heights,
7877
is_trace_height_constant,
7978
memory_ctx,
8079
segmentation_ctx,
8180
suspend_on_segment: false,
8281
};
83-
if !config.continuation_enabled {
84-
// force single segment
85-
ctx.segmentation_ctx.segment_check_insns = u64::MAX;
86-
ctx.segmentation_ctx.instrets_until_check = u64::MAX;
87-
}
8882

8983
// Add merkle height contributions for all registers
9084
ctx.memory_ctx.add_register_merkle_heights();
85+
ctx.memory_ctx
86+
.lazy_update_boundary_heights(&mut ctx.trace_heights);
9187

9288
ctx
9389
}
@@ -98,9 +94,8 @@ impl<const PAGE_BITS: usize> MeteredCtx<PAGE_BITS> {
9894
self.segmentation_ctx.set_max_trace_height(max_trace_height);
9995
let max_check_freq = (max_trace_height / 2) as u64;
10096
if max_check_freq < self.segmentation_ctx.segment_check_insns {
101-
self.segmentation_ctx.segment_check_insns = max_check_freq;
97+
self = self.with_segment_check_insns(max_check_freq);
10298
}
103-
self.segmentation_ctx.instrets_until_check = self.segmentation_ctx.segment_check_insns;
10499
self
105100
}
106101

@@ -114,6 +109,20 @@ impl<const PAGE_BITS: usize> MeteredCtx<PAGE_BITS> {
114109
self
115110
}
116111

112+
pub fn with_segment_check_insns(mut self, segment_check_insns: u64) -> Self {
113+
self.segmentation_ctx.segment_check_insns = segment_check_insns;
114+
self.segmentation_ctx.instrets_until_check = segment_check_insns;
115+
116+
// Update memory context with new segment check instructions
117+
let page_indices_since_checkpoint_cap =
118+
MemoryCtx::<PAGE_BITS>::calculate_checkpoint_capacity(segment_check_insns);
119+
120+
self.memory_ctx.page_indices_since_checkpoint =
121+
vec![0; page_indices_since_checkpoint_cap].into_boxed_slice();
122+
self.memory_ctx.page_indices_since_checkpoint_len = 0;
123+
self
124+
}
125+
117126
pub fn segments(&self) -> &[Segment] {
118127
&self.segmentation_ctx.segments
119128
}
@@ -122,12 +131,6 @@ impl<const PAGE_BITS: usize> MeteredCtx<PAGE_BITS> {
122131
self.segmentation_ctx.segments
123132
}
124133

125-
fn reset_segment(&mut self) {
126-
self.memory_ctx.clear();
127-
// Add merkle height contributions for all registers
128-
self.memory_ctx.add_register_merkle_heights();
129-
}
130-
131134
#[inline(always)]
132135
pub fn check_and_segment(&mut self) -> bool {
133136
// We track the segmentation check by instrets_until_check instead of instret in order to
@@ -147,8 +150,40 @@ impl<const PAGE_BITS: usize> MeteredCtx<PAGE_BITS> {
147150
);
148151

149152
if did_segment {
150-
self.reset_segment();
153+
// Initialize contexts for new segment
154+
self.segmentation_ctx
155+
.initialize_segment(&mut self.trace_heights, &self.is_trace_height_constant);
156+
self.memory_ctx.initialize_segment(&mut self.trace_heights);
157+
158+
// Check if the new segment is within limits
159+
if self.segmentation_ctx.should_segment(
160+
self.segmentation_ctx.instret,
161+
&self.trace_heights,
162+
&self.is_trace_height_constant,
163+
) {
164+
let trace_heights_str = self
165+
.trace_heights
166+
.iter()
167+
.zip(self.segmentation_ctx.air_names.iter())
168+
.filter(|(&height, _)| height > 0)
169+
.map(|(&height, name)| format!(" {name} = {height}"))
170+
.collect::<Vec<_>>()
171+
.join("\n");
172+
tracing::warn!(
173+
"Segment initialized with heights that exceed limits\n\
174+
instret={}\n\
175+
trace_heights=[\n{}\n]",
176+
self.segmentation_ctx.instret,
177+
trace_heights_str
178+
);
179+
}
151180
}
181+
182+
// Update checkpoints
183+
self.segmentation_ctx
184+
.update_checkpoint(self.segmentation_ctx.instret, &self.trace_heights);
185+
self.memory_ctx.update_checkpoint();
186+
152187
did_segment
153188
}
154189

crates/vm/src/arch/execution_mode/metered/memory_ctx.rs

Lines changed: 113 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ use openvm_instructions::riscv::{RV32_NUM_REGISTERS, RV32_REGISTER_AS, RV32_REGI
33

44
use crate::{arch::SystemConfig, system::memory::dimensions::MemoryDimensions};
55

6+
/// Upper bound on number of memory pages accessed per instruction. Used for buffer allocation.
7+
pub const MAX_MEM_PAGE_OPS_PER_INSN: usize = 1 << 16;
8+
69
#[derive(Clone, Debug)]
710
pub struct BitSet {
811
words: Box<[u64]>,
@@ -99,7 +102,6 @@ impl BitSet {
99102

100103
#[derive(Clone, Debug)]
101104
pub struct MemoryCtx<const PAGE_BITS: usize> {
102-
pub page_indices: BitSet,
103105
memory_dimensions: MemoryDimensions,
104106
min_block_size_bits: Vec<u8>,
105107
pub boundary_idx: usize,
@@ -108,22 +110,26 @@ pub struct MemoryCtx<const PAGE_BITS: usize> {
108110
continuations_enabled: bool,
109111
chunk: u32,
110112
chunk_bits: u32,
111-
pub page_access_count: usize,
112-
// Note: 32 is the maximum access adapter size.
113+
pub page_indices: BitSet,
113114
pub addr_space_access_count: RVec<usize>,
115+
pub page_indices_since_checkpoint: Box<[u32]>,
116+
pub page_indices_since_checkpoint_len: usize,
114117
}
115118

116119
impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
117-
pub fn new(config: &SystemConfig) -> Self {
120+
pub fn new(config: &SystemConfig, segment_check_insns: u64) -> Self {
118121
let chunk = config.initial_block_size() as u32;
119122
let chunk_bits = chunk.ilog2();
120123

121124
let memory_dimensions = config.memory_config.memory_dimensions();
122125
let merkle_height = memory_dimensions.overall_height();
123126

127+
let bitset_size = 1 << (merkle_height.saturating_sub(PAGE_BITS));
128+
let addr_space_size = (1 << memory_dimensions.addr_space_height) + 1;
129+
let page_indices_since_checkpoint_cap =
130+
Self::calculate_checkpoint_capacity(segment_check_insns);
131+
124132
Self {
125-
// Address height already considers `chunk_bits`.
126-
page_indices: BitSet::new(1 << (merkle_height.saturating_sub(PAGE_BITS))),
127133
min_block_size_bits: config.memory_config.min_block_size_bits(),
128134
boundary_idx: config.memory_boundary_air_id(),
129135
merkle_tree_index: config.memory_merkle_air_id(),
@@ -132,14 +138,17 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
132138
chunk_bits,
133139
memory_dimensions,
134140
continuations_enabled: config.continuation_enabled,
135-
page_access_count: 0,
136-
addr_space_access_count: vec![0; (1 << memory_dimensions.addr_space_height) + 1].into(),
141+
page_indices: BitSet::new(bitset_size),
142+
addr_space_access_count: vec![0; addr_space_size].into(),
143+
page_indices_since_checkpoint: vec![0; page_indices_since_checkpoint_cap]
144+
.into_boxed_slice(),
145+
page_indices_since_checkpoint_len: 0,
137146
}
138147
}
139148

140149
#[inline(always)]
141-
pub fn clear(&mut self) {
142-
self.page_indices.clear();
150+
pub(super) fn calculate_checkpoint_capacity(segment_check_insns: u64) -> usize {
151+
segment_check_insns as usize * MAX_MEM_PAGE_OPS_PER_INSN
143152
}
144153

145154
#[inline(always)]
@@ -177,10 +186,23 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
177186
let end_block_id = start_block_id + num_blocks;
178187
let start_page_id = start_block_id >> PAGE_BITS;
179188
let end_page_id = ((end_block_id - 1) >> PAGE_BITS) + 1;
189+
assert!(
190+
self.page_indices_since_checkpoint_len + (end_page_id - start_page_id) as usize
191+
<= self.page_indices_since_checkpoint.len(),
192+
"more than {MAX_MEM_PAGE_OPS_PER_INSN} memory pages accessed in a single instruction"
193+
);
180194

181195
for page_id in start_page_id..end_page_id {
196+
// Append page_id to page_indices_since_checkpoint
197+
let len = self.page_indices_since_checkpoint_len;
198+
debug_assert!(len < self.page_indices_since_checkpoint.len());
199+
// SAFETY: len is within bounds, and we extend length by 1 after writing.
200+
unsafe {
201+
*self.page_indices_since_checkpoint.as_mut_ptr().add(len) = page_id;
202+
}
203+
self.page_indices_since_checkpoint_len = len + 1;
204+
182205
if self.page_indices.insert(page_id as usize) {
183-
self.page_access_count += 1;
184206
// SAFETY: address_space passed is usually a hardcoded constant or derived from an
185207
// Instruction where it is bounds checked before passing
186208
unsafe {
@@ -235,13 +257,69 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
235257
}
236258
}
237259

238-
/// Resolve all lazy updates of each memory access for memory adapters/poseidon2/merkle chip.
260+
/// Initialize state for a new segment
239261
#[inline(always)]
240-
pub(crate) fn lazy_update_boundary_heights(&mut self, trace_heights: &mut [u32]) {
241-
debug_assert!(self.boundary_idx < trace_heights.len());
262+
pub(crate) fn initialize_segment(&mut self, trace_heights: &mut [u32]) {
263+
// Clear page indices for the new segment
264+
self.page_indices.clear();
265+
266+
// Reset trace heights for memory chips as 0
267+
// SAFETY: boundary_idx is a compile time constant within bounds
268+
unsafe {
269+
*trace_heights.get_unchecked_mut(self.boundary_idx) = 0;
270+
}
271+
if let Some(merkle_tree_idx) = self.merkle_tree_index {
272+
// SAFETY: merkle_tree_idx is guaranteed to be in bounds
273+
unsafe {
274+
*trace_heights.get_unchecked_mut(merkle_tree_idx) = 0;
275+
}
276+
let poseidon2_idx = trace_heights.len() - 2;
277+
// SAFETY: poseidon2_idx is trace_heights.len() - 2, guaranteed to be in bounds
278+
unsafe {
279+
*trace_heights.get_unchecked_mut(poseidon2_idx) = 0;
280+
}
281+
}
282+
283+
// Apply height updates for all pages accessed since last checkpoint, and
284+
// initialize page_indices for the new segment.
285+
let mut addr_space_access_count = vec![0; self.addr_space_access_count.len()];
286+
let pages_len = self.page_indices_since_checkpoint_len;
287+
for i in 0..pages_len {
288+
// SAFETY: i is within 0..pages_len and pages_len is the slice length.
289+
let page_id = unsafe { *self.page_indices_since_checkpoint.get_unchecked(i) } as usize;
290+
if self.page_indices.insert(page_id) {
291+
let (addr_space, _) = self
292+
.memory_dimensions
293+
.index_to_label((page_id as u64) << PAGE_BITS);
294+
let addr_space_idx = addr_space as usize;
295+
debug_assert!(addr_space_idx < addr_space_access_count.len());
296+
// SAFETY: addr_space_idx is bounds checked in debug and derived from a valid page
297+
// id.
298+
unsafe {
299+
*addr_space_access_count.get_unchecked_mut(addr_space_idx) += 1;
300+
}
301+
}
302+
}
303+
self.apply_height_updates(trace_heights, &addr_space_access_count);
304+
305+
// Add merkle height contributions for all registers
306+
self.add_register_merkle_heights();
307+
self.lazy_update_boundary_heights(trace_heights);
308+
}
309+
310+
/// Updates the checkpoint with current safe state
311+
#[inline(always)]
312+
pub(crate) fn update_checkpoint(&mut self) {
313+
self.page_indices_since_checkpoint_len = 0;
314+
}
315+
316+
/// Apply height updates given page counts
317+
#[inline(always)]
318+
fn apply_height_updates(&self, trace_heights: &mut [u32], addr_space_access_count: &[usize]) {
319+
let page_access_count: usize = addr_space_access_count.iter().sum();
242320

243321
// On page fault, assume we add all leaves in a page
244-
let leaves = (self.page_access_count << PAGE_BITS) as u32;
322+
let leaves = (page_access_count << PAGE_BITS) as u32;
245323
// SAFETY: boundary_idx is a compile time constant within bounds
246324
unsafe {
247325
*trace_heights.get_unchecked_mut(self.boundary_idx) += leaves;
@@ -261,15 +339,16 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
261339
let nodes = (((1 << PAGE_BITS) - 1) + (merkle_height - PAGE_BITS)) as u32;
262340
// SAFETY: merkle_tree_idx is guaranteed to be in bounds
263341
unsafe {
264-
*trace_heights.get_unchecked_mut(poseidon2_idx) += nodes * 2;
265-
*trace_heights.get_unchecked_mut(merkle_tree_idx) += nodes * 2;
342+
*trace_heights.get_unchecked_mut(poseidon2_idx) +=
343+
nodes * page_access_count as u32 * 2;
344+
*trace_heights.get_unchecked_mut(merkle_tree_idx) +=
345+
nodes * page_access_count as u32 * 2;
266346
}
267347
}
268-
self.page_access_count = 0;
269348

270-
for address_space in 0..self.addr_space_access_count.len() {
349+
for address_space in 0..addr_space_access_count.len() {
271350
// SAFETY: address_space is from 0 to len(), guaranteed to be in bounds
272-
let x = unsafe { *self.addr_space_access_count.get_unchecked(address_space) };
351+
let x = unsafe { *addr_space_access_count.get_unchecked(address_space) };
273352
if x > 0 {
274353
// Initial **and** final handling of touched pages requires send (resp. receive) in
275354
// chunk-sized units for the merkle chip
@@ -281,15 +360,23 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
281360
self.chunk_bits,
282361
(x << (PAGE_BITS + 1)) as u32,
283362
);
284-
// SAFETY: address_space is from 0 to len(), guaranteed to be in bounds
285-
unsafe {
286-
*self
287-
.addr_space_access_count
288-
.get_unchecked_mut(address_space) = 0;
289-
}
290363
}
291364
}
292365
}
366+
367+
/// Resolve all lazy updates of each memory access for memory adapters/poseidon2/merkle chip.
368+
#[inline(always)]
369+
pub(crate) fn lazy_update_boundary_heights(&mut self, trace_heights: &mut [u32]) {
370+
self.apply_height_updates(trace_heights, &self.addr_space_access_count);
371+
// SAFETY: Resetting array elements to 0 is always safe
372+
unsafe {
373+
std::ptr::write_bytes(
374+
self.addr_space_access_count.as_mut_ptr(),
375+
0,
376+
self.addr_space_access_count.len(),
377+
);
378+
}
379+
}
293380
}
294381

295382
#[cfg(test)]

0 commit comments

Comments
 (0)