diff --git a/crates/vm/src/arch/execution_mode/metered/ctx.rs b/crates/vm/src/arch/execution_mode/metered/ctx.rs index 8428438ca7..ce1144bfbf 100644 --- a/crates/vm/src/arch/execution_mode/metered/ctx.rs +++ b/crates/vm/src/arch/execution_mode/metered/ctx.rs @@ -49,30 +49,29 @@ impl MeteredCtx { }) .unzip(); - let memory_ctx = MemoryCtx::new(config); + let segmentation_ctx = + SegmentationCtx::new(air_names, widths, interactions, config.segmentation_limits); + let memory_ctx = MemoryCtx::new(config, segmentation_ctx.segment_check_insns); // Assert that the indices are correct debug_assert!( - air_names[memory_ctx.boundary_idx].contains("Boundary"), + segmentation_ctx.air_names[memory_ctx.boundary_idx].contains("Boundary"), "air_name={}", - air_names[memory_ctx.boundary_idx] + segmentation_ctx.air_names[memory_ctx.boundary_idx] ); if let Some(merkle_tree_index) = memory_ctx.merkle_tree_index { debug_assert!( - air_names[merkle_tree_index].contains("Merkle"), + segmentation_ctx.air_names[merkle_tree_index].contains("Merkle"), "air_name={}", - air_names[merkle_tree_index] + segmentation_ctx.air_names[merkle_tree_index] ); } debug_assert!( - air_names[memory_ctx.adapter_offset].contains("AccessAdapterAir<2>"), + segmentation_ctx.air_names[memory_ctx.adapter_offset].contains("AccessAdapterAir<2>"), "air_name={}", - air_names[memory_ctx.adapter_offset] + segmentation_ctx.air_names[memory_ctx.adapter_offset] ); - let segmentation_ctx = - SegmentationCtx::new(air_names, widths, interactions, config.segmentation_limits); - let mut ctx = Self { trace_heights, is_trace_height_constant, @@ -80,14 +79,11 @@ impl MeteredCtx { segmentation_ctx, suspend_on_segment: false, }; - if !config.continuation_enabled { - // force single segment - ctx.segmentation_ctx.segment_check_insns = u64::MAX; - ctx.segmentation_ctx.instrets_until_check = u64::MAX; - } // Add merkle height contributions for all registers ctx.memory_ctx.add_register_merkle_heights(); + ctx.memory_ctx + .lazy_update_boundary_heights(&mut ctx.trace_heights); ctx } @@ -98,9 +94,8 @@ impl MeteredCtx { self.segmentation_ctx.set_max_trace_height(max_trace_height); let max_check_freq = (max_trace_height / 2) as u64; if max_check_freq < self.segmentation_ctx.segment_check_insns { - self.segmentation_ctx.segment_check_insns = max_check_freq; + self = self.with_segment_check_insns(max_check_freq); } - self.segmentation_ctx.instrets_until_check = self.segmentation_ctx.segment_check_insns; self } @@ -114,6 +109,20 @@ impl MeteredCtx { self } + pub fn with_segment_check_insns(mut self, segment_check_insns: u64) -> Self { + self.segmentation_ctx.segment_check_insns = segment_check_insns; + self.segmentation_ctx.instrets_until_check = segment_check_insns; + + // Update memory context with new segment check instructions + let page_indices_since_checkpoint_cap = + MemoryCtx::::calculate_checkpoint_capacity(segment_check_insns); + + self.memory_ctx.page_indices_since_checkpoint = + vec![0; page_indices_since_checkpoint_cap].into_boxed_slice(); + self.memory_ctx.page_indices_since_checkpoint_len = 0; + self + } + pub fn segments(&self) -> &[Segment] { &self.segmentation_ctx.segments } @@ -122,12 +131,6 @@ impl MeteredCtx { self.segmentation_ctx.segments } - fn reset_segment(&mut self) { - self.memory_ctx.clear(); - // Add merkle height contributions for all registers - self.memory_ctx.add_register_merkle_heights(); - } - #[inline(always)] pub fn check_and_segment(&mut self) -> bool { // We track the segmentation check by instrets_until_check instead of instret in order to @@ -147,8 +150,40 @@ impl MeteredCtx { ); if did_segment { - self.reset_segment(); + // Initialize contexts for new segment + self.segmentation_ctx + .initialize_segment(&mut self.trace_heights, &self.is_trace_height_constant); + self.memory_ctx.initialize_segment(&mut self.trace_heights); + + // Check if the new segment is within limits + if self.segmentation_ctx.should_segment( + self.segmentation_ctx.instret, + &self.trace_heights, + &self.is_trace_height_constant, + ) { + let trace_heights_str = self + .trace_heights + .iter() + .zip(self.segmentation_ctx.air_names.iter()) + .filter(|(&height, _)| height > 0) + .map(|(&height, name)| format!(" {name} = {height}")) + .collect::>() + .join("\n"); + tracing::warn!( + "Segment initialized with heights that exceed limits\n\ + instret={}\n\ + trace_heights=[\n{}\n]", + self.segmentation_ctx.instret, + trace_heights_str + ); + } } + + // Update checkpoints + self.segmentation_ctx + .update_checkpoint(self.segmentation_ctx.instret, &self.trace_heights); + self.memory_ctx.update_checkpoint(); + did_segment } diff --git a/crates/vm/src/arch/execution_mode/metered/memory_ctx.rs b/crates/vm/src/arch/execution_mode/metered/memory_ctx.rs index 3429177d11..68e2c6a665 100644 --- a/crates/vm/src/arch/execution_mode/metered/memory_ctx.rs +++ b/crates/vm/src/arch/execution_mode/metered/memory_ctx.rs @@ -3,6 +3,9 @@ use openvm_instructions::riscv::{RV32_NUM_REGISTERS, RV32_REGISTER_AS, RV32_REGI use crate::{arch::SystemConfig, system::memory::dimensions::MemoryDimensions}; +/// Upper bound on number of memory pages accessed per instruction. Used for buffer allocation. +pub const MAX_MEM_PAGE_OPS_PER_INSN: usize = 1 << 16; + #[derive(Clone, Debug)] pub struct BitSet { words: Box<[u64]>, @@ -99,7 +102,6 @@ impl BitSet { #[derive(Clone, Debug)] pub struct MemoryCtx { - pub page_indices: BitSet, memory_dimensions: MemoryDimensions, min_block_size_bits: Vec, pub boundary_idx: usize, @@ -108,22 +110,26 @@ pub struct MemoryCtx { continuations_enabled: bool, chunk: u32, chunk_bits: u32, - pub page_access_count: usize, - // Note: 32 is the maximum access adapter size. + pub page_indices: BitSet, pub addr_space_access_count: RVec, + pub page_indices_since_checkpoint: Box<[u32]>, + pub page_indices_since_checkpoint_len: usize, } impl MemoryCtx { - pub fn new(config: &SystemConfig) -> Self { + pub fn new(config: &SystemConfig, segment_check_insns: u64) -> Self { let chunk = config.initial_block_size() as u32; let chunk_bits = chunk.ilog2(); let memory_dimensions = config.memory_config.memory_dimensions(); let merkle_height = memory_dimensions.overall_height(); + let bitset_size = 1 << (merkle_height.saturating_sub(PAGE_BITS)); + let addr_space_size = (1 << memory_dimensions.addr_space_height) + 1; + let page_indices_since_checkpoint_cap = + Self::calculate_checkpoint_capacity(segment_check_insns); + Self { - // Address height already considers `chunk_bits`. - page_indices: BitSet::new(1 << (merkle_height.saturating_sub(PAGE_BITS))), min_block_size_bits: config.memory_config.min_block_size_bits(), boundary_idx: config.memory_boundary_air_id(), merkle_tree_index: config.memory_merkle_air_id(), @@ -132,14 +138,17 @@ impl MemoryCtx { chunk_bits, memory_dimensions, continuations_enabled: config.continuation_enabled, - page_access_count: 0, - addr_space_access_count: vec![0; (1 << memory_dimensions.addr_space_height) + 1].into(), + page_indices: BitSet::new(bitset_size), + addr_space_access_count: vec![0; addr_space_size].into(), + page_indices_since_checkpoint: vec![0; page_indices_since_checkpoint_cap] + .into_boxed_slice(), + page_indices_since_checkpoint_len: 0, } } #[inline(always)] - pub fn clear(&mut self) { - self.page_indices.clear(); + pub(super) fn calculate_checkpoint_capacity(segment_check_insns: u64) -> usize { + segment_check_insns as usize * MAX_MEM_PAGE_OPS_PER_INSN } #[inline(always)] @@ -177,10 +186,23 @@ impl MemoryCtx { let end_block_id = start_block_id + num_blocks; let start_page_id = start_block_id >> PAGE_BITS; let end_page_id = ((end_block_id - 1) >> PAGE_BITS) + 1; + assert!( + self.page_indices_since_checkpoint_len + (end_page_id - start_page_id) as usize + <= self.page_indices_since_checkpoint.len(), + "more than {MAX_MEM_PAGE_OPS_PER_INSN} memory pages accessed in a single instruction" + ); for page_id in start_page_id..end_page_id { + // Append page_id to page_indices_since_checkpoint + let len = self.page_indices_since_checkpoint_len; + debug_assert!(len < self.page_indices_since_checkpoint.len()); + // SAFETY: len is within bounds, and we extend length by 1 after writing. + unsafe { + *self.page_indices_since_checkpoint.as_mut_ptr().add(len) = page_id; + } + self.page_indices_since_checkpoint_len = len + 1; + if self.page_indices.insert(page_id as usize) { - self.page_access_count += 1; // SAFETY: address_space passed is usually a hardcoded constant or derived from an // Instruction where it is bounds checked before passing unsafe { @@ -235,13 +257,69 @@ impl MemoryCtx { } } - /// Resolve all lazy updates of each memory access for memory adapters/poseidon2/merkle chip. + /// Initialize state for a new segment #[inline(always)] - pub(crate) fn lazy_update_boundary_heights(&mut self, trace_heights: &mut [u32]) { - debug_assert!(self.boundary_idx < trace_heights.len()); + pub(crate) fn initialize_segment(&mut self, trace_heights: &mut [u32]) { + // Clear page indices for the new segment + self.page_indices.clear(); + + // Reset trace heights for memory chips as 0 + // SAFETY: boundary_idx is a compile time constant within bounds + unsafe { + *trace_heights.get_unchecked_mut(self.boundary_idx) = 0; + } + if let Some(merkle_tree_idx) = self.merkle_tree_index { + // SAFETY: merkle_tree_idx is guaranteed to be in bounds + unsafe { + *trace_heights.get_unchecked_mut(merkle_tree_idx) = 0; + } + let poseidon2_idx = trace_heights.len() - 2; + // SAFETY: poseidon2_idx is trace_heights.len() - 2, guaranteed to be in bounds + unsafe { + *trace_heights.get_unchecked_mut(poseidon2_idx) = 0; + } + } + + // Apply height updates for all pages accessed since last checkpoint, and + // initialize page_indices for the new segment. + let mut addr_space_access_count = vec![0; self.addr_space_access_count.len()]; + let pages_len = self.page_indices_since_checkpoint_len; + for i in 0..pages_len { + // SAFETY: i is within 0..pages_len and pages_len is the slice length. + let page_id = unsafe { *self.page_indices_since_checkpoint.get_unchecked(i) } as usize; + if self.page_indices.insert(page_id) { + let (addr_space, _) = self + .memory_dimensions + .index_to_label((page_id as u64) << PAGE_BITS); + let addr_space_idx = addr_space as usize; + debug_assert!(addr_space_idx < addr_space_access_count.len()); + // SAFETY: addr_space_idx is bounds checked in debug and derived from a valid page + // id. + unsafe { + *addr_space_access_count.get_unchecked_mut(addr_space_idx) += 1; + } + } + } + self.apply_height_updates(trace_heights, &addr_space_access_count); + + // Add merkle height contributions for all registers + self.add_register_merkle_heights(); + self.lazy_update_boundary_heights(trace_heights); + } + + /// Updates the checkpoint with current safe state + #[inline(always)] + pub(crate) fn update_checkpoint(&mut self) { + self.page_indices_since_checkpoint_len = 0; + } + + /// Apply height updates given page counts + #[inline(always)] + fn apply_height_updates(&self, trace_heights: &mut [u32], addr_space_access_count: &[usize]) { + let page_access_count: usize = addr_space_access_count.iter().sum(); // On page fault, assume we add all leaves in a page - let leaves = (self.page_access_count << PAGE_BITS) as u32; + let leaves = (page_access_count << PAGE_BITS) as u32; // SAFETY: boundary_idx is a compile time constant within bounds unsafe { *trace_heights.get_unchecked_mut(self.boundary_idx) += leaves; @@ -261,15 +339,16 @@ impl MemoryCtx { let nodes = (((1 << PAGE_BITS) - 1) + (merkle_height - PAGE_BITS)) as u32; // SAFETY: merkle_tree_idx is guaranteed to be in bounds unsafe { - *trace_heights.get_unchecked_mut(poseidon2_idx) += nodes * 2; - *trace_heights.get_unchecked_mut(merkle_tree_idx) += nodes * 2; + *trace_heights.get_unchecked_mut(poseidon2_idx) += + nodes * page_access_count as u32 * 2; + *trace_heights.get_unchecked_mut(merkle_tree_idx) += + nodes * page_access_count as u32 * 2; } } - self.page_access_count = 0; - for address_space in 0..self.addr_space_access_count.len() { + for address_space in 0..addr_space_access_count.len() { // SAFETY: address_space is from 0 to len(), guaranteed to be in bounds - let x = unsafe { *self.addr_space_access_count.get_unchecked(address_space) }; + let x = unsafe { *addr_space_access_count.get_unchecked(address_space) }; if x > 0 { // Initial **and** final handling of touched pages requires send (resp. receive) in // chunk-sized units for the merkle chip @@ -281,15 +360,23 @@ impl MemoryCtx { self.chunk_bits, (x << (PAGE_BITS + 1)) as u32, ); - // SAFETY: address_space is from 0 to len(), guaranteed to be in bounds - unsafe { - *self - .addr_space_access_count - .get_unchecked_mut(address_space) = 0; - } } } } + + /// Resolve all lazy updates of each memory access for memory adapters/poseidon2/merkle chip. + #[inline(always)] + pub(crate) fn lazy_update_boundary_heights(&mut self, trace_heights: &mut [u32]) { + self.apply_height_updates(trace_heights, &self.addr_space_access_count); + // SAFETY: Resetting array elements to 0 is always safe + unsafe { + std::ptr::write_bytes( + self.addr_space_access_count.as_mut_ptr(), + 0, + self.addr_space_access_count.len(), + ); + } + } } #[cfg(test)] diff --git a/crates/vm/src/arch/execution_mode/metered/segment_ctx.rs b/crates/vm/src/arch/execution_mode/metered/segment_ctx.rs index db1aecfa31..0b34ad3f99 100644 --- a/crates/vm/src/arch/execution_mode/metered/segment_ctx.rs +++ b/crates/vm/src/arch/execution_mode/metered/segment_ctx.rs @@ -68,8 +68,7 @@ pub struct SegmentationCtx { pub(crate) segmentation_limits: SegmentationLimits, pub instret: u64, pub instrets_until_check: u64, - #[getset(set_with = "pub")] - pub segment_check_insns: u64, + pub(super) segment_check_insns: u64, /// Checkpoint of trace heights at last known state where all thresholds satisfied pub(crate) checkpoint_trace_heights: Vec, /// Instruction count at the checkpoint @@ -176,7 +175,7 @@ impl SegmentationCtx { } #[inline(always)] - fn should_segment( + pub(crate) fn should_segment( &self, instret: u64, trace_heights: &[u32], @@ -257,20 +256,13 @@ impl SegmentationCtx { let should_seg = self.should_segment(instret, trace_heights, is_trace_height_constant); if should_seg { - self.create_segment_from_checkpoint(instret, trace_heights, is_trace_height_constant); - } else { - self.update_checkpoint(instret, trace_heights); + self.create_segment_from_checkpoint(instret, trace_heights); } should_seg } #[inline(always)] - fn create_segment_from_checkpoint( - &mut self, - instret: u64, - trace_heights: &mut [u32], - is_trace_height_constant: &[bool], - ) { + fn create_segment_from_checkpoint(&mut self, instret: u64, trace_heights: &mut [u32]) { let instret_start = self .segments .last() @@ -296,14 +288,26 @@ impl SegmentationCtx { (instret, trace_heights.to_vec()) }; - // Reset current trace heights and checkpoint - self.reset_trace_heights(trace_heights, &segment_heights, is_trace_height_constant); - self.checkpoint_instret = 0; - let num_insns = segment_instret - instret_start; self.create_segment::(instret_start, num_insns, segment_heights); } + /// Initialize state for a new segment + #[inline(always)] + pub(crate) fn initialize_segment( + &mut self, + trace_heights: &mut [u32], + is_trace_height_constant: &[bool], + ) { + // Reset trace heights by subtracting the last segment's heights + let last_segment = self.segments.last().unwrap(); + self.reset_trace_heights( + trace_heights, + &last_segment.trace_heights, + is_trace_height_constant, + ); + } + /// Resets trace heights by subtracting segment heights #[inline(always)] fn reset_trace_heights( @@ -325,7 +329,7 @@ impl SegmentationCtx { /// Updates the checkpoint with current safe state #[inline(always)] - fn update_checkpoint(&mut self, instret: u64, trace_heights: &[u32]) { + pub(crate) fn update_checkpoint(&mut self, instret: u64, trace_heights: &[u32]) { self.checkpoint_trace_heights.copy_from_slice(trace_heights); self.checkpoint_instret = instret; } diff --git a/extensions/native/circuit/tests/integration_test.rs b/extensions/native/circuit/tests/integration_test.rs index d0d56f6d4c..6cb64077e8 100644 --- a/extensions/native/circuit/tests/integration_test.rs +++ b/extensions/native/circuit/tests/integration_test.rs @@ -11,7 +11,6 @@ use openvm_circuit::system::cuda::extensions::SystemGpuBuilder as SystemBuilder; use openvm_circuit::{arch::RowMajorMatrixArena, system::SystemCpuBuilder as SystemBuilder}; use openvm_circuit::{ arch::{ - execution_mode::metered::segment_ctx::{SegmentationLimits, DEFAULT_SEGMENT_CHECK_INSNS}, hasher::{poseidon2::vm_poseidon2_hasher, Hasher}, verify_segments, verify_single, AirInventory, ContinuationVmProver, PreflightExecutionOutput, SingleSegmentVmProver, VirtualMachine, VmCircuitConfig, @@ -948,41 +947,6 @@ fn test_vm_execute_native_chips() { .expect("Failed to execute"); } -// This test ensures that metered execution never segments when continuations is disabled -#[test] -fn test_single_segment_executor_no_segmentation() { - setup_tracing(); - - let mut config = test_native_config(); - config - .system - .set_segmentation_limits(SegmentationLimits::default().with_max_trace_height(1)); - - let engine = TestEngine::new(FriParameters::new_for_testing(3)); - let (vm, _) = - VirtualMachine::new_with_keygen(engine, NativeBuilder::default(), config).unwrap(); - let instructions: Vec<_> = (0..2 * DEFAULT_SEGMENT_CHECK_INSNS) - .map(|_| Instruction::large_from_isize(ADD.global_opcode(), 0, 0, 1, 4, 0, 0, 0)) - .chain(std::iter::once(Instruction::from_isize( - TERMINATE.global_opcode(), - 0, - 0, - 0, - 0, - 0, - ))) - .collect(); - - let exe = VmExe::new(Program::from_instructions(&instructions)); - let executor_idx_to_air_idx = vm.executor_idx_to_air_idx(); - let metered_ctx = vm.build_metered_ctx(&exe); - vm.executor() - .metered_instance(&exe, &executor_idx_to_air_idx) - .unwrap() - .execute_metered(vec![], metered_ctx) - .unwrap(); -} - #[test] fn test_vm_execute_metered_cost_native_chips() { type F = BabyBear; diff --git a/extensions/rv32im/circuit/src/common/mod.rs b/extensions/rv32im/circuit/src/common/mod.rs index 0a58b7310b..827eb6b76f 100644 --- a/extensions/rv32im/circuit/src/common/mod.rs +++ b/extensions/rv32im/circuit/src/common/mod.rs @@ -212,8 +212,15 @@ mod aot { // let end_page_id = ((end_block_id - 1) >> PAGE_BITS) + 1; // for page_id in start_page_id..end_page_id { + // // Append page_id to page_indices_since_checkpoint + // let len = self.page_indices_since_checkpoint_len; + // // SAFETY: len is within bounds, and we extend length by 1 after writing. + // unsafe { + // *self.page_indices_since_checkpoint.as_mut_ptr().add(len) = page_id; + // } + // self.page_indices_since_checkpoint_len = len + 1; + // // if self.page_indices.insert(page_id as usize) { - // self.page_access_count += 1; // // SAFETY: address_space passed is usually a hardcoded constant or derived // from an // Instruction where it is bounds checked before passing // unsafe { @@ -266,11 +273,33 @@ mod aot { + offset_of!(MeteredCtx, memory_ctx); let page_indices_ptr_offset = memory_ctx_offset + offset_of!(MemoryCtx, page_indices); - let page_access_count_offset = - memory_ctx_offset + offset_of!(MemoryCtx, page_access_count); let addr_space_access_count_ptr_offset = memory_ctx_offset + offset_of!(MemoryCtx, addr_space_access_count); + let page_indices_since_checkpoint_ptr_offset = memory_ctx_offset + + offset_of!(MemoryCtx, page_indices_since_checkpoint); + let page_indices_since_checkpoint_len_offset = memory_ctx_offset + + offset_of!( + MemoryCtx, + page_indices_since_checkpoint_len + ); let inserted_label = format!(".asm_execute_pc_{pc}_inserted"); + + // Append page_id to page_indices_since_checkpoint + asm_str += &format!( + " mov {reg1}, [{REG_EXEC_STATE_PTR} + {page_indices_since_checkpoint_len_offset}]\n" + ); + asm_str += &format!( + " mov {reg2}, [{REG_EXEC_STATE_PTR} + {page_indices_since_checkpoint_ptr_offset}]\n" + ); + let ptr_reg_32 = convert_x86_reg(ptr_reg, Width::W32).ok_or_else(|| { + AotError::Other(format!("unsupported ptr_reg for 32-bit store: {ptr_reg}")) + })?; + asm_str += &format!(" mov dword ptr [{reg2} + {reg1} * 4], {ptr_reg_32}\n"); + asm_str += &format!(" add {reg1}, 1\n"); + asm_str += &format!( + " mov [{REG_EXEC_STATE_PTR} + {page_indices_since_checkpoint_len_offset}], {reg1}\n" + ); + // The next section is the implementation of `BitSet::insert` in ASM. // pub fn insert(&mut self, index: usize) -> bool { // let word_index = index >> 6; @@ -307,11 +336,6 @@ mod aot { // `*word += mask` asm_str += &format!(" add {ptr_reg}, {reg2}\n"); asm_str += &format!(" mov [{reg1}], {ptr_reg}\n"); - // reg1 = &self.page_access_count` - asm_str += - &format!(" lea {reg1}, [{REG_EXEC_STATE_PTR} + {page_access_count_offset}]\n"); - // self.page_access_count += 1; - asm_str += &format!(" add dword ptr [{reg1}], 1\n"); // reg1 = &addr_space_access_count.as_ptr() asm_str += &format!( " lea {reg1}, [{REG_EXEC_STATE_PTR} + {addr_space_access_count_ptr_offset}]\n"