@@ -3,6 +3,9 @@ use openvm_instructions::riscv::{RV32_NUM_REGISTERS, RV32_REGISTER_AS, RV32_REGI
33
44use crate :: { arch:: SystemConfig , system:: memory:: dimensions:: MemoryDimensions } ;
55
6+ /// Upper bound on number of memory pages accessed per instruction. Used for buffer allocation.
7+ pub const MAX_MEM_PAGE_OPS_PER_INSN : usize = 1 << 16 ;
8+
69#[ derive( Clone , Debug ) ]
710pub struct BitSet {
811 words : Box < [ u64 ] > ,
@@ -99,7 +102,6 @@ impl BitSet {
99102
100103#[ derive( Clone , Debug ) ]
101104pub struct MemoryCtx < const PAGE_BITS : usize > {
102- pub page_indices : BitSet ,
103105 memory_dimensions : MemoryDimensions ,
104106 min_block_size_bits : Vec < u8 > ,
105107 pub boundary_idx : usize ,
@@ -108,22 +110,26 @@ pub struct MemoryCtx<const PAGE_BITS: usize> {
108110 continuations_enabled : bool ,
109111 chunk : u32 ,
110112 chunk_bits : u32 ,
111- pub page_access_count : usize ,
112- // Note: 32 is the maximum access adapter size.
113+ pub page_indices : BitSet ,
113114 pub addr_space_access_count : RVec < usize > ,
115+ pub page_indices_since_checkpoint : Box < [ u32 ] > ,
116+ pub page_indices_since_checkpoint_len : usize ,
114117}
115118
116119impl < const PAGE_BITS : usize > MemoryCtx < PAGE_BITS > {
117- pub fn new ( config : & SystemConfig ) -> Self {
120+ pub fn new ( config : & SystemConfig , segment_check_insns : u64 ) -> Self {
118121 let chunk = config. initial_block_size ( ) as u32 ;
119122 let chunk_bits = chunk. ilog2 ( ) ;
120123
121124 let memory_dimensions = config. memory_config . memory_dimensions ( ) ;
122125 let merkle_height = memory_dimensions. overall_height ( ) ;
123126
127+ let bitset_size = 1 << ( merkle_height. saturating_sub ( PAGE_BITS ) ) ;
128+ let addr_space_size = ( 1 << memory_dimensions. addr_space_height ) + 1 ;
129+ let page_indices_since_checkpoint_cap =
130+ Self :: calculate_checkpoint_capacity ( segment_check_insns) ;
131+
124132 Self {
125- // Address height already considers `chunk_bits`.
126- page_indices : BitSet :: new ( 1 << ( merkle_height. saturating_sub ( PAGE_BITS ) ) ) ,
127133 min_block_size_bits : config. memory_config . min_block_size_bits ( ) ,
128134 boundary_idx : config. memory_boundary_air_id ( ) ,
129135 merkle_tree_index : config. memory_merkle_air_id ( ) ,
@@ -132,14 +138,17 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
132138 chunk_bits,
133139 memory_dimensions,
134140 continuations_enabled : config. continuation_enabled ,
135- page_access_count : 0 ,
136- addr_space_access_count : vec ! [ 0 ; ( 1 << memory_dimensions. addr_space_height) + 1 ] . into ( ) ,
141+ page_indices : BitSet :: new ( bitset_size) ,
142+ addr_space_access_count : vec ! [ 0 ; addr_space_size] . into ( ) ,
143+ page_indices_since_checkpoint : vec ! [ 0 ; page_indices_since_checkpoint_cap]
144+ . into_boxed_slice ( ) ,
145+ page_indices_since_checkpoint_len : 0 ,
137146 }
138147 }
139148
140149 #[ inline( always) ]
141- pub fn clear ( & mut self ) {
142- self . page_indices . clear ( ) ;
150+ pub ( super ) fn calculate_checkpoint_capacity ( segment_check_insns : u64 ) -> usize {
151+ segment_check_insns as usize * MAX_MEM_PAGE_OPS_PER_INSN
143152 }
144153
145154 #[ inline( always) ]
@@ -177,10 +186,23 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
177186 let end_block_id = start_block_id + num_blocks;
178187 let start_page_id = start_block_id >> PAGE_BITS ;
179188 let end_page_id = ( ( end_block_id - 1 ) >> PAGE_BITS ) + 1 ;
189+ assert ! (
190+ self . page_indices_since_checkpoint_len + ( end_page_id - start_page_id) as usize
191+ <= self . page_indices_since_checkpoint. len( ) ,
192+ "more than {MAX_MEM_PAGE_OPS_PER_INSN} memory pages accessed in a single instruction"
193+ ) ;
180194
181195 for page_id in start_page_id..end_page_id {
196+ // Append page_id to page_indices_since_checkpoint
197+ let len = self . page_indices_since_checkpoint_len ;
198+ debug_assert ! ( len < self . page_indices_since_checkpoint. len( ) ) ;
199+ // SAFETY: len is within bounds, and we extend length by 1 after writing.
200+ unsafe {
201+ * self . page_indices_since_checkpoint . as_mut_ptr ( ) . add ( len) = page_id;
202+ }
203+ self . page_indices_since_checkpoint_len = len + 1 ;
204+
182205 if self . page_indices . insert ( page_id as usize ) {
183- self . page_access_count += 1 ;
184206 // SAFETY: address_space passed is usually a hardcoded constant or derived from an
185207 // Instruction where it is bounds checked before passing
186208 unsafe {
@@ -235,13 +257,69 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
235257 }
236258 }
237259
238- /// Resolve all lazy updates of each memory access for memory adapters/poseidon2/merkle chip.
260+ /// Initialize state for a new segment
239261 #[ inline( always) ]
240- pub ( crate ) fn lazy_update_boundary_heights ( & mut self , trace_heights : & mut [ u32 ] ) {
241- debug_assert ! ( self . boundary_idx < trace_heights. len( ) ) ;
262+ pub ( crate ) fn initialize_segment ( & mut self , trace_heights : & mut [ u32 ] ) {
263+ // Clear page indices for the new segment
264+ self . page_indices . clear ( ) ;
265+
266+ // Reset trace heights for memory chips as 0
267+ // SAFETY: boundary_idx is a compile time constant within bounds
268+ unsafe {
269+ * trace_heights. get_unchecked_mut ( self . boundary_idx ) = 0 ;
270+ }
271+ if let Some ( merkle_tree_idx) = self . merkle_tree_index {
272+ // SAFETY: merkle_tree_idx is guaranteed to be in bounds
273+ unsafe {
274+ * trace_heights. get_unchecked_mut ( merkle_tree_idx) = 0 ;
275+ }
276+ let poseidon2_idx = trace_heights. len ( ) - 2 ;
277+ // SAFETY: poseidon2_idx is trace_heights.len() - 2, guaranteed to be in bounds
278+ unsafe {
279+ * trace_heights. get_unchecked_mut ( poseidon2_idx) = 0 ;
280+ }
281+ }
282+
283+ // Apply height updates for all pages accessed since last checkpoint, and
284+ // initialize page_indices for the new segment.
285+ let mut addr_space_access_count = vec ! [ 0 ; self . addr_space_access_count. len( ) ] ;
286+ let pages_len = self . page_indices_since_checkpoint_len ;
287+ for i in 0 ..pages_len {
288+ // SAFETY: i is within 0..pages_len and pages_len is the slice length.
289+ let page_id = unsafe { * self . page_indices_since_checkpoint . get_unchecked ( i) } as usize ;
290+ if self . page_indices . insert ( page_id) {
291+ let ( addr_space, _) = self
292+ . memory_dimensions
293+ . index_to_label ( ( page_id as u64 ) << PAGE_BITS ) ;
294+ let addr_space_idx = addr_space as usize ;
295+ debug_assert ! ( addr_space_idx < addr_space_access_count. len( ) ) ;
296+ // SAFETY: addr_space_idx is bounds checked in debug and derived from a valid page
297+ // id.
298+ unsafe {
299+ * addr_space_access_count. get_unchecked_mut ( addr_space_idx) += 1 ;
300+ }
301+ }
302+ }
303+ self . apply_height_updates ( trace_heights, & addr_space_access_count) ;
304+
305+ // Add merkle height contributions for all registers
306+ self . add_register_merkle_heights ( ) ;
307+ self . lazy_update_boundary_heights ( trace_heights) ;
308+ }
309+
310+ /// Updates the checkpoint with current safe state
311+ #[ inline( always) ]
312+ pub ( crate ) fn update_checkpoint ( & mut self ) {
313+ self . page_indices_since_checkpoint_len = 0 ;
314+ }
315+
316+ /// Apply height updates given page counts
317+ #[ inline( always) ]
318+ fn apply_height_updates ( & self , trace_heights : & mut [ u32 ] , addr_space_access_count : & [ usize ] ) {
319+ let page_access_count: usize = addr_space_access_count. iter ( ) . sum ( ) ;
242320
243321 // On page fault, assume we add all leaves in a page
244- let leaves = ( self . page_access_count << PAGE_BITS ) as u32 ;
322+ let leaves = ( page_access_count << PAGE_BITS ) as u32 ;
245323 // SAFETY: boundary_idx is a compile time constant within bounds
246324 unsafe {
247325 * trace_heights. get_unchecked_mut ( self . boundary_idx ) += leaves;
@@ -261,15 +339,16 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
261339 let nodes = ( ( ( 1 << PAGE_BITS ) - 1 ) + ( merkle_height - PAGE_BITS ) ) as u32 ;
262340 // SAFETY: merkle_tree_idx is guaranteed to be in bounds
263341 unsafe {
264- * trace_heights. get_unchecked_mut ( poseidon2_idx) += nodes * 2 ;
265- * trace_heights. get_unchecked_mut ( merkle_tree_idx) += nodes * 2 ;
342+ * trace_heights. get_unchecked_mut ( poseidon2_idx) +=
343+ nodes * page_access_count as u32 * 2 ;
344+ * trace_heights. get_unchecked_mut ( merkle_tree_idx) +=
345+ nodes * page_access_count as u32 * 2 ;
266346 }
267347 }
268- self . page_access_count = 0 ;
269348
270- for address_space in 0 ..self . addr_space_access_count . len ( ) {
349+ for address_space in 0 ..addr_space_access_count. len ( ) {
271350 // SAFETY: address_space is from 0 to len(), guaranteed to be in bounds
272- let x = unsafe { * self . addr_space_access_count . get_unchecked ( address_space) } ;
351+ let x = unsafe { * addr_space_access_count. get_unchecked ( address_space) } ;
273352 if x > 0 {
274353 // Initial **and** final handling of touched pages requires send (resp. receive) in
275354 // chunk-sized units for the merkle chip
@@ -281,15 +360,23 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
281360 self . chunk_bits ,
282361 ( x << ( PAGE_BITS + 1 ) ) as u32 ,
283362 ) ;
284- // SAFETY: address_space is from 0 to len(), guaranteed to be in bounds
285- unsafe {
286- * self
287- . addr_space_access_count
288- . get_unchecked_mut ( address_space) = 0 ;
289- }
290363 }
291364 }
292365 }
366+
367+ /// Resolve all lazy updates of each memory access for memory adapters/poseidon2/merkle chip.
368+ #[ inline( always) ]
369+ pub ( crate ) fn lazy_update_boundary_heights ( & mut self , trace_heights : & mut [ u32 ] ) {
370+ self . apply_height_updates ( trace_heights, & self . addr_space_access_count ) ;
371+ // SAFETY: Resetting array elements to 0 is always safe
372+ unsafe {
373+ std:: ptr:: write_bytes (
374+ self . addr_space_access_count . as_mut_ptr ( ) ,
375+ 0 ,
376+ self . addr_space_access_count . len ( ) ,
377+ ) ;
378+ }
379+ }
293380}
294381
295382#[ cfg( test) ]
0 commit comments