@@ -74,34 +74,30 @@ extern "C" int _system_poseidon2_tracegen(
74
74
return cudaGetLastError ();
75
75
}
76
76
77
- // Reduces the records, removing duplicates and storing the number of times
78
- // each occurs in d_counts. The number of records after reduction is stored
79
- // into host pointer num_records.
80
- extern " C" int _system_poseidon2_deduplicate_records (
77
+ // Prepares d_num_records for use with sort reduce and stores the temporary buffer
78
+ // size necessary for both cub functions (i.e. sort and reduce).
79
+ extern " C" int _system_poseidon2_deduplicate_records_get_temp_bytes (
81
80
Fp *d_records,
82
81
uint32_t *d_counts,
83
- size_t *num_records
82
+ size_t num_records,
83
+ size_t *d_num_records,
84
+ size_t *h_temp_bytes_out
84
85
) {
85
- auto [grid, block] = kernel_launch_params (* num_records);
86
+ auto [grid, block] = kernel_launch_params (num_records);
86
87
FpArray<16 > *d_records_fp16 = reinterpret_cast <FpArray<16 > *>(d_records);
87
- size_t *d_num_records;
88
88
89
89
// We want to sort and reduce the raw records, keeping track of how many
90
- // each occurs in d_counts. To prepare for reduce we need to a) allocate
91
- // d_num_records, b) fill d_counts with 1s, and c) group keys together
92
- // using sort.
93
- cudaMallocAsync (&d_num_records, sizeof (size_t ), cudaStreamPerThread);
94
- cudaMemcpyAsync (
95
- d_num_records, num_records, sizeof (size_t ), cudaMemcpyHostToDevice, cudaStreamPerThread
96
- );
97
- fill_buffer<uint32_t ><<<grid, block, 0 , cudaStreamPerThread>>> (d_counts, 1 , *num_records);
90
+ // each occurs in d_counts. To prepare for reduce we need to a) fill
91
+ // d_counts with 1s, and b) group keys together using sort. Note we do
92
+ // b) in the kernel below.
93
+ fill_buffer<uint32_t ><<<grid, block>>> (d_counts, 1 , num_records);
98
94
99
95
size_t sort_storage_bytes = 0 ;
100
96
cub::DeviceMergeSort::SortKeys (
101
97
nullptr ,
102
98
sort_storage_bytes,
103
99
d_records_fp16,
104
- * num_records,
100
+ num_records,
105
101
Fp16CompareOp (),
106
102
cudaStreamPerThread
107
103
);
@@ -116,13 +112,27 @@ extern "C" int _system_poseidon2_deduplicate_records(
116
112
d_counts,
117
113
d_num_records,
118
114
std::plus (),
119
- * num_records,
115
+ num_records,
120
116
cudaStreamPerThread
121
117
);
122
118
123
- size_t temp_storage_bytes = std::max (sort_storage_bytes, reduce_storage_bytes);
124
- void *d_temp_storage = nullptr ;
125
- cudaMallocAsync (&d_temp_storage, temp_storage_bytes, cudaStreamPerThread);
119
+ *h_temp_bytes_out = std::max (sort_storage_bytes, reduce_storage_bytes);
120
+ return cudaGetLastError ();
121
+ }
122
+
123
+ // Reduces the records, removing duplicates and storing the number of times
124
+ // each occurs in d_counts. The number of records after reduction is stored
125
+ // into host pointer num_records. The value of temp_storage_bytes should be
126
+ // computed using _system_poseidon2_deduplicate_records_get_temp_bytes.
127
+ extern " C" int _system_poseidon2_deduplicate_records (
128
+ Fp *d_records,
129
+ uint32_t *d_counts,
130
+ size_t num_records,
131
+ size_t *d_num_records,
132
+ void *d_temp_storage,
133
+ size_t temp_storage_bytes
134
+ ) {
135
+ FpArray<16 > *d_records_fp16 = reinterpret_cast <FpArray<16 > *>(d_records);
126
136
127
137
// TODO: We currently can't use DeviceRadixSort since each key is 64 bytes
128
138
// which causes Fp16Decomposer usage to exceed shared memory. We need to
@@ -131,7 +141,7 @@ extern "C" int _system_poseidon2_deduplicate_records(
131
141
d_temp_storage,
132
142
temp_storage_bytes,
133
143
d_records_fp16,
134
- * num_records,
144
+ num_records,
135
145
Fp16CompareOp (),
136
146
cudaStreamPerThread
137
147
);
@@ -148,14 +158,9 @@ extern "C" int _system_poseidon2_deduplicate_records(
148
158
d_counts,
149
159
d_num_records,
150
160
std::plus (),
151
- * num_records,
161
+ num_records,
152
162
cudaStreamPerThread
153
163
);
154
164
155
- cudaMemcpyAsync (
156
- num_records, d_num_records, sizeof (size_t ), cudaMemcpyDeviceToHost, cudaStreamPerThread
157
- );
158
- cudaFreeAsync (d_num_records, cudaStreamPerThread);
159
- cudaFreeAsync (d_temp_storage, cudaStreamPerThread);
160
165
return cudaGetLastError ();
161
166
}
0 commit comments