Skip to content

Commit 849e2ac

Browse files
chore(cuda): update poseidon2 kernel to use memory manager (#2108)
Co-authored-by: stephenh-axiom-xyz <[email protected]>
1 parent 80dc2ba commit 849e2ac

File tree

3 files changed

+93
-33
lines changed

3 files changed

+93
-33
lines changed

crates/vm/cuda/src/system/poseidon2.cu

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -74,34 +74,30 @@ extern "C" int _system_poseidon2_tracegen(
7474
return cudaGetLastError();
7575
}
7676

77-
// Reduces the records, removing duplicates and storing the number of times
78-
// each occurs in d_counts. The number of records after reduction is stored
79-
// into host pointer num_records.
80-
extern "C" int _system_poseidon2_deduplicate_records(
77+
// Prepares d_num_records for use with sort reduce and stores the temporary buffer
78+
// size necessary for both cub functions (i.e. sort and reduce).
79+
extern "C" int _system_poseidon2_deduplicate_records_get_temp_bytes(
8180
Fp *d_records,
8281
uint32_t *d_counts,
83-
size_t *num_records
82+
size_t num_records,
83+
size_t *d_num_records,
84+
size_t *h_temp_bytes_out
8485
) {
85-
auto [grid, block] = kernel_launch_params(*num_records);
86+
auto [grid, block] = kernel_launch_params(num_records);
8687
FpArray<16> *d_records_fp16 = reinterpret_cast<FpArray<16> *>(d_records);
87-
size_t *d_num_records;
8888

8989
// We want to sort and reduce the raw records, keeping track of how many
90-
// each occurs in d_counts. To prepare for reduce we need to a) allocate
91-
// d_num_records, b) fill d_counts with 1s, and c) group keys together
92-
// using sort.
93-
cudaMallocAsync(&d_num_records, sizeof(size_t), cudaStreamPerThread);
94-
cudaMemcpyAsync(
95-
d_num_records, num_records, sizeof(size_t), cudaMemcpyHostToDevice, cudaStreamPerThread
96-
);
97-
fill_buffer<uint32_t><<<grid, block, 0, cudaStreamPerThread>>>(d_counts, 1, *num_records);
90+
// each occurs in d_counts. To prepare for reduce we need to a) fill
91+
// d_counts with 1s, and b) group keys together using sort. Note we do
92+
// b) in the kernel below.
93+
fill_buffer<uint32_t><<<grid, block>>>(d_counts, 1, num_records);
9894

9995
size_t sort_storage_bytes = 0;
10096
cub::DeviceMergeSort::SortKeys(
10197
nullptr,
10298
sort_storage_bytes,
10399
d_records_fp16,
104-
*num_records,
100+
num_records,
105101
Fp16CompareOp(),
106102
cudaStreamPerThread
107103
);
@@ -116,13 +112,27 @@ extern "C" int _system_poseidon2_deduplicate_records(
116112
d_counts,
117113
d_num_records,
118114
std::plus(),
119-
*num_records,
115+
num_records,
120116
cudaStreamPerThread
121117
);
122118

123-
size_t temp_storage_bytes = std::max(sort_storage_bytes, reduce_storage_bytes);
124-
void *d_temp_storage = nullptr;
125-
cudaMallocAsync(&d_temp_storage, temp_storage_bytes, cudaStreamPerThread);
119+
*h_temp_bytes_out = std::max(sort_storage_bytes, reduce_storage_bytes);
120+
return cudaGetLastError();
121+
}
122+
123+
// Reduces the records, removing duplicates and storing the number of times
124+
// each occurs in d_counts. The number of records after reduction is stored
125+
// into host pointer num_records. The value of temp_storage_bytes should be
126+
// computed using _system_poseidon2_deduplicate_records_get_temp_bytes.
127+
extern "C" int _system_poseidon2_deduplicate_records(
128+
Fp *d_records,
129+
uint32_t *d_counts,
130+
size_t num_records,
131+
size_t *d_num_records,
132+
void *d_temp_storage,
133+
size_t temp_storage_bytes
134+
) {
135+
FpArray<16> *d_records_fp16 = reinterpret_cast<FpArray<16> *>(d_records);
126136

127137
// TODO: We currently can't use DeviceRadixSort since each key is 64 bytes
128138
// which causes Fp16Decomposer usage to exceed shared memory. We need to
@@ -131,7 +141,7 @@ extern "C" int _system_poseidon2_deduplicate_records(
131141
d_temp_storage,
132142
temp_storage_bytes,
133143
d_records_fp16,
134-
*num_records,
144+
num_records,
135145
Fp16CompareOp(),
136146
cudaStreamPerThread
137147
);
@@ -148,14 +158,9 @@ extern "C" int _system_poseidon2_deduplicate_records(
148158
d_counts,
149159
d_num_records,
150160
std::plus(),
151-
*num_records,
161+
num_records,
152162
cudaStreamPerThread
153163
);
154164

155-
cudaMemcpyAsync(
156-
num_records, d_num_records, sizeof(size_t), cudaMemcpyDeviceToHost, cudaStreamPerThread
157-
);
158-
cudaFreeAsync(d_num_records, cudaStreamPerThread);
159-
cudaFreeAsync(d_temp_storage, cudaStreamPerThread);
160165
return cudaGetLastError();
161166
}

crates/vm/src/cuda_abi.rs

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,21 @@ pub mod poseidon2 {
127127
sbox_regs: usize,
128128
) -> i32;
129129

130+
fn _system_poseidon2_deduplicate_records_get_temp_bytes(
131+
d_records: *mut F,
132+
d_counts: *mut u32,
133+
num_records: usize,
134+
d_num_records: *mut usize,
135+
h_temp_bytes_out: *mut usize,
136+
) -> i32;
137+
130138
fn _system_poseidon2_deduplicate_records(
131139
d_records: *mut F,
132140
d_counts: *mut u32,
133-
num_records: *mut usize,
141+
num_records: usize,
142+
d_num_records: *mut usize,
143+
d_temp_storage: *mut std::ffi::c_void,
144+
temp_storage_bytes: usize,
134145
) -> i32;
135146
}
136147

@@ -154,15 +165,37 @@ pub mod poseidon2 {
154165
))
155166
}
156167

168+
pub unsafe fn deduplicate_records_get_temp_bytes(
169+
d_records: &DeviceBuffer<F>,
170+
d_counts: &DeviceBuffer<u32>,
171+
num_records: usize,
172+
d_num_records: &DeviceBuffer<usize>,
173+
h_temp_bytes_out: &mut usize,
174+
) -> Result<(), CudaError> {
175+
CudaError::from_result(_system_poseidon2_deduplicate_records_get_temp_bytes(
176+
d_records.as_mut_ptr(),
177+
d_counts.as_mut_ptr(),
178+
num_records,
179+
d_num_records.as_mut_ptr(),
180+
h_temp_bytes_out,
181+
))
182+
}
183+
157184
pub unsafe fn deduplicate_records(
158185
d_records: &DeviceBuffer<F>,
159186
d_counts: &DeviceBuffer<u32>,
160-
num_records: &mut usize,
187+
num_records: usize,
188+
d_num_records: &DeviceBuffer<usize>,
189+
d_temp_storage: &DeviceBuffer<u8>,
190+
temp_storage_bytes: usize,
161191
) -> Result<(), CudaError> {
162192
CudaError::from_result(_system_poseidon2_deduplicate_records(
163193
d_records.as_mut_ptr(),
164194
d_counts.as_mut_ptr(),
165-
num_records as *mut usize,
195+
num_records,
196+
d_num_records.as_mut_ptr(),
197+
d_temp_storage.as_mut_raw_ptr(),
198+
temp_storage_bytes,
166199
))
167200
}
168201
}

crates/vm/src/system/cuda/poseidon2.rs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ use openvm_circuit::{
66
system::poseidon2::columns::Poseidon2PeripheryCols, utils::next_power_of_two_or_zero,
77
};
88
use openvm_cuda_backend::{base::DeviceMatrix, prelude::F, prover_backend::GpuBackend};
9-
use openvm_cuda_common::{copy::MemCopyD2H, d_buffer::DeviceBuffer};
9+
use openvm_cuda_common::{
10+
copy::{MemCopyD2H, MemCopyH2D},
11+
d_buffer::DeviceBuffer,
12+
};
1013
use openvm_stark_backend::{
1114
prover::{hal::MatrixDimensions, types::AirProvingContext},
1215
Chip,
@@ -60,8 +63,27 @@ impl<RA, const SBOX_REGISTERS: usize> Chip<RA, GpuBackend> for Poseidon2ChipGPU<
6063
let mut num_records = self.idx.to_host().unwrap()[0] as usize;
6164
let counts = DeviceBuffer::<u32>::with_capacity(num_records);
6265
unsafe {
63-
poseidon2::deduplicate_records(&self.records, &counts, &mut num_records)
64-
.expect("Failed to deduplicate records");
66+
let d_num_records = [num_records].to_device().unwrap();
67+
let mut temp_bytes = 0;
68+
poseidon2::deduplicate_records_get_temp_bytes(
69+
&self.records,
70+
&counts,
71+
num_records,
72+
&d_num_records,
73+
&mut temp_bytes,
74+
)
75+
.expect("Failed to get temp bytes");
76+
let d_temp_storage = DeviceBuffer::<u8>::with_capacity(temp_bytes);
77+
poseidon2::deduplicate_records(
78+
&self.records,
79+
&counts,
80+
num_records,
81+
&d_num_records,
82+
&d_temp_storage,
83+
temp_bytes,
84+
)
85+
.expect("Failed to deduplicate records");
86+
num_records = *d_num_records.to_host().unwrap().first().unwrap();
6587
}
6688
#[cfg(feature = "metrics")]
6789
self.current_trace_height

0 commit comments

Comments
 (0)