Skip to content

Commit cb44337

Browse files
chore: Remove Vec from MemoryRecord interface (#1298)
* chore: Remove Vec from MemoryRecord interface This allows us to change out the implementation (e.g., SmallVec) without affecting the interface * Apply suggestions from code review --------- Co-authored-by: Jonathan Wang <[email protected]>
1 parent ba15b51 commit cb44337

File tree

13 files changed

+83
-43
lines changed

13 files changed

+83
-43
lines changed

crates/vm/src/arch/testing/memory/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ where
9898
pointer: record.pointer,
9999
};
100100
row.data
101-
.copy_from_slice(record.prev_data.as_ref().unwrap_or(&record.data));
101+
.copy_from_slice(record.prev_data_slice().unwrap_or(record.data_slice()));
102102
row.timestamp = Val::<SC>::from_canonical_u32(record.prev_timestamp);
103103
row.count = -Val::<SC>::ONE;
104104

@@ -107,7 +107,7 @@ where
107107
address_space: record.address_space,
108108
pointer: record.pointer,
109109
};
110-
row.data.copy_from_slice(&record.data);
110+
row.data.copy_from_slice(record.data_slice());
111111
row.timestamp = Val::<SC>::from_canonical_u32(record.timestamp);
112112
row.count = Val::<SC>::ONE;
113113
}

crates/vm/src/system/memory/controller/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ impl<F: PrimeField32> MemoryAuxColsFactory<F> {
737737
) {
738738
buffer
739739
.prev_data
740-
.copy_from_slice(write.prev_data.as_ref().unwrap());
740+
.copy_from_slice(write.prev_data_slice().unwrap());
741741
self.generate_base_aux(write, &mut buffer.base);
742742
}
743743

@@ -780,7 +780,7 @@ impl<F: PrimeField32> MemoryAuxColsFactory<F> {
780780
&self,
781781
write: &MemoryRecord<F>,
782782
) -> MemoryWriteAuxCols<F, N> {
783-
let prev_data = write.prev_data.clone().unwrap();
783+
let prev_data = write.prev_data_slice().unwrap();
784784
MemoryWriteAuxCols::new(
785785
prev_data.try_into().unwrap(),
786786
F::from_canonical_u32(write.prev_timestamp),

crates/vm/src/system/memory/offline.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,25 @@ pub struct MemoryRecord<T> {
3131
pub pointer: T,
3232
pub timestamp: u32,
3333
pub prev_timestamp: u32,
34-
pub data: Vec<T>,
34+
data: Vec<T>,
3535
/// None if a read.
36-
pub prev_data: Option<Vec<T>>,
36+
prev_data: Option<Vec<T>>,
37+
}
38+
39+
impl<T> MemoryRecord<T> {
40+
pub fn data_slice(&self) -> &[T] {
41+
self.data.as_slice()
42+
}
43+
44+
pub fn prev_data_slice(&self) -> Option<&[T]> {
45+
self.prev_data.as_deref()
46+
}
47+
}
48+
49+
impl<T: Copy> MemoryRecord<T> {
50+
pub fn data_at(&self, index: usize) -> T {
51+
self.data[index]
52+
}
3753
}
3854

3955
pub struct OfflineMemory<F> {

crates/vm/src/system/memory/tests.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,30 +166,30 @@ fn generate_trace<F: PrimeField32>(
166166
row.pointer = record.pointer;
167167
row.timestamp = F::from_canonical_u32(record.timestamp);
168168

169-
match (record.data.len(), &record.prev_data) {
169+
match (record.data_slice().len(), &record.prev_data_slice()) {
170170
(1, &None) => {
171171
aux_factory.generate_read_aux(&record, &mut row.read_1_aux);
172-
row.data_1 = record.data.try_into().unwrap();
172+
row.data_1 = record.data_slice().try_into().unwrap();
173173
row.is_read_1 = F::ONE;
174174
}
175175
(1, &Some(_)) => {
176176
aux_factory.generate_write_aux(&record, &mut row.write_1_aux);
177-
row.data_1 = record.data.try_into().unwrap();
177+
row.data_1 = record.data_slice().try_into().unwrap();
178178
row.is_write_1 = F::ONE;
179179
}
180180
(4, &None) => {
181181
aux_factory.generate_read_aux(&record, &mut row.read_4_aux);
182-
row.data_4 = record.data.try_into().unwrap();
182+
row.data_4 = record.data_slice().try_into().unwrap();
183183
row.is_read_4 = F::ONE;
184184
}
185185
(4, &Some(_)) => {
186186
aux_factory.generate_write_aux(&record, &mut row.write_4_aux);
187-
row.data_4 = record.data.try_into().unwrap();
187+
row.data_4 = record.data_slice().try_into().unwrap();
188188
row.is_write_4 = F::ONE;
189189
}
190190
(MAX, &None) => {
191191
aux_factory.generate_read_aux(&record, &mut row.read_max_aux);
192-
row.data_max = record.data.try_into().unwrap();
192+
row.data_max = record.data_slice().try_into().unwrap();
193193
row.is_read_max = F::ONE;
194194
}
195195
_ => panic!("unexpected pattern"),

extensions/keccak256/circuit/src/trace.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,14 @@ where
7070
let len_read = memory.record_by_id(record.len_read);
7171

7272
state = [0u64; 25];
73-
let src_limbs: [_; RV32_REGISTER_NUM_LIMBS - 1] = from_fn(|i| src_read.data[i + 1]);
74-
let len_limbs: [_; RV32_REGISTER_NUM_LIMBS - 1] = from_fn(|i| len_read.data[i + 1]);
73+
let src_limbs: [_; RV32_REGISTER_NUM_LIMBS - 1] = src_read.data_slice()
74+
[1..RV32_REGISTER_NUM_LIMBS]
75+
.try_into()
76+
.unwrap();
77+
let len_limbs: [_; RV32_REGISTER_NUM_LIMBS - 1] = len_read.data_slice()
78+
[1..RV32_REGISTER_NUM_LIMBS]
79+
.try_into()
80+
.unwrap();
7581
let mut instruction = KeccakInstructionCols {
7682
pc: record.pc,
7783
is_enabled: Val::<SC>::ONE,
@@ -80,7 +86,7 @@ where
8086
dst_ptr: dst_read.pointer,
8187
src_ptr: src_read.pointer,
8288
len_ptr: len_read.pointer,
83-
dst: dst_read.data.clone().try_into().unwrap(),
89+
dst: dst_read.data_slice().try_into().unwrap(),
8490
src_limbs,
8591
src: Val::<SC>::from_canonical_usize(record.input_blocks[0].src),
8692
len_limbs,
@@ -175,7 +181,7 @@ where
175181
row_mut
176182
.mem_oc
177183
.partial_block
178-
.copy_from_slice(&partial_read.data[1..]);
184+
.copy_from_slice(&partial_read.data_slice()[1..]);
179185
}
180186
for (i, is_padding) in row_mut.sponge.is_padding_byte.iter_mut().enumerate() {
181187
*is_padding = Val::<SC>::from_bool(i >= block.remaining_len);
@@ -196,7 +202,7 @@ where
196202
.map(|r| {
197203
memory
198204
.record_by_id(*r)
199-
.data
205+
.data_slice()
200206
.last()
201207
.unwrap()
202208
.as_canonical_u32()

extensions/native/circuit/src/fri/mod.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use core::ops::Deref;
22
use std::{
3-
array,
43
borrow::{Borrow, BorrowMut},
54
mem::offset_of,
65
sync::{Arc, Mutex},
@@ -573,10 +572,10 @@ fn record_to_rows<F: PrimeField32>(
573572
let a_ptr_read = memory.record_by_id(record.a_ptr_read);
574573
let b_ptr_read = memory.record_by_id(record.b_ptr_read);
575574

576-
let length = length_read.data[0].as_canonical_u32() as usize;
577-
let alpha: [F; EXT_DEG] = array::from_fn(|i| alpha_read.data[i]);
578-
let a_ptr = a_ptr_read.data[0];
579-
let b_ptr = b_ptr_read.data[0];
575+
let length = length_read.data_at(0).as_canonical_u32() as usize;
576+
let alpha: [F; EXT_DEG] = alpha_read.data_slice().try_into().unwrap();
577+
let a_ptr = a_ptr_read.data_at(0);
578+
let b_ptr = b_ptr_read.data_at(0);
580579

581580
let mut result = [F::ZERO; EXT_DEG];
582581

@@ -597,8 +596,8 @@ fn record_to_rows<F: PrimeField32>(
597596
{
598597
let a_read = memory.record_by_id(a_record_id);
599598
let b_read = memory.record_by_id(b_record_id);
600-
let a = a_read.data[0];
601-
let b: [F; EXT_DEG] = array::from_fn(|i| b_read.data[i]);
599+
let a = a_read.data_at(0);
600+
let b: [F; EXT_DEG] = b_read.data_slice().try_into().unwrap();
602601

603602
let start = i * OVERALL_WIDTH;
604603
let cols: &mut WorkloadCols<F> = slice[start..start + WL_WIDTH].borrow_mut();

extensions/native/circuit/src/poseidon2/trace.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ impl<F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Chip<F, SBOX_R
114114
&mut specific.read_final_height_or_sibling_array_start,
115115
);
116116
specific.root_is_on_right = F::from_bool(root_is_on_right);
117-
specific.sibling_array_start = read_sibling_array_start.data[0];
117+
specific.sibling_array_start = read_sibling_array_start.data_at(0);
118118
}
119119
fn correct_last_top_level_row(
120120
&self,

extensions/rv32-adapters/src/eq_mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ impl<
399399
let rs = read_record.rs.map(|r| memory.record_by_id(r));
400400
for (i, r) in rs.iter().enumerate() {
401401
row_slice.rs_ptr[i] = r.pointer;
402-
row_slice.rs_val[i].copy_from_slice(&r.data);
402+
row_slice.rs_val[i].copy_from_slice(r.data_slice());
403403
aux_cols_factory.generate_read_aux(r, &mut row_slice.rs_read_aux[i]);
404404
for (j, x) in read_record.reads[i].iter().enumerate() {
405405
let read = memory.record_by_id(*x);
@@ -414,7 +414,9 @@ impl<
414414
// Range checks
415415
let need_range_check: [u32; 2] = from_fn(|i| {
416416
if i < NUM_READS {
417-
rs[i].data[RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32()
417+
rs[i]
418+
.data_at(RV32_REGISTER_NUM_LIMBS - 1)
419+
.as_canonical_u32()
418420
} else {
419421
0
420422
}

extensions/rv32-adapters/src/heap_branch.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ impl<F: PrimeField32, const NUM_READS: usize, const READ_SIZE: usize> VmAdapterC
290290

291291
for (i, rs_read) in rs_reads.iter().enumerate() {
292292
row_slice.rs_ptr[i] = rs_read.pointer;
293-
row_slice.rs_val[i].copy_from_slice(&rs_read.data);
293+
row_slice.rs_val[i].copy_from_slice(rs_read.data_slice());
294294
aux_cols_factory.generate_read_aux(rs_read, &mut row_slice.rs_read_aux[i]);
295295
}
296296

@@ -302,7 +302,11 @@ impl<F: PrimeField32, const NUM_READS: usize, const READ_SIZE: usize> VmAdapterC
302302
// Range checks:
303303
let need_range_check: Vec<u32> = rs_reads
304304
.iter()
305-
.map(|record| record.data[RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32())
305+
.map(|record| {
306+
record
307+
.data_at(RV32_REGISTER_NUM_LIMBS - 1)
308+
.as_canonical_u32()
309+
})
306310
.chain(once(0)) // in case NUM_READS is odd
307311
.collect();
308312
debug_assert!(self.air.address_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS);

extensions/rv32-adapters/src/vec_heap.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -502,11 +502,11 @@ pub(super) fn vec_heap_generate_trace_row_impl<
502502
.collect::<Vec<_>>();
503503

504504
row_slice.rd_ptr = rd.pointer;
505-
row_slice.rd_val.copy_from_slice(&rd.data);
505+
row_slice.rd_val.copy_from_slice(rd.data_slice());
506506

507507
for (i, r) in rs.iter().enumerate() {
508508
row_slice.rs_ptr[i] = r.pointer;
509-
row_slice.rs_val[i].copy_from_slice(&r.data);
509+
row_slice.rs_val[i].copy_from_slice(r.data_slice());
510510
aux_cols_factory.generate_read_aux(r, &mut row_slice.rs_read_aux[i]);
511511
}
512512

@@ -528,7 +528,11 @@ pub(super) fn vec_heap_generate_trace_row_impl<
528528
let need_range_check: Vec<u32> = rs
529529
.iter()
530530
.chain(std::iter::repeat(&rd).take(2))
531-
.map(|record| record.data[RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32())
531+
.map(|record| {
532+
record
533+
.data_at(RV32_REGISTER_NUM_LIMBS - 1)
534+
.as_canonical_u32()
535+
})
532536
.collect();
533537
debug_assert!(address_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS);
534538
let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits;

0 commit comments

Comments
 (0)