Skip to content

Commit e5e4dbc

Browse files
authored
perf(new-execution): use unchecked ops in e2 (#1902)
- seems to speed up e2 by 10-15% [Benchmark comparison](https://github.com/axiom-crypto/openvm-reth-benchmark/actions/runs/16527317747)
1 parent 0f1213d commit e5e4dbc

File tree

3 files changed

+167
-45
lines changed

3 files changed

+167
-45
lines changed

crates/vm/src/arch/execution_mode/metered/ctx.rs

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::num::NonZero;
2+
13
use getset::WithSetters;
24
use openvm_instructions::riscv::{
35
RV32_IMM_AS, RV32_NUM_REGISTERS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS,
@@ -110,6 +112,7 @@ impl<const PAGE_BITS: usize> MeteredCtx<PAGE_BITS> {
110112
ctx
111113
}
112114

115+
#[inline(always)]
113116
fn add_register_merkle_heights(&mut self) {
114117
if self.continuations_enabled {
115118
self.memory_ctx.update_boundary_merkle_heights(
@@ -159,15 +162,19 @@ impl<const PAGE_BITS: usize> MeteredCtx<PAGE_BITS> {
159162
self.add_register_merkle_heights();
160163
}
161164

165+
#[inline(always)]
162166
pub fn check_and_segment(&mut self, instret: u64) {
163-
// Avoid checking segment too often.
164-
if instret
165-
< self
166-
.instret_last_segment_check
167-
.saturating_add(self.segment_check_insns)
168-
{
167+
let threshold = self
168+
.instret_last_segment_check
169+
.wrapping_add(self.segment_check_insns);
170+
debug_assert!(
171+
threshold >= self.instret_last_segment_check,
172+
"overflow in segment check threshold calculation"
173+
);
174+
if instret < threshold {
169175
return;
170176
}
177+
171178
self.memory_ctx
172179
.lazy_update_boundary_heights(&mut self.trace_heights);
173180
let did_segment = self.segmentation_ctx.check_and_segment(
@@ -205,14 +212,16 @@ impl<const PAGE_BITS: usize> E1ExecutionCtx for MeteredCtx<PAGE_BITS> {
205212
address_space != RV32_IMM_AS,
206213
"address space must not be immediate"
207214
);
215+
debug_assert!(size > 0, "size must be greater than 0, got {}", size);
208216
debug_assert!(
209217
size.is_power_of_two(),
210218
"size must be a power of 2, got {}",
211219
size
212220
);
213221

214222
// Handle access adapter updates
215-
let size_bits = size.ilog2();
223+
// SAFETY: size passed is always a non-zero power of 2
224+
let size_bits = unsafe { NonZero::new_unchecked(size).ilog2() };
216225
self.memory_ctx
217226
.update_adapter_heights(&mut self.trace_heights, address_space, size_bits);
218227

@@ -247,6 +256,16 @@ impl<const PAGE_BITS: usize> E1ExecutionCtx for MeteredCtx<PAGE_BITS> {
247256
impl<const PAGE_BITS: usize> E2ExecutionCtx for MeteredCtx<PAGE_BITS> {
248257
#[inline(always)]
249258
fn on_height_change(&mut self, chip_idx: usize, height_delta: u32) {
250-
self.trace_heights[chip_idx] += height_delta;
259+
debug_assert!(
260+
chip_idx < self.trace_heights.len(),
261+
"chip_idx out of bounds"
262+
);
263+
// SAFETY: chip_idx is created in executor_idx_to_air_idx and is always within bounds
264+
unsafe {
265+
*self.trace_heights.get_unchecked_mut(chip_idx) = self
266+
.trace_heights
267+
.get_unchecked(chip_idx)
268+
.wrapping_add(height_delta);
269+
}
251270
}
252271
}

crates/vm/src/arch/execution_mode/metered/memory_ctx.rs

Lines changed: 106 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,52 +17,82 @@ impl BitSet {
1717

1818
#[inline(always)]
1919
pub fn insert(&mut self, index: usize) -> bool {
20-
let word_index = index / 64;
21-
let bit_index = index % 64;
20+
let word_index = index >> 6;
21+
let bit_index = index & 63;
2222
let mask = 1u64 << bit_index;
2323

24-
let was_set = (self.words[word_index] & mask) != 0;
25-
self.words[word_index] |= mask;
24+
debug_assert!(word_index < self.words.len(), "BitSet index out of bounds");
25+
26+
// SAFETY: word_index is derived from a memory address that is bounds-checked
27+
// during memory access. The bitset is sized to accommodate all valid
28+
// memory addresses, so word_index is always within bounds.
29+
let word = unsafe { self.words.get_unchecked_mut(word_index) };
30+
let was_set = (*word & mask) != 0;
31+
*word |= mask;
2632
!was_set
2733
}
2834

2935
/// Set all bits within [start, end) to 1, return the number of flipped bits.
36+
/// Assumes start < end and end <= self.words.len() * 64.
3037
#[inline(always)]
3138
pub fn insert_range(&mut self, start: usize, end: usize) -> usize {
3239
debug_assert!(start < end);
40+
debug_assert!(end <= self.words.len() * 64, "BitSet range out of bounds");
41+
3342
let mut ret = 0;
34-
let start_word_index = start / u64::BITS as usize;
35-
let end_word_index = (end - 1) / u64::BITS as usize;
36-
let start_bit = start as u32 % u64::BITS;
43+
let start_word_index = start >> 6;
44+
let end_word_index = (end - 1) >> 6;
45+
let start_bit = (start & 63) as u32;
46+
3747
if start_word_index == end_word_index {
38-
let end_bit = (end - 1) as u32 % u64::BITS + 1;
48+
let end_bit = ((end - 1) & 63) as u32 + 1;
3949
let mask_bits = end_bit - start_bit;
40-
let mask = (u64::MAX >> (u64::BITS - mask_bits)) << start_bit;
41-
ret += mask_bits - (self.words[start_word_index] & mask).count_ones();
42-
self.words[start_word_index] |= mask;
50+
let mask = (u64::MAX >> (64 - mask_bits)) << start_bit;
51+
// SAFETY: Caller ensures start < end and end <= self.words.len() * 64,
52+
// so start_word_index < self.words.len()
53+
let word = unsafe { self.words.get_unchecked_mut(start_word_index) };
54+
ret += mask_bits - (*word & mask).count_ones();
55+
*word |= mask;
4356
} else {
44-
let end_bit = end as u32 % u64::BITS;
45-
let mask_bits = u64::BITS - start_bit;
57+
let end_bit = (end & 63) as u32;
58+
let mask_bits = 64 - start_bit;
4659
let mask = u64::MAX << start_bit;
47-
ret += mask_bits - (self.words[start_word_index] & mask).count_ones();
48-
self.words[start_word_index] |= mask;
60+
// SAFETY: Caller ensures start < end and end <= self.words.len() * 64,
61+
// so start_word_index < self.words.len()
62+
let start_word = unsafe { self.words.get_unchecked_mut(start_word_index) };
63+
ret += mask_bits - (*start_word & mask).count_ones();
64+
*start_word |= mask;
65+
4966
let mask_bits = end_bit;
50-
let (mask, _) = u64::MAX.overflowing_shr(u64::BITS - end_bit);
51-
ret += mask_bits - (self.words[end_word_index] & mask).count_ones();
52-
self.words[end_word_index] |= mask;
67+
let mask = if end_bit == 0 {
68+
0
69+
} else {
70+
u64::MAX >> (64 - end_bit)
71+
};
72+
// SAFETY: Caller ensures end <= self.words.len() * 64, so
73+
// end_word_index < self.words.len()
74+
let end_word = unsafe { self.words.get_unchecked_mut(end_word_index) };
75+
ret += mask_bits - (*end_word & mask).count_ones();
76+
*end_word |= mask;
5377
}
78+
5479
if start_word_index + 1 < end_word_index {
5580
for i in (start_word_index + 1)..end_word_index {
56-
ret += self.words[i].count_zeros();
57-
self.words[i] = u64::MAX;
81+
// SAFETY: Caller ensures proper start and end, so i is within bounds
82+
// of self.words.len()
83+
let word = unsafe { self.words.get_unchecked_mut(i) };
84+
ret += word.count_zeros();
85+
*word = u64::MAX;
5886
}
5987
}
6088
ret as usize
6189
}
6290

91+
#[inline(always)]
6392
pub fn clear(&mut self) {
64-
for item in self.words.iter_mut() {
65-
*item = 0;
93+
// SAFETY: words is valid for self.words.len() elements
94+
unsafe {
95+
std::ptr::write_bytes(self.words.as_mut_ptr(), 0, self.words.len());
6696
}
6797
}
6898
}
@@ -132,6 +162,7 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
132162
addr_space_access_count: vec![0; (1 << memory_dimensions.addr_space_height) + 1],
133163
}
134164
}
165+
135166
#[inline(always)]
136167
pub fn clear(&mut self) {
137168
self.page_indices.clear();
@@ -147,6 +178,8 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
147178
ptr: u32,
148179
size: u32,
149180
) {
181+
debug_assert!((address_space as usize) < self.addr_space_access_count.len());
182+
150183
let num_blocks = (size + self.chunk - 1) >> self.chunk_bits;
151184
let start_chunk_id = ptr >> self.chunk_bits;
152185
let start_block_id = if self.chunk == 1 {
@@ -159,10 +192,17 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
159192
let end_block_id = start_block_id + num_blocks;
160193
let start_page_id = start_block_id >> PAGE_BITS;
161194
let end_page_id = ((end_block_id - 1) >> PAGE_BITS) + 1;
195+
162196
for page_id in start_page_id..end_page_id {
163197
if self.page_indices.insert(page_id as usize) {
164198
self.page_access_count += 1;
165-
self.addr_space_access_count[address_space as usize] += 1;
199+
// SAFETY: address_space passed is usually a hardcoded constant or derived from an
200+
// Instruction where it is bounds checked before passing
201+
unsafe {
202+
*self
203+
.addr_space_access_count
204+
.get_unchecked_mut(address_space as usize) += 1;
205+
}
166206
}
167207
}
168208
}
@@ -185,38 +225,68 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
185225
size_bits: u32,
186226
num: u32,
187227
) {
188-
let align_bits = self.as_byte_alignment_bits[address_space as usize];
228+
debug_assert!((address_space as usize) < self.as_byte_alignment_bits.len());
229+
230+
// SAFETY: address_space passed is usually a hardcoded constant or derived from an
231+
// Instruction where it is bounds checked before passing
232+
let align_bits = unsafe {
233+
*self
234+
.as_byte_alignment_bits
235+
.get_unchecked(address_space as usize)
236+
};
189237
debug_assert!(
190238
align_bits as u32 <= size_bits,
191239
"align_bits ({}) must be <= size_bits ({})",
192240
align_bits,
193241
size_bits
194242
);
243+
195244
for adapter_bits in (align_bits as u32 + 1..=size_bits).rev() {
196245
let adapter_idx = self.adapter_offset + adapter_bits as usize - 1;
197-
trace_heights[adapter_idx] += num << (size_bits - adapter_bits + 1);
246+
debug_assert!(adapter_idx < trace_heights.len());
247+
// SAFETY: trace_heights is initialized taking access adapters into account
248+
unsafe {
249+
*trace_heights.get_unchecked_mut(adapter_idx) +=
250+
num << (size_bits - adapter_bits + 1);
251+
}
198252
}
199253
}
200254

201255
/// Resolve all lazy updates of each memory access for memory adapters/poseidon2/merkle chip.
202256
#[inline(always)]
203257
pub(crate) fn lazy_update_boundary_heights(&mut self, trace_heights: &mut [u32]) {
258+
debug_assert!(self.boundary_idx < trace_heights.len());
259+
204260
// On page fault, assume we add all leaves in a page
205261
let leaves = (self.page_access_count << PAGE_BITS) as u32;
206-
trace_heights[self.boundary_idx] += leaves;
262+
// SAFETY: boundary_idx is a compile time constant within bounds
263+
unsafe {
264+
*trace_heights.get_unchecked_mut(self.boundary_idx) += leaves;
265+
}
207266

208267
if let Some(merkle_tree_idx) = self.merkle_tree_index {
268+
debug_assert!(merkle_tree_idx < trace_heights.len());
269+
debug_assert!(trace_heights.len() >= 2);
270+
209271
let poseidon2_idx = trace_heights.len() - 2;
210-
trace_heights[poseidon2_idx] += leaves * 2;
272+
// SAFETY: poseidon2_idx is trace_heights.len() - 2, guaranteed to be in bounds
273+
unsafe {
274+
*trace_heights.get_unchecked_mut(poseidon2_idx) += leaves * 2;
275+
}
211276

212277
let merkle_height = self.memory_dimensions.overall_height();
213278
let nodes = (((1 << PAGE_BITS) - 1) + (merkle_height - PAGE_BITS)) as u32;
214-
trace_heights[poseidon2_idx] += nodes * 2;
215-
trace_heights[merkle_tree_idx] += nodes * 2;
279+
// SAFETY: merkle_tree_idx is guaranteed to be in bounds
280+
unsafe {
281+
*trace_heights.get_unchecked_mut(poseidon2_idx) += nodes * 2;
282+
*trace_heights.get_unchecked_mut(merkle_tree_idx) += nodes * 2;
283+
}
216284
}
217285
self.page_access_count = 0;
286+
218287
for address_space in 0..self.addr_space_access_count.len() {
219-
let x = self.addr_space_access_count[address_space];
288+
// SAFETY: address_space is from 0 to len(), guaranteed to be in bounds
289+
let x = unsafe { *self.addr_space_access_count.get_unchecked(address_space) };
220290
if x > 0 {
221291
// After finalize, we'll need to read it in chunk-sized units for the merkle chip
222292
self.update_adapter_heights_batch(
@@ -225,7 +295,12 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
225295
self.chunk_bits,
226296
(x << PAGE_BITS) as u32,
227297
);
228-
self.addr_space_access_count[address_space] = 0;
298+
// SAFETY: address_space is from 0 to len(), guaranteed to be in bounds
299+
unsafe {
300+
*self
301+
.addr_space_access_count
302+
.get_unchecked_mut(address_space) = 0;
303+
}
229304
}
230305
}
231306
}

0 commit comments

Comments
 (0)