Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 59 additions & 24 deletions crates/vm/src/arch/execution_mode/metered/ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,45 +49,41 @@ impl<const PAGE_BITS: usize> MeteredCtx<PAGE_BITS> {
})
.unzip();

let memory_ctx = MemoryCtx::new(config);
let segmentation_ctx =
SegmentationCtx::new(air_names, widths, interactions, config.segmentation_limits);
let memory_ctx = MemoryCtx::new(config, segmentation_ctx.segment_check_insns);

// Assert that the indices are correct
debug_assert!(
air_names[memory_ctx.boundary_idx].contains("Boundary"),
segmentation_ctx.air_names[memory_ctx.boundary_idx].contains("Boundary"),
"air_name={}",
air_names[memory_ctx.boundary_idx]
segmentation_ctx.air_names[memory_ctx.boundary_idx]
);
if let Some(merkle_tree_index) = memory_ctx.merkle_tree_index {
debug_assert!(
air_names[merkle_tree_index].contains("Merkle"),
segmentation_ctx.air_names[merkle_tree_index].contains("Merkle"),
"air_name={}",
air_names[merkle_tree_index]
segmentation_ctx.air_names[merkle_tree_index]
);
}
debug_assert!(
air_names[memory_ctx.adapter_offset].contains("AccessAdapterAir<2>"),
segmentation_ctx.air_names[memory_ctx.adapter_offset].contains("AccessAdapterAir<2>"),
"air_name={}",
air_names[memory_ctx.adapter_offset]
segmentation_ctx.air_names[memory_ctx.adapter_offset]
);

let segmentation_ctx =
SegmentationCtx::new(air_names, widths, interactions, config.segmentation_limits);

let mut ctx = Self {
trace_heights,
is_trace_height_constant,
memory_ctx,
segmentation_ctx,
suspend_on_segment: false,
};
if !config.continuation_enabled {
// force single segment
ctx.segmentation_ctx.segment_check_insns = u64::MAX;
ctx.segmentation_ctx.instrets_until_check = u64::MAX;
}

// Add merkle height contributions for all registers
ctx.memory_ctx.add_register_merkle_heights();
ctx.memory_ctx
.lazy_update_boundary_heights(&mut ctx.trace_heights);

ctx
}
Expand All @@ -98,9 +94,8 @@ impl<const PAGE_BITS: usize> MeteredCtx<PAGE_BITS> {
self.segmentation_ctx.set_max_trace_height(max_trace_height);
let max_check_freq = (max_trace_height / 2) as u64;
if max_check_freq < self.segmentation_ctx.segment_check_insns {
self.segmentation_ctx.segment_check_insns = max_check_freq;
self = self.with_segment_check_insns(max_check_freq);
}
self.segmentation_ctx.instrets_until_check = self.segmentation_ctx.segment_check_insns;
self
}

Expand All @@ -114,6 +109,20 @@ impl<const PAGE_BITS: usize> MeteredCtx<PAGE_BITS> {
self
}

pub fn with_segment_check_insns(mut self, segment_check_insns: u64) -> Self {
self.segmentation_ctx.segment_check_insns = segment_check_insns;
self.segmentation_ctx.instrets_until_check = segment_check_insns;

// Update memory context with new segment check instructions
let page_indices_since_checkpoint_cap =
MemoryCtx::<PAGE_BITS>::calculate_checkpoint_capacity(segment_check_insns);

self.memory_ctx.page_indices_since_checkpoint =
vec![0; page_indices_since_checkpoint_cap].into_boxed_slice();
self.memory_ctx.page_indices_since_checkpoint_len = 0;
self
}

pub fn segments(&self) -> &[Segment] {
&self.segmentation_ctx.segments
}
Expand All @@ -122,12 +131,6 @@ impl<const PAGE_BITS: usize> MeteredCtx<PAGE_BITS> {
self.segmentation_ctx.segments
}

fn reset_segment(&mut self) {
self.memory_ctx.clear();
// Add merkle height contributions for all registers
self.memory_ctx.add_register_merkle_heights();
}

#[inline(always)]
pub fn check_and_segment(&mut self) -> bool {
// We track the segmentation check by instrets_until_check instead of instret in order to
Expand All @@ -147,8 +150,40 @@ impl<const PAGE_BITS: usize> MeteredCtx<PAGE_BITS> {
);

if did_segment {
self.reset_segment();
// Initialize contexts for new segment
self.segmentation_ctx
.initialize_segment(&mut self.trace_heights, &self.is_trace_height_constant);
self.memory_ctx.initialize_segment(&mut self.trace_heights);

// Check if the new segment is within limits
if self.segmentation_ctx.should_segment(
self.segmentation_ctx.instret,
&self.trace_heights,
&self.is_trace_height_constant,
) {
let trace_heights_str = self
.trace_heights
.iter()
.zip(self.segmentation_ctx.air_names.iter())
.filter(|(&height, _)| height > 0)
.map(|(&height, name)| format!(" {name} = {height}"))
.collect::<Vec<_>>()
.join("\n");
tracing::warn!(
"Segment initialized with heights that exceed limits\n\
instret={}\n\
trace_heights=[\n{}\n]",
self.segmentation_ctx.instret,
trace_heights_str
);
}
}

// Update checkpoints
self.segmentation_ctx
.update_checkpoint(self.segmentation_ctx.instret, &self.trace_heights);
self.memory_ctx.update_checkpoint();

did_segment
}

Expand Down
139 changes: 113 additions & 26 deletions crates/vm/src/arch/execution_mode/metered/memory_ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use openvm_instructions::riscv::{RV32_NUM_REGISTERS, RV32_REGISTER_AS, RV32_REGI

use crate::{arch::SystemConfig, system::memory::dimensions::MemoryDimensions};

/// Upper bound on number of memory pages accessed per instruction. Used for buffer allocation.
pub const MAX_MEM_PAGE_OPS_PER_INSN: usize = 1 << 16;

#[derive(Clone, Debug)]
pub struct BitSet {
words: Box<[u64]>,
Expand Down Expand Up @@ -99,7 +102,6 @@ impl BitSet {

#[derive(Clone, Debug)]
pub struct MemoryCtx<const PAGE_BITS: usize> {
pub page_indices: BitSet,
memory_dimensions: MemoryDimensions,
min_block_size_bits: Vec<u8>,
pub boundary_idx: usize,
Expand All @@ -108,22 +110,26 @@ pub struct MemoryCtx<const PAGE_BITS: usize> {
continuations_enabled: bool,
chunk: u32,
chunk_bits: u32,
pub page_access_count: usize,
// Note: 32 is the maximum access adapter size.
pub page_indices: BitSet,
pub addr_space_access_count: RVec<usize>,
pub page_indices_since_checkpoint: Box<[u32]>,
pub page_indices_since_checkpoint_len: usize,
}

impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
pub fn new(config: &SystemConfig) -> Self {
pub fn new(config: &SystemConfig, segment_check_insns: u64) -> Self {
let chunk = config.initial_block_size() as u32;
let chunk_bits = chunk.ilog2();

let memory_dimensions = config.memory_config.memory_dimensions();
let merkle_height = memory_dimensions.overall_height();

let bitset_size = 1 << (merkle_height.saturating_sub(PAGE_BITS));
let addr_space_size = (1 << memory_dimensions.addr_space_height) + 1;
let page_indices_since_checkpoint_cap =
Self::calculate_checkpoint_capacity(segment_check_insns);

Self {
// Address height already considers `chunk_bits`.
page_indices: BitSet::new(1 << (merkle_height.saturating_sub(PAGE_BITS))),
min_block_size_bits: config.memory_config.min_block_size_bits(),
boundary_idx: config.memory_boundary_air_id(),
merkle_tree_index: config.memory_merkle_air_id(),
Expand All @@ -132,14 +138,17 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
chunk_bits,
memory_dimensions,
continuations_enabled: config.continuation_enabled,
page_access_count: 0,
addr_space_access_count: vec![0; (1 << memory_dimensions.addr_space_height) + 1].into(),
page_indices: BitSet::new(bitset_size),
addr_space_access_count: vec![0; addr_space_size].into(),
page_indices_since_checkpoint: vec![0; page_indices_since_checkpoint_cap]
.into_boxed_slice(),
page_indices_since_checkpoint_len: 0,
}
}

#[inline(always)]
pub fn clear(&mut self) {
self.page_indices.clear();
pub(super) fn calculate_checkpoint_capacity(segment_check_insns: u64) -> usize {
segment_check_insns as usize * MAX_MEM_PAGE_OPS_PER_INSN
}

#[inline(always)]
Expand Down Expand Up @@ -177,10 +186,23 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
let end_block_id = start_block_id + num_blocks;
let start_page_id = start_block_id >> PAGE_BITS;
let end_page_id = ((end_block_id - 1) >> PAGE_BITS) + 1;
assert!(
self.page_indices_since_checkpoint_len + (end_page_id - start_page_id) as usize
<= self.page_indices_since_checkpoint.len(),
"more than {MAX_MEM_PAGE_OPS_PER_INSN} memory pages accessed in a single instruction"
);

for page_id in start_page_id..end_page_id {
// Append page_id to page_indices_since_checkpoint
let len = self.page_indices_since_checkpoint_len;
debug_assert!(len < self.page_indices_since_checkpoint.len());
// SAFETY: len is within bounds, and we extend length by 1 after writing.
unsafe {
*self.page_indices_since_checkpoint.as_mut_ptr().add(len) = page_id;
}
self.page_indices_since_checkpoint_len = len + 1;

if self.page_indices.insert(page_id as usize) {
self.page_access_count += 1;
// SAFETY: address_space passed is usually a hardcoded constant or derived from an
// Instruction where it is bounds checked before passing
unsafe {
Expand Down Expand Up @@ -235,13 +257,69 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
}
}

/// Resolve all lazy updates of each memory access for memory adapters/poseidon2/merkle chip.
/// Initialize state for a new segment
#[inline(always)]
pub(crate) fn lazy_update_boundary_heights(&mut self, trace_heights: &mut [u32]) {
debug_assert!(self.boundary_idx < trace_heights.len());
pub(crate) fn initialize_segment(&mut self, trace_heights: &mut [u32]) {
// Clear page indices for the new segment
self.page_indices.clear();

// Reset trace heights for memory chips as 0
// SAFETY: boundary_idx is a compile time constant within bounds
unsafe {
*trace_heights.get_unchecked_mut(self.boundary_idx) = 0;
}
if let Some(merkle_tree_idx) = self.merkle_tree_index {
// SAFETY: merkle_tree_idx is guaranteed to be in bounds
unsafe {
*trace_heights.get_unchecked_mut(merkle_tree_idx) = 0;
}
let poseidon2_idx = trace_heights.len() - 2;
// SAFETY: poseidon2_idx is trace_heights.len() - 2, guaranteed to be in bounds
unsafe {
*trace_heights.get_unchecked_mut(poseidon2_idx) = 0;
}
}
Comment on lines +266 to +281
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't reset access adapter heights here because access adapters are shared between memory and normal traces and it won't be correct to reset them to 0. what we want ideally is to remove only the contribution of memory chips to access adapter heights


// Apply height updates for all pages accessed since last checkpoint, and
// initialize page_indices for the new segment.
let mut addr_space_access_count = vec![0; self.addr_space_access_count.len()];
let pages_len = self.page_indices_since_checkpoint_len;
for i in 0..pages_len {
// SAFETY: i is within 0..pages_len and pages_len is the slice length.
let page_id = unsafe { *self.page_indices_since_checkpoint.get_unchecked(i) } as usize;
if self.page_indices.insert(page_id) {
let (addr_space, _) = self
.memory_dimensions
.index_to_label((page_id as u64) << PAGE_BITS);
let addr_space_idx = addr_space as usize;
debug_assert!(addr_space_idx < addr_space_access_count.len());
// SAFETY: addr_space_idx is bounds checked in debug and derived from a valid page
// id.
unsafe {
*addr_space_access_count.get_unchecked_mut(addr_space_idx) += 1;
}
}
}
self.apply_height_updates(trace_heights, &addr_space_access_count);

// Add merkle height contributions for all registers
self.add_register_merkle_heights();
self.lazy_update_boundary_heights(trace_heights);
}

/// Updates the checkpoint with current safe state
#[inline(always)]
pub(crate) fn update_checkpoint(&mut self) {
self.page_indices_since_checkpoint_len = 0;
}

/// Apply height updates given page counts
#[inline(always)]
fn apply_height_updates(&self, trace_heights: &mut [u32], addr_space_access_count: &[usize]) {
let page_access_count: usize = addr_space_access_count.iter().sum();

// On page fault, assume we add all leaves in a page
let leaves = (self.page_access_count << PAGE_BITS) as u32;
let leaves = (page_access_count << PAGE_BITS) as u32;
// SAFETY: boundary_idx is a compile time constant within bounds
unsafe {
*trace_heights.get_unchecked_mut(self.boundary_idx) += leaves;
Expand All @@ -261,15 +339,16 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
let nodes = (((1 << PAGE_BITS) - 1) + (merkle_height - PAGE_BITS)) as u32;
// SAFETY: merkle_tree_idx is guaranteed to be in bounds
unsafe {
*trace_heights.get_unchecked_mut(poseidon2_idx) += nodes * 2;
*trace_heights.get_unchecked_mut(merkle_tree_idx) += nodes * 2;
*trace_heights.get_unchecked_mut(poseidon2_idx) +=
nodes * page_access_count as u32 * 2;
*trace_heights.get_unchecked_mut(merkle_tree_idx) +=
nodes * page_access_count as u32 * 2;
Comment on lines -264 to +345
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this was a bug earlier where the update should've been nodes * page_access_count * 2 instead of nodes * 2

}
}
self.page_access_count = 0;

for address_space in 0..self.addr_space_access_count.len() {
for address_space in 0..addr_space_access_count.len() {
// SAFETY: address_space is from 0 to len(), guaranteed to be in bounds
let x = unsafe { *self.addr_space_access_count.get_unchecked(address_space) };
let x = unsafe { *addr_space_access_count.get_unchecked(address_space) };
if x > 0 {
// Initial **and** final handling of touched pages requires send (resp. receive) in
// chunk-sized units for the merkle chip
Expand All @@ -281,15 +360,23 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
self.chunk_bits,
(x << (PAGE_BITS + 1)) as u32,
);
// SAFETY: address_space is from 0 to len(), guaranteed to be in bounds
unsafe {
*self
.addr_space_access_count
.get_unchecked_mut(address_space) = 0;
}
Comment on lines -284 to -289
Copy link
Collaborator Author

@shuklaayush shuklaayush Dec 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved outside this function. see below

}
}
}

/// Resolve all lazy updates of each memory access for memory adapters/poseidon2/merkle chip.
#[inline(always)]
pub(crate) fn lazy_update_boundary_heights(&mut self, trace_heights: &mut [u32]) {
self.apply_height_updates(trace_heights, &self.addr_space_access_count);
// SAFETY: Resetting array elements to 0 is always safe
unsafe {
std::ptr::write_bytes(
self.addr_space_access_count.as_mut_ptr(),
0,
self.addr_space_access_count.len(),
);
}
}
}

#[cfg(test)]
Expand Down
Loading