Skip to content

Commit 6402afe

Browse files
committed
feat: update gpu tracegen for boundary chip
1 parent 439f596 commit 6402afe

File tree

7 files changed

+200
-68
lines changed

7 files changed

+200
-68
lines changed

crates/vm/cuda/src/system/boundary.cu

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
#include <cassert>
77

88
inline constexpr size_t PERSISTENT_CHUNK = 8;
9+
inline constexpr size_t BLOCKS_PER_CHUNK = 2;
910
inline constexpr size_t VOLATILE_CHUNK = 1;
1011

11-
template <size_t CHUNK> struct BoundaryRecord {
12+
template <size_t CHUNK, size_t BLOCKS> struct BoundaryRecord {
1213
uint32_t address_space;
1314
uint32_t ptr;
14-
uint32_t timestamp;
15+
uint32_t timestamps[BLOCKS];
1516
uint32_t values[CHUNK];
1617
};
1718

@@ -21,7 +22,7 @@ template <typename T> struct PersistentBoundaryCols {
2122
T leaf_label;
2223
T values[PERSISTENT_CHUNK];
2324
T hash[PERSISTENT_CHUNK];
24-
T timestamp;
25+
T timestamps[BLOCKS_PER_CHUNK];
2526
};
2627

2728
inline constexpr size_t ADDR_ELTS = 2;
@@ -42,7 +43,7 @@ __global__ void cukernel_persistent_boundary_tracegen(
4243
size_t height,
4344
size_t width,
4445
uint8_t const *const *initial_mem,
45-
BoundaryRecord<PERSISTENT_CHUNK> *records,
46+
BoundaryRecord<PERSISTENT_CHUNK, BLOCKS_PER_CHUNK> *records,
4647
size_t num_records,
4748
FpArray<16> *poseidon2_buffer,
4849
uint32_t *poseidon2_buffer_idx,
@@ -53,7 +54,7 @@ __global__ void cukernel_persistent_boundary_tracegen(
5354
RowSlice row = RowSlice(trace + row_idx, height);
5455

5556
if (record_idx < num_records) {
56-
BoundaryRecord<PERSISTENT_CHUNK> record = records[record_idx];
57+
BoundaryRecord<PERSISTENT_CHUNK, BLOCKS_PER_CHUNK> record = records[record_idx];
5758
Poseidon2Buffer poseidon2(poseidon2_buffer, poseidon2_buffer_idx, poseidon2_capacity);
5859
COL_WRITE_VALUE(row, PersistentBoundaryCols, address_space, record.address_space);
5960
COL_WRITE_VALUE(row, PersistentBoundaryCols, leaf_label, record.ptr / PERSISTENT_CHUNK);
@@ -77,24 +78,32 @@ __global__ void cukernel_persistent_boundary_tracegen(
7778
}
7879
FpArray<8> init_hash = poseidon2.hash_and_record(init_values);
7980
COL_WRITE_VALUE(row, PersistentBoundaryCols, expand_direction, Fp::one());
80-
COL_WRITE_VALUE(row, PersistentBoundaryCols, timestamp, Fp::zero());
8181
COL_WRITE_ARRAY(
8282
row, PersistentBoundaryCols, values, reinterpret_cast<Fp const *>(init_values.v)
8383
);
8484
COL_WRITE_ARRAY(
8585
row, PersistentBoundaryCols, hash, reinterpret_cast<Fp const *>(init_hash.v)
8686
);
87+
Fp ts_values[BLOCKS_PER_CHUNK];
88+
for (int i = 0; i < BLOCKS_PER_CHUNK; ++i) {
89+
ts_values[i] = Fp::zero();
90+
}
91+
COL_WRITE_ARRAY(row, PersistentBoundaryCols, timestamps, ts_values);
8792
} else {
8893
FpArray<8> final_values = FpArray<8>::from_raw_array(record.values);
8994
FpArray<8> final_hash = poseidon2.hash_and_record(final_values);
9095
COL_WRITE_VALUE(row, PersistentBoundaryCols, expand_direction, Fp::neg_one());
91-
COL_WRITE_VALUE(row, PersistentBoundaryCols, timestamp, record.timestamp);
9296
COL_WRITE_ARRAY(
9397
row, PersistentBoundaryCols, values, reinterpret_cast<Fp const *>(final_values.v)
9498
);
9599
COL_WRITE_ARRAY(
96100
row, PersistentBoundaryCols, hash, reinterpret_cast<Fp const *>(final_hash.v)
97101
);
102+
Fp ts_values[BLOCKS_PER_CHUNK];
103+
for (int i = 0; i < BLOCKS_PER_CHUNK; ++i) {
104+
ts_values[i] = Fp(record.timestamps[i]);
105+
}
106+
COL_WRITE_ARRAY(row, PersistentBoundaryCols, timestamps, ts_values);
98107
}
99108
} else {
100109
row.fill_zero(0, width);
@@ -105,7 +114,7 @@ __global__ void cukernel_volatile_boundary_tracegen(
105114
Fp *trace,
106115
size_t height,
107116
size_t width,
108-
BoundaryRecord<VOLATILE_CHUNK> const *records,
117+
BoundaryRecord<VOLATILE_CHUNK, 1> const *records,
109118
size_t num_records,
110119
uint32_t *range_checker,
111120
size_t range_checker_num_bins,
@@ -122,7 +131,7 @@ __global__ void cukernel_volatile_boundary_tracegen(
122131
// For the sake of always filling `addr_lt_aux`
123132
row.fill_zero(0, width);
124133
}
125-
BoundaryRecord<VOLATILE_CHUNK> record = records[idx];
134+
BoundaryRecord<VOLATILE_CHUNK, 1> record = records[idx];
126135
rc.decompose(
127136
record.address_space,
128137
as_max_bits,
@@ -137,11 +146,11 @@ __global__ void cukernel_volatile_boundary_tracegen(
137146
);
138147
COL_WRITE_VALUE(row, VolatileBoundaryCols, initial_data, Fp::zero());
139148
COL_WRITE_VALUE(row, VolatileBoundaryCols, final_data, record.values[0]);
140-
COL_WRITE_VALUE(row, VolatileBoundaryCols, final_timestamp, record.timestamp);
149+
COL_WRITE_VALUE(row, VolatileBoundaryCols, final_timestamp, record.timestamps[0]);
141150
COL_WRITE_VALUE(row, VolatileBoundaryCols, is_valid, Fp::one());
142151

143152
if (idx != num_records - 1) {
144-
BoundaryRecord<VOLATILE_CHUNK> next_record = records[idx + 1];
153+
BoundaryRecord<VOLATILE_CHUNK, 1> next_record = records[idx + 1];
145154
uint32_t curr[ADDR_ELTS] = {record.address_space, record.ptr};
146155
uint32_t next[ADDR_ELTS] = {next_record.address_space, next_record.ptr};
147156
IsLessThanArray::generate_subrow(
@@ -189,8 +198,8 @@ extern "C" int _persistent_boundary_tracegen(
189198
size_t poseidon2_capacity
190199
) {
191200
auto [grid, block] = kernel_launch_params(height);
192-
BoundaryRecord<PERSISTENT_CHUNK> *d_records =
193-
reinterpret_cast<BoundaryRecord<PERSISTENT_CHUNK> *>(d_raw_records);
201+
BoundaryRecord<PERSISTENT_CHUNK, BLOCKS_PER_CHUNK> *d_records =
202+
reinterpret_cast<BoundaryRecord<PERSISTENT_CHUNK, BLOCKS_PER_CHUNK> *>(d_raw_records);
194203
FpArray<16> *d_poseidon2_buffer = reinterpret_cast<FpArray<16> *>(d_poseidon2_raw_buffer);
195204
cukernel_persistent_boundary_tracegen<<<grid, block>>>(
196205
d_trace,
@@ -218,7 +227,7 @@ extern "C" int _volatile_boundary_tracegen(
218227
size_t ptr_max_bits
219228
) {
220229
auto [grid, block] = kernel_launch_params(height, 512);
221-
auto d_records = reinterpret_cast<BoundaryRecord<VOLATILE_CHUNK> const *>(d_raw_records);
230+
auto d_records = reinterpret_cast<BoundaryRecord<VOLATILE_CHUNK, 1> const *>(d_raw_records);
222231
cukernel_volatile_boundary_tracegen<<<grid, block>>>(
223232
d_trace,
224233
height,

crates/vm/src/arch/testing/cuda.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ use crate::{
5454
POSEIDON2_DIRECT_BUS, READ_INSTRUCTION_BUS,
5555
},
5656
Arena, DenseRecordArena, ExecutionBridge, ExecutionBus, ExecutionState, MatrixRecordArena,
57-
MemoryConfig, PreflightExecutor, Streams, VmStateMut,
57+
MemoryConfig, PreflightExecutor, Streams, VmStateMut, CONST_BLOCK_SIZE,
5858
},
5959
system::{
60-
cuda::{poseidon2::Poseidon2PeripheryChipGPU, DIGEST_WIDTH},
60+
cuda::poseidon2::Poseidon2PeripheryChipGPU,
6161
memory::{
6262
offline_checker::{MemoryBridge, MemoryBus},
6363
MemoryAirInventory, SharedMemoryHelper,
@@ -314,7 +314,7 @@ impl GpuChipTestBuilder {
314314
)));
315315
Self {
316316
memory: DeviceMemoryTester::persistent(
317-
default_tracing_memory(&mem_config, DIGEST_WIDTH),
317+
default_tracing_memory(&mem_config, CONST_BLOCK_SIZE),
318318
mem_bus,
319319
mem_config,
320320
range_checker.clone(),

crates/vm/src/system/cuda/boundary.rs

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::sync::Arc;
22

33
use openvm_circuit::{
4+
arch::CONST_BLOCK_SIZE,
45
system::memory::{
56
persistent::PersistentBoundaryCols, volatile::VolatileBoundaryCols,
67
TimestampedEquipartition, TimestampedValues,
@@ -19,7 +20,7 @@ use openvm_stark_backend::{
1920
Chip,
2021
};
2122

22-
use super::{merkle_tree::TIMESTAMPED_BLOCK_WIDTH, poseidon2::SharedBuffer};
23+
use super::{poseidon2::SharedBuffer, DIGEST_WIDTH};
2324
use crate::cuda_abi::boundary::{persistent_boundary_tracegen, volatile_boundary_tracegen};
2425

2526
pub struct PersistentBoundary {
@@ -28,7 +29,7 @@ pub struct PersistentBoundary {
2829
/// This struct cannot own the device memory, hence we take extra care not to use memory we
2930
/// don't own. TODO: use `Arc<DeviceBuffer>` instead?
3031
pub initial_leaves: Vec<*const std::ffi::c_void>,
31-
pub touched_blocks: Option<DeviceBuffer<u32>>,
32+
pub records: Option<DeviceBuffer<u32>>,
3233
}
3334

3435
pub struct VolatileBoundary {
@@ -49,13 +50,24 @@ pub struct BoundaryChipGPU {
4950
pub trace_width: Option<usize>,
5051
}
5152

53+
const BLOCKS_PER_CHUNK: usize = DIGEST_WIDTH / CONST_BLOCK_SIZE;
54+
55+
#[repr(C)]
56+
#[derive(Clone, Copy)]
57+
pub struct PersistentBoundaryRecord {
58+
pub address_space: u32,
59+
pub ptr: u32,
60+
pub timestamps: [u32; BLOCKS_PER_CHUNK],
61+
pub values: [F; DIGEST_WIDTH],
62+
}
63+
5264
impl BoundaryChipGPU {
5365
pub fn persistent(poseidon2_buffer: SharedBuffer<F>) -> Self {
5466
Self {
5567
fields: BoundaryFields::Persistent(PersistentBoundary {
5668
poseidon2_buffer,
5769
initial_leaves: Vec::new(),
58-
touched_blocks: None,
70+
records: None,
5971
}),
6072
num_records: None,
6173
trace_width: None,
@@ -106,14 +118,18 @@ impl BoundaryChipGPU {
106118

107119
pub fn finalize_records_persistent<const CHUNK: usize>(
108120
&mut self,
109-
touched_blocks: DeviceBuffer<u32>,
121+
records: Vec<PersistentBoundaryRecord>,
110122
) {
111123
match &mut self.fields {
112124
BoundaryFields::Volatile(_) => panic!("call `finalize_records_volatile`"),
113125
BoundaryFields::Persistent(fields) => {
114-
self.num_records = Some(touched_blocks.len() / TIMESTAMPED_BLOCK_WIDTH);
126+
self.num_records = Some(records.len());
115127
self.trace_width = Some(PersistentBoundaryCols::<F, CHUNK>::width());
116-
fields.touched_blocks = Some(touched_blocks);
128+
fields.records = Some(if records.is_empty() {
129+
DeviceBuffer::new()
130+
} else {
131+
records.to_device().unwrap().as_buffer::<u32>()
132+
});
117133
}
118134
}
119135
}
@@ -144,7 +160,7 @@ impl<RA> Chip<RA, GpuBackend> for BoundaryChipGPU {
144160
trace.height(),
145161
trace.width(),
146162
&mem_ptrs,
147-
boundary.touched_blocks.as_ref().unwrap(),
163+
boundary.records.as_ref().unwrap(),
148164
num_records,
149165
&boundary.poseidon2_buffer.buffer,
150166
&boundary.poseidon2_buffer.idx,

crates/vm/src/system/cuda/memory.rs

Lines changed: 105 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::sync::Arc;
1+
use std::{collections::BTreeMap, sync::Arc};
22

33
use openvm_circuit::{
44
arch::{
@@ -12,7 +12,7 @@ use openvm_circuit::{
1212
use openvm_circuit_primitives::var_range::VariableRangeCheckerChipGPU;
1313
use openvm_cuda_backend::{prover_backend::GpuBackend, types::F};
1414
use openvm_cuda_common::{
15-
copy::{cuda_memcpy, MemCopyD2D, MemCopyH2D},
15+
copy::{cuda_memcpy, MemCopyH2D},
1616
d_buffer::DeviceBuffer,
1717
memory_manager::MemTracker,
1818
};
@@ -22,8 +22,8 @@ use openvm_stark_backend::{
2222

2323
use super::{
2424
access_adapters::AccessAdapterInventoryGPU,
25-
boundary::{BoundaryChipGPU, BoundaryFields},
26-
merkle_tree::{MemoryMerkleTree, TIMESTAMPED_BLOCK_WIDTH},
25+
boundary::{BoundaryChipGPU, BoundaryFields, PersistentBoundaryRecord},
26+
merkle_tree::{MemoryMerkleTree, MERKLE_TOUCHED_BLOCK_WIDTH, TIMESTAMPED_BLOCK_WIDTH},
2727
Poseidon2PeripheryChipGPU, DIGEST_WIDTH,
2828
};
2929

@@ -151,7 +151,34 @@ impl MemoryInventoryGPU {
151151
}
152152

153153
mem.tracing_info("boundary finalize");
154-
let (touched_memory, empty) = if partition.is_empty() {
154+
let read_chunk_from_device_raw =
155+
|addr_space: u32, chunk_ptr: u32| -> [F; DIGEST_WIDTH] {
156+
let mut res = [F::ZERO; DIGEST_WIDTH];
157+
let addr_space_idx = addr_space as usize;
158+
let d_mem = &persistent.initial_memory[addr_space_idx];
159+
if d_mem.is_empty() {
160+
return res;
161+
}
162+
let layout = &persistent.merkle_tree.mem_config().addr_spaces[addr_space_idx]
163+
.layout;
164+
let one_cell_size = layout.size();
165+
let offset = chunk_ptr as usize * one_cell_size;
166+
let mut values = vec![0u8; one_cell_size * DIGEST_WIDTH];
167+
unsafe {
168+
cuda_memcpy::<true, false>(
169+
values.as_mut_ptr() as *mut std::ffi::c_void,
170+
d_mem.as_ptr().add(offset) as *const std::ffi::c_void,
171+
values.len(),
172+
)
173+
.unwrap();
174+
for i in 0..DIGEST_WIDTH {
175+
res[i] = layout.to_field::<F>(&values[i * one_cell_size..]);
176+
}
177+
}
178+
res
179+
};
180+
let partition_is_empty = partition.is_empty();
181+
let (touched_memory, empty) = if partition_is_empty {
155182
let leftmost_values = 'left: {
156183
let mut res = [F::ZERO; CONST_BLOCK_SIZE];
157184
if persistent.initial_memory[ADDR_SPACE_OFFSET as usize].is_empty() {
@@ -190,24 +217,87 @@ impl MemoryInventoryGPU {
190217
} else {
191218
(partition, false)
192219
};
220+
let merkle_touched_memory = {
221+
let mut chunk_map: BTreeMap<(u32, u32), (u32, [F; DIGEST_WIDTH])> =
222+
BTreeMap::new();
223+
for &((addr_space, ptr), ts_values) in touched_memory.iter() {
224+
let chunk_ptr = (ptr / DIGEST_WIDTH as u32) * DIGEST_WIDTH as u32;
225+
let block_idx_in_chunk =
226+
((ptr % DIGEST_WIDTH as u32) / CONST_BLOCK_SIZE as u32) as usize;
227+
let entry = chunk_map.entry((addr_space, chunk_ptr)).or_insert_with(|| {
228+
(
229+
0u32,
230+
read_chunk_from_device_raw(addr_space, chunk_ptr),
231+
)
232+
});
233+
entry.0 = entry.0.max(ts_values.timestamp);
234+
for (i, val) in ts_values.values.iter().copied().enumerate() {
235+
entry.1[block_idx_in_chunk * CONST_BLOCK_SIZE + i] = val;
236+
}
237+
}
238+
chunk_map
239+
.into_iter()
240+
.map(|((addr_space, ptr), (timestamp, values))| {
241+
(
242+
(addr_space, ptr),
243+
TimestampedValues { timestamp, values },
244+
)
245+
})
246+
.collect::<Vec<_>>()
247+
};
248+
let boundary_records = if partition_is_empty {
249+
Vec::new()
250+
} else {
251+
let mut chunk_map: BTreeMap<
252+
(u32, u32),
253+
([u32; DIGEST_WIDTH / CONST_BLOCK_SIZE], [F; DIGEST_WIDTH]),
254+
> = BTreeMap::new();
255+
for &((addr_space, ptr), ts_values) in touched_memory.iter() {
256+
let chunk_ptr = (ptr / DIGEST_WIDTH as u32) * DIGEST_WIDTH as u32;
257+
let block_idx_in_chunk =
258+
((ptr % DIGEST_WIDTH as u32) / CONST_BLOCK_SIZE as u32) as usize;
259+
let entry = chunk_map.entry((addr_space, chunk_ptr)).or_insert_with(|| {
260+
(
261+
[0u32; DIGEST_WIDTH / CONST_BLOCK_SIZE],
262+
read_chunk_from_device_raw(addr_space, chunk_ptr),
263+
)
264+
});
265+
entry.0[block_idx_in_chunk] = ts_values.timestamp;
266+
for (i, val) in ts_values.values.iter().copied().enumerate() {
267+
entry.1[block_idx_in_chunk * CONST_BLOCK_SIZE + i] = val;
268+
}
269+
}
270+
chunk_map
271+
.into_iter()
272+
.map(|((addr_space, ptr), (timestamps, values))| {
273+
PersistentBoundaryRecord {
274+
address_space: addr_space,
275+
ptr,
276+
timestamps,
277+
values,
278+
}
279+
})
280+
.collect::<Vec<_>>()
281+
};
193282
debug_assert_eq!(
194283
size_of_val(&touched_memory[0]),
195284
TIMESTAMPED_BLOCK_WIDTH * size_of::<u32>()
196285
);
197-
let d_touched_memory = touched_memory.to_device().unwrap().as_buffer::<u32>();
198-
if empty {
199-
self.boundary
200-
.finalize_records_persistent::<DIGEST_WIDTH>(DeviceBuffer::new());
201-
} else {
202-
self.boundary.finalize_records_persistent::<DIGEST_WIDTH>(
203-
d_touched_memory.device_copy().unwrap().as_buffer::<u32>(),
204-
); // TODO do not copy
205-
}
286+
let d_merkle_touched_memory = merkle_touched_memory
287+
.to_device()
288+
.unwrap()
289+
.as_buffer::<u32>();
290+
debug_assert_eq!(
291+
size_of_val(&merkle_touched_memory[0]),
292+
MERKLE_TOUCHED_BLOCK_WIDTH * size_of::<u32>()
293+
);
294+
self.boundary
295+
.finalize_records_persistent::<DIGEST_WIDTH>(boundary_records);
206296
mem.tracing_info("merkle update");
207297
persistent.merkle_tree.finalize();
208298
let merkle_tree_ctx = persistent.merkle_tree.update_with_touched_blocks(
209299
unpadded_merkle_height,
210-
&d_touched_memory,
300+
&d_merkle_touched_memory,
211301
empty,
212302
);
213303
Some(merkle_tree_ctx)

0 commit comments

Comments
 (0)