Skip to content

Commit 225bb16

Browse files
feat: typeless AddressMap with typed APIs (#1559)
Note: this PR is not targeting `main`. I've used `TODO` and `TEMP` to mark places in code that will need to be cleaned up before merging to `main`. Beginning the refactor of online memory to allow different host types in different address spaces. Going to touch a lot of APIs. Focusing on stabilizing APIs - currently this PR will not improve performance. Tests will not all pass because I have intentionally disabled some logging required for trace generation. Only execution tests will pass (or run the execute benchmark). In future PR(s): - [ ] make `Memory` trait for execution read/write API - [ ] better handling of type conversions for memory image - [ ] replace the underlying memory implementation with other implementations like mmap Towards INT-3743 Even with wasteful conversions, execution is faster: Before: https://github.com/openvm-org/openvm/actions/runs/14318675080 After: https://github.com/openvm-org/openvm/actions/runs/14371335248?pr=1559
1 parent a9f68e0 commit 225bb16

File tree

48 files changed

+806
-709
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+806
-709
lines changed

crates/toolchain/instructions/src/exe.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ use serde::{Deserialize, Serialize};
55

66
use crate::program::Program;
77

8-
/// Memory image is a map from (address space, address) to word.
9-
pub type MemoryImage<F> = BTreeMap<(u32, u32), F>;
8+
// TODO[jpw]: delete this
9+
/// Memory image is a map from (address space, address * size_of<CellType>) to u8.
10+
pub type SparseMemoryImage = BTreeMap<(u32, u32), u8>;
1011
/// Stores the starting address, end address, and name of a set of function.
1112
pub type FnBounds = BTreeMap<u32, FnBound>;
1213

@@ -22,7 +23,7 @@ pub struct VmExe<F> {
2223
/// Start address of pc.
2324
pub pc_start: u32,
2425
/// Initial memory image.
25-
pub init_memory: MemoryImage<F>,
26+
pub init_memory: SparseMemoryImage,
2627
/// Starting + ending bounds for each function.
2728
pub fn_bounds: FnBounds,
2829
}
@@ -40,7 +41,7 @@ impl<F> VmExe<F> {
4041
self.pc_start = pc_start;
4142
self
4243
}
43-
pub fn with_init_memory(mut self, init_memory: MemoryImage<F>) -> Self {
44+
pub fn with_init_memory(mut self, init_memory: SparseMemoryImage) -> Self {
4445
self.init_memory = init_memory;
4546
self
4647
}

crates/toolchain/transpiler/src/util.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::collections::BTreeMap;
22

33
use openvm_instructions::{
4-
exe::MemoryImage,
4+
exe::SparseMemoryImage,
55
instruction::Instruction,
66
riscv::{RV32_MEMORY_AS, RV32_REGISTER_NUM_LIMBS},
77
utils::isize_to_field,
@@ -163,17 +163,14 @@ pub fn nop<F: PrimeField32>() -> Instruction<F> {
163163
}
164164
}
165165

166-
/// Converts our memory image (u32 -> [u8; 4]) into Vm memory image ((as, address) -> word)
167-
pub fn elf_memory_image_to_openvm_memory_image<F: PrimeField32>(
166+
/// Converts our memory image (u32 -> [u8; 4]) into Vm memory image ((as=2, address) -> byte)
167+
pub fn elf_memory_image_to_openvm_memory_image(
168168
memory_image: BTreeMap<u32, u32>,
169-
) -> MemoryImage<F> {
170-
let mut result = MemoryImage::new();
169+
) -> SparseMemoryImage {
170+
let mut result = SparseMemoryImage::new();
171171
for (addr, word) in memory_image {
172172
for (i, byte) in word.to_le_bytes().into_iter().enumerate() {
173-
result.insert(
174-
(RV32_MEMORY_AS, addr + i as u32),
175-
F::from_canonical_u8(byte),
176-
);
173+
result.insert((RV32_MEMORY_AS, addr + i as u32), byte);
177174
}
178175
}
179176
result

crates/vm/src/arch/extensions.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ impl<F: PrimeField32, E, P> VmChipComplex<F, E, P> {
786786
self.base.program_chip.set_program(program);
787787
}
788788

789-
pub(crate) fn set_initial_memory(&mut self, memory: MemoryImage<F>) {
789+
pub(crate) fn set_initial_memory(&mut self, memory: MemoryImage) {
790790
self.base.memory_controller.set_initial_memory(memory);
791791
}
792792

crates/vm/src/arch/segment.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ where
145145
{
146146
pub chip_complex: VmChipComplex<F, VC::Executor, VC::Periphery>,
147147
/// Memory image after segment was executed. Not used in trace generation.
148-
pub final_memory: Option<MemoryImage<F>>,
148+
pub final_memory: Option<MemoryImage>,
149149

150150
pub since_last_segment_check: usize,
151151
pub trace_height_constraints: Vec<LinearConstraint>,
@@ -168,7 +168,7 @@ impl<F: PrimeField32, VC: VmConfig<F>> ExecutionSegment<F, VC> {
168168
config: &VC,
169169
program: Program<F>,
170170
init_streams: Streams<F>,
171-
initial_memory: Option<MemoryImage<F>>,
171+
initial_memory: Option<MemoryImage>,
172172
trace_height_constraints: Vec<LinearConstraint>,
173173
#[allow(unused_variables)] fn_bounds: FnBounds,
174174
) -> Self {

crates/vm/src/arch/vm.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ pub enum GenerationError {
4747
}
4848

4949
/// VM memory state for continuations.
50-
pub type VmMemoryState<F> = MemoryImage<F>;
5150
5251
#[derive(Clone, Default, Debug)]
5352
pub struct Streams<F> {
@@ -95,19 +94,19 @@ pub enum ExitCode {
9594
pub struct VmExecutorResult<SC: StarkGenericConfig> {
9695
pub per_segment: Vec<ProofInput<SC>>,
9796
/// When VM is running on persistent mode, public values are stored in a special memory space.
98-
pub final_memory: Option<VmMemoryState<Val<SC>>>,
97+
pub final_memory: Option<MemoryImage>,
9998
}
10099

101100
pub struct VmExecutorNextSegmentState<F: PrimeField32> {
102-
pub memory: MemoryImage<F>,
101+
pub memory: MemoryImage,
103102
pub input: Streams<F>,
104103
pub pc: u32,
105104
#[cfg(feature = "bench-metrics")]
106105
pub metrics: VmMetrics,
107106
}
108107

109108
impl<F: PrimeField32> VmExecutorNextSegmentState<F> {
110-
pub fn new(memory: MemoryImage<F>, input: impl Into<Streams<F>>, pc: u32) -> Self {
109+
pub fn new(memory: MemoryImage, input: impl Into<Streams<F>>, pc: u32) -> Self {
111110
Self {
112111
memory,
113112
input: input.into(),
@@ -170,12 +169,13 @@ where
170169
let mem_config = self.config.system().memory_config;
171170
let exe = exe.into();
172171
let mut segment_results = vec![];
173-
let memory = AddressMap::from_iter(
172+
let memory = AddressMap::from_sparse(
174173
mem_config.as_offset,
175174
1 << mem_config.as_height,
176175
1 << mem_config.pointer_max_bits,
177176
exe.init_memory.clone(),
178177
);
178+
179179
let pc = exe.pc_start;
180180
let mut state = VmExecutorNextSegmentState::new(memory, input, pc);
181181
let mut segment_idx = 0;
@@ -271,7 +271,7 @@ where
271271
&self,
272272
exe: impl Into<VmExe<F>>,
273273
input: impl Into<Streams<F>>,
274-
) -> Result<Option<VmMemoryState<F>>, ExecutionError> {
274+
) -> Result<Option<MemoryImage>, ExecutionError> {
275275
let mut last = None;
276276
self.execute_and_then(
277277
exe,
@@ -580,7 +580,7 @@ where
580580
&self,
581581
exe: impl Into<VmExe<F>>,
582582
input: impl Into<Streams<F>>,
583-
) -> Result<Option<VmMemoryState<F>>, ExecutionError> {
583+
) -> Result<Option<MemoryImage>, ExecutionError> {
584584
self.executor.execute(exe, input)
585585
}
586586

crates/vm/src/system/memory/controller/interface.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pub enum MemoryInterface<F> {
1313
Persistent {
1414
boundary_chip: PersistentBoundaryChip<F, CHUNK>,
1515
merkle_chip: MemoryMerkleChip<CHUNK, F>,
16-
initial_memory: MemoryImage<F>,
16+
initial_memory: MemoryImage,
1717
},
1818
}
1919

crates/vm/src/system/memory/controller/mod.rs

Lines changed: 52 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use std::{
33
collections::BTreeMap,
44
iter,
55
marker::PhantomData,
6-
mem,
76
sync::{Arc, Mutex},
87
};
98

@@ -62,7 +61,7 @@ pub const BOUNDARY_AIR_OFFSET: usize = 0;
6261
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
6362
pub struct RecordId(pub usize);
6463

65-
pub type MemoryImage<F> = AddressMap<F, PAGE_SIZE>;
64+
pub type MemoryImage = AddressMap<PAGE_SIZE>;
6665

6766
#[repr(C)]
6867
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
@@ -98,7 +97,7 @@ pub struct MemoryController<F> {
9897
// Store separately to avoid smart pointer reference each time
9998
range_checker_bus: VariableRangeCheckerBus,
10099
// addr_space -> Memory data structure
101-
memory: Memory<F>,
100+
memory: Memory,
102101
/// A reference to the `OfflineMemory`. Will be populated after `finalize()`.
103102
offline_memory: Arc<Mutex<OfflineMemory<F>>>,
104103
pub access_adapters: AccessAdapterInventory<F>,
@@ -314,7 +313,7 @@ impl<F: PrimeField32> MemoryController<F> {
314313
}
315314
}
316315

317-
pub fn memory_image(&self) -> &MemoryImage<F> {
316+
pub fn memory_image(&self) -> &MemoryImage {
318317
&self.memory.data
319318
}
320319

@@ -344,7 +343,7 @@ impl<F: PrimeField32> MemoryController<F> {
344343
}
345344
}
346345

347-
pub fn set_initial_memory(&mut self, memory: MemoryImage<F>) {
346+
pub fn set_initial_memory(&mut self, memory: MemoryImage) {
348347
if self.timestamp() > INITIAL_TIMESTAMP + 1 {
349348
panic!("Cannot set initial memory after first timestamp");
350349
}
@@ -379,58 +378,67 @@ impl<F: PrimeField32> MemoryController<F> {
379378
(record_id, data)
380379
}
381380

382-
pub fn read<const N: usize>(&mut self, address_space: F, pointer: F) -> (RecordId, [F; N]) {
381+
// TEMP[jpw]: Function is safe temporarily for refactoring
382+
/// # Safety
383+
/// The type `T` must be stack-allocated `repr(C)` or `repr(transparent)`, and it must be the exact type used to represent a single
384+
/// memory cell in address space `address_space`. For standard usage, `T` is either `u8` or `F` where `F` is
385+
/// the base field of the ZK backend.
386+
pub fn read<T: Copy, const N: usize>(
387+
&mut self,
388+
address_space: F,
389+
pointer: F,
390+
) -> (RecordId, [T; N]) {
383391
let address_space_u32 = address_space.as_canonical_u32();
384392
let ptr_u32 = pointer.as_canonical_u32();
385393
assert!(
386394
address_space == F::ZERO || ptr_u32 < (1 << self.mem_config.pointer_max_bits),
387395
"memory out of bounds: {ptr_u32:?}",
388396
);
389397

390-
let (record_id, values) = self.memory.read::<N>(address_space_u32, ptr_u32);
398+
let (record_id, values) = unsafe { self.memory.read::<T, N>(address_space_u32, ptr_u32) };
391399

392400
(record_id, values)
393401
}
394402

395403
/// Reads a word directly from memory without updating internal state.
396404
///
397405
/// Any value returned is unconstrained.
398-
pub fn unsafe_read_cell(&self, addr_space: F, ptr: F) -> F {
399-
self.unsafe_read::<1>(addr_space, ptr)[0]
406+
pub fn unsafe_read_cell<T: Copy>(&self, addr_space: F, ptr: F) -> T {
407+
self.unsafe_read::<T, 1>(addr_space, ptr)[0]
400408
}
401409

402410
/// Reads a word directly from memory without updating internal state.
403411
///
404412
/// Any value returned is unconstrained.
405-
pub fn unsafe_read<const N: usize>(&self, addr_space: F, ptr: F) -> [F; N] {
413+
pub fn unsafe_read<T: Copy, const N: usize>(&self, addr_space: F, ptr: F) -> [T; N] {
406414
let addr_space = addr_space.as_canonical_u32();
407415
let ptr = ptr.as_canonical_u32();
408-
array::from_fn(|i| self.memory.get(addr_space, ptr + i as u32))
416+
unsafe { array::from_fn(|i| self.memory.get::<T>(addr_space, ptr + i as u32)) }
409417
}
410418

411419
/// Writes `data` to the given cell.
412420
///
413421
/// Returns the `RecordId` and previous data.
414-
pub fn write_cell(&mut self, address_space: F, pointer: F, data: F) -> (RecordId, F) {
415-
let (record_id, [data]) = self.write(address_space, pointer, [data]);
422+
pub fn write_cell<T: Copy>(&mut self, address_space: F, pointer: F, data: T) -> (RecordId, T) {
423+
let (record_id, [data]) = self.write(address_space, pointer, &[data]);
416424
(record_id, data)
417425
}
418426

419-
pub fn write<const N: usize>(
427+
pub fn write<T: Copy, const N: usize>(
420428
&mut self,
421429
address_space: F,
422430
pointer: F,
423-
data: [F; N],
424-
) -> (RecordId, [F; N]) {
425-
assert_ne!(address_space, F::ZERO);
431+
data: &[T; N],
432+
) -> (RecordId, [T; N]) {
433+
debug_assert_ne!(address_space, F::ZERO);
426434
let address_space_u32 = address_space.as_canonical_u32();
427435
let ptr_u32 = pointer.as_canonical_u32();
428436
assert!(
429437
ptr_u32 < (1 << self.mem_config.pointer_max_bits),
430438
"memory out of bounds: {ptr_u32:?}",
431439
);
432440

433-
self.memory.write(address_space_u32, ptr_u32, data)
441+
unsafe { self.memory.write::<T, N>(address_space_u32, ptr_u32, data) }
434442
}
435443

436444
pub fn aux_cols_factory(&self) -> MemoryAuxColsFactory<F> {
@@ -455,26 +463,27 @@ impl<F: PrimeField32> MemoryController<F> {
455463
}
456464

457465
fn replay_access_log(&mut self) {
458-
let log = mem::take(&mut self.memory.log);
459-
if log.is_empty() {
460-
// Online memory logs may be empty, but offline memory may be replayed from external sources.
461-
// In these cases, we skip the calls to replay access logs because `set_log_capacity` would
462-
// panic.
463-
tracing::debug!("skipping replay_access_log");
464-
return;
465-
}
466-
467-
let mut offline_memory = self.offline_memory.lock().unwrap();
468-
offline_memory.set_log_capacity(log.len());
469-
470-
for entry in log {
471-
Self::replay_access(
472-
entry,
473-
&mut offline_memory,
474-
&mut self.interface_chip,
475-
&mut self.access_adapters,
476-
);
477-
}
466+
unimplemented!();
467+
// let log = mem::take(&mut self.memory.log);
468+
// if log.is_empty() {
469+
// // Online memory logs may be empty, but offline memory may be replayed from external sources.
470+
// // In these cases, we skip the calls to replay access logs because `set_log_capacity` would
471+
// // panic.
472+
// tracing::debug!("skipping replay_access_log");
473+
// return;
474+
// }
475+
476+
// let mut offline_memory = self.offline_memory.lock().unwrap();
477+
// offline_memory.set_log_capacity(log.len());
478+
479+
// for entry in log {
480+
// Self::replay_access(
481+
// entry,
482+
// &mut offline_memory,
483+
// &mut self.interface_chip,
484+
// &mut self.access_adapters,
485+
// );
486+
// }
478487
}
479488

480489
/// Low-level API to replay a single memory access log entry and populate the [OfflineMemory], [MemoryInterface], and `AccessAdapterInventory`.
@@ -703,13 +712,13 @@ impl<F: PrimeField32> MemoryController<F> {
703712
pub fn offline_memory(&self) -> Arc<Mutex<OfflineMemory<F>>> {
704713
self.offline_memory.clone()
705714
}
706-
pub fn get_memory_logs(&self) -> &Vec<MemoryLogEntry<F>> {
715+
pub fn get_memory_logs(&self) -> &Vec<MemoryLogEntry<u8>> {
707716
&self.memory.log
708717
}
709-
pub fn set_memory_logs(&mut self, logs: Vec<MemoryLogEntry<F>>) {
718+
pub fn set_memory_logs(&mut self, logs: Vec<MemoryLogEntry<u8>>) {
710719
self.memory.log = logs;
711720
}
712-
pub fn take_memory_logs(&mut self) -> Vec<MemoryLogEntry<F>> {
721+
pub fn take_memory_logs(&mut self) -> Vec<MemoryLogEntry<u8>> {
713722
std::mem::take(&mut self.memory.log)
714723
}
715724
}
@@ -855,9 +864,9 @@ mod tests {
855864

856865
if rng.gen_bool(0.5) {
857866
let data = F::from_canonical_u32(rng.gen_range(0..1 << 30));
858-
memory_controller.write(address_space, pointer, [data]);
867+
memory_controller.write(address_space, pointer, &[data]);
859868
} else {
860-
memory_controller.read::<1>(address_space, pointer);
869+
memory_controller.read::<F, 1>(address_space, pointer);
861870
}
862871
}
863872
assert!(memory_controller

0 commit comments

Comments
 (0)