Skip to content

Commit 3b3f502

Browse files
authored
feat(gpu_prover): coset and tree recomputations (#120)
## What ❔ This PR implements the re-computation capability for cosets and trees inside the proving workflow. ## Why ❔ Re-conputing cosets and/or trees for prover stages instead of keeping them trades memory usage for performance allows us to trade VRAM usage for performance. With this capability we are able to prove the reduced risc-v 2^23 circuit on GPUs with 24GB VRAM with a small performance penalty (estimated to be on the order of 10% on the L4 GPU). With this feature we can switch our recursion strategy to exclusive use of the 2^23 sized reduced risc-v circuit and therefore improve the reduction performance significantly and reduce recursion complexity. ## Is this a breaking change? - [ ] Yes - [x] No ## Checklist - [x] PR title corresponds to the body of PR (we generate changelog entries from PRs). - [x] Tests for the changes have been added / updated. - [x] Documentation comments have been added / updated. - [x] Code has been formatted.
1 parent b8446a1 commit 3b3f502

File tree

17 files changed

+956
-383
lines changed

17 files changed

+956
-383
lines changed

circuit_defs/prover_examples/src/gpu.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,14 @@ pub fn gpu_prove_image_execution_for_machine_with_gpu_tracers<
256256
let log_domain_size = trace_len.trailing_zeros();
257257
let log_tree_cap_size =
258258
OPTIMAL_FOLDING_PROPERTIES[log_domain_size as usize].total_caps_size_log2 as u32;
259-
let mut setup =
260-
SetupPrecomputations::new(circuit, log_lde_factor, log_tree_cap_size, prover_context)?;
259+
let mut setup = SetupPrecomputations::new(
260+
circuit,
261+
log_lde_factor,
262+
log_tree_cap_size,
263+
false,
264+
false,
265+
prover_context,
266+
)?;
261267
setup.schedule_transfer(Arc::new(setup_evaluations), prover_context)?;
262268
setup
263269
};
@@ -302,6 +308,8 @@ pub fn gpu_prove_image_execution_for_machine_with_gpu_tracers<
302308
NUM_QUERIES,
303309
POW_BITS,
304310
None,
311+
false,
312+
false,
305313
prover_context,
306314
)?;
307315
job.finish()?
@@ -372,6 +380,8 @@ pub fn gpu_prove_image_execution_for_machine_with_gpu_tracers<
372380
circuit,
373381
log_lde_factor,
374382
log_tree_cap_size,
383+
false,
384+
false,
375385
prover_context,
376386
)?;
377387
setup.schedule_transfer(Arc::new(setup_evaluations), prover_context)?;
@@ -407,6 +417,8 @@ pub fn gpu_prove_image_execution_for_machine_with_gpu_tracers<
407417
NUM_QUERIES,
408418
POW_BITS,
409419
None,
420+
false,
421+
false,
410422
prover_context,
411423
)?;
412424
job.finish()?

gpu_prover/native/ntt/natural_evals_to_bitrev_Z.cu

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace airbender::ntt {
88
// register array accesses dynamic and cause spilling. But, bizarrely, it doesn't: it has the opposite effect and
99
// prevents spilling.
1010

11-
template <unsigned LOG_VALS_PER_THREAD, bool evals_are_coset>
11+
template <unsigned LOG_VALS_PER_THREAD, bool evals_are_coset, bool evals_are_compressed = false>
1212
DEVICE_FORCEINLINE void evals_to_Z_final_stages_warp(vectorized_e2_matrix_getter<ld_modifier::cg> gmem_in,
1313
vectorized_e2_matrix_setter<st_modifier::cg> gmem_out, const unsigned start_stage,
1414
const unsigned stages_this_launch, const unsigned log_n, const unsigned num_Z_cols,
@@ -100,8 +100,13 @@ DEVICE_FORCEINLINE void evals_to_Z_final_stages_warp(vectorized_e2_matrix_getter
100100
const unsigned mem_idx = gmem_offset + 64 * i + 2 * lane_id;
101101
const unsigned idx0 = bitrev(mem_idx, log_n);
102102
const unsigned idx1 = bitrev(mem_idx + 1, log_n);
103-
vals[2 * i] = lde_scale<true>(vals[2 * i], idx0, 1, 1, log_n);
104-
vals[2 * i + 1] = lde_scale<true>(vals[2 * i + 1], idx1, 1, 1, log_n);
103+
if (evals_are_compressed) {
104+
vals[2 * i] = lde_scale_and_shift<true>(vals[2 * i], idx0, 1, 1, log_n);
105+
vals[2 * i + 1] = lde_scale_and_shift<true>(vals[2 * i + 1], idx1, 1, 1, log_n);
106+
} else {
107+
vals[2 * i] = lde_scale<true>(vals[2 * i], idx0, 1, 1, log_n);
108+
vals[2 * i + 1] = lde_scale<true>(vals[2 * i + 1], idx1, 1, 1, log_n);
109+
}
105110
}
106111
}
107112

@@ -144,7 +149,23 @@ EXTERN __launch_bounds__(128, 8) __global__
144149
evals_to_Z_final_stages_warp<2, true>(gmem_in, gmem_out, start_stage, stages_this_launch, log_n, num_Z_cols, grid_offset);
145150
}
146151

147-
template <unsigned LOG_VALS_PER_THREAD, bool evals_are_coset>
152+
EXTERN __launch_bounds__(128, 8) __global__
153+
void ab_compressed_coset_evals_to_Z_final_8_stages_warp(vectorized_e2_matrix_getter<ld_modifier::cg> gmem_in,
154+
vectorized_e2_matrix_setter<st_modifier::cg> gmem_out, const unsigned start_stage,
155+
const unsigned stages_this_launch, const unsigned log_n, const unsigned num_Z_cols,
156+
const unsigned grid_offset) {
157+
evals_to_Z_final_stages_warp<3, true, true>(gmem_in, gmem_out, start_stage, stages_this_launch, log_n, num_Z_cols, grid_offset);
158+
}
159+
160+
EXTERN __launch_bounds__(128, 8) __global__
161+
void ab_compressed_coset_evals_to_Z_final_7_stages_warp(vectorized_e2_matrix_getter<ld_modifier::cg> gmem_in,
162+
vectorized_e2_matrix_setter<st_modifier::cg> gmem_out, const unsigned start_stage,
163+
const unsigned stages_this_launch, const unsigned log_n, const unsigned num_Z_cols,
164+
const unsigned grid_offset) {
165+
evals_to_Z_final_stages_warp<2, true, true>(gmem_in, gmem_out, start_stage, stages_this_launch, log_n, num_Z_cols, grid_offset);
166+
}
167+
168+
template <unsigned LOG_VALS_PER_THREAD, bool evals_are_coset, bool evals_are_compressed = false>
148169
DEVICE_FORCEINLINE void evals_to_Z_final_stages_block(vectorized_e2_matrix_getter<ld_modifier::cg> gmem_in,
149170
vectorized_e2_matrix_setter<st_modifier::cg> gmem_out, const unsigned start_stage,
150171
const unsigned stages_this_launch, const unsigned log_n, const unsigned num_Z_cols,
@@ -326,8 +347,13 @@ DEVICE_FORCEINLINE void evals_to_Z_final_stages_block(vectorized_e2_matrix_gette
326347
const unsigned mem_idx = gmem_offset + 64 * i + 2 * lane_id;
327348
const unsigned idx0 = bitrev(mem_idx, log_n);
328349
const unsigned idx1 = bitrev(mem_idx + 1, log_n);
329-
vals[2 * i] = lde_scale<true>(vals[2 * i], idx0, 1, 1, log_n);
330-
vals[2 * i + 1] = lde_scale<true>(vals[2 * i + 1], idx1, 1, 1, log_n);
350+
if (evals_are_compressed) {
351+
vals[2 * i] = lde_scale_and_shift<true>(vals[2 * i], idx0, 1, 1, log_n);
352+
vals[2 * i + 1] = lde_scale_and_shift<true>(vals[2 * i + 1], idx1, 1, 1, log_n);
353+
} else {
354+
vals[2 * i] = lde_scale<true>(vals[2 * i], idx0, 1, 1, log_n);
355+
vals[2 * i + 1] = lde_scale<true>(vals[2 * i + 1], idx1, 1, 1, log_n);
356+
}
331357
}
332358
}
333359

@@ -356,6 +382,14 @@ EXTERN __launch_bounds__(512, 2) __global__
356382
evals_to_Z_final_stages_block<3, true>(gmem_in, gmem_out, start_stage, stages_this_launch, log_n, num_Z_cols, grid_offset);
357383
}
358384

385+
EXTERN __launch_bounds__(512, 2) __global__
386+
void ab_compressed_coset_evals_to_Z_final_9_to_12_stages_block(vectorized_e2_matrix_getter<ld_modifier::cg> gmem_in,
387+
vectorized_e2_matrix_setter<st_modifier::cg> gmem_out, const unsigned start_stage,
388+
const unsigned stages_this_launch, const unsigned log_n, const unsigned num_Z_cols,
389+
const unsigned grid_offset) {
390+
evals_to_Z_final_stages_block<3, true, true>(gmem_in, gmem_out, start_stage, stages_this_launch, log_n, num_Z_cols, grid_offset);
391+
}
392+
359393
// This kernel basically reverses the pattern of the b2n_noninitial_stages_block kernel.
360394
template <unsigned LOG_VALS_PER_THREAD>
361395
DEVICE_FORCEINLINE void evals_to_Z_nonfinal_stages_block(vectorized_e2_matrix_getter<ld_modifier::cg> gmem_in,
@@ -559,7 +593,8 @@ EXTERN __launch_bounds__(512, 2) __global__
559593
// Simple, non-optimized kernel used for log_n < 16, to unblock debugging small proofs.
560594
EXTERN __launch_bounds__(512, 2) __global__
561595
void ab_evals_to_Z_one_stage(vectorized_e2_matrix_getter<ld_modifier::cg> gmem_in, vectorized_e2_matrix_setter<st_modifier::cg> gmem_out,
562-
const unsigned start_stage, const unsigned log_n, const unsigned blocks_per_ntt, const bool evals_are_coset) {
596+
const unsigned start_stage, const unsigned log_n, const unsigned blocks_per_ntt, const bool evals_are_coset,
597+
const bool evals_are_compressed) {
563598
const unsigned col_pair = blockIdx.x / blocks_per_ntt;
564599
const unsigned bid_in_ntt = blockIdx.x % blocks_per_ntt;
565600
const unsigned tid_in_ntt = threadIdx.x + bid_in_ntt * blockDim.x;
@@ -585,8 +620,13 @@ EXTERN __launch_bounds__(512, 2) __global__
585620
a = e2f::mul(a, ab_inv_sizes[log_n]);
586621
b = e2f::mul(b, ab_inv_sizes[log_n]);
587622
if (evals_are_coset) {
588-
a = lde_scale<true>(a, bitrev(a_idx, log_n), 1, 1, log_n);
589-
b = lde_scale<true>(b, bitrev(b_idx, log_n), 1, 1, log_n);
623+
if (evals_are_compressed) {
624+
a = lde_scale_and_shift<true>(a, bitrev(a_idx, log_n), 1, 1, log_n);
625+
b = lde_scale_and_shift<true>(b, bitrev(b_idx, log_n), 1, 1, log_n);
626+
} else {
627+
a = lde_scale<true>(a, bitrev(a_idx, log_n), 1, 1, log_n);
628+
b = lde_scale<true>(b, bitrev(b_idx, log_n), 1, 1, log_n);
629+
}
590630
}
591631
}
592632

gpu_prover/native/ntt/ntt.cuh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,20 +151,22 @@ DEVICE_FORCEINLINE void load_noninitial_twiddles_warp(e2f *twiddle_cache, const
151151
}
152152

153153
// Assumes coset_idx > 0
154+
template <bool inverse = false>
154155
DEVICE_FORCEINLINE e2f get_lde_scale_and_shift_factor(const unsigned k, const unsigned log_extension_degree, const unsigned coset_idx, const unsigned log_n) {
155156
// following the notation of https://eprint.iacr.org/2023/824.pdf Section 4
156157
const unsigned tau_power_of_w = coset_idx << (CIRCLE_GROUP_LOG_ORDER - log_n - log_extension_degree);
157158
const unsigned H_over_two = 1u << (log_n - 1);
158159
const unsigned power_of_w = k >= H_over_two ? tau_power_of_w * (k - H_over_two) : (1u << CIRCLE_GROUP_LOG_ORDER) - tau_power_of_w * (H_over_two - k);
159-
return get_power_of_w(power_of_w, false);
160+
return get_power_of_w(power_of_w, inverse);
160161
}
161162

163+
template <bool inverse = false>
162164
DEVICE_FORCEINLINE e2f lde_scale_and_shift(const e2f Zk, const unsigned k, const unsigned log_extension_degree, const unsigned coset_idx,
163165
const unsigned log_n) {
164166
// Assumes the 0th coset is the main domain, as in zksync_airbender
165167
if (coset_idx == 0)
166168
return Zk;
167-
const auto gauged_shift_factor = get_lde_scale_and_shift_factor(k, log_extension_degree, coset_idx, log_n);
169+
const auto gauged_shift_factor = get_lde_scale_and_shift_factor<inverse>(k, log_extension_degree, coset_idx, log_n);
168170
return e2f::mul(Zk, gauged_shift_factor);
169171
}
170172

gpu_prover/src/execution/gpu_worker.rs

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use super::messages::WorkerResult;
22
use super::precomputations::CircuitPrecomputations;
33
use crate::allocator::host::ConcurrentStaticHostAllocator;
4-
use crate::circuit_type::CircuitType;
4+
use crate::circuit_type::{CircuitType, MainCircuitType};
55
use crate::cudart::device::set_device;
66
use crate::cudart::result::CudaResult;
77
use crate::prover::context::{ProverContext, ProverContextConfig};
@@ -121,6 +121,16 @@ const fn get_tree_cap_size(log_domain_size: u32) -> u32 {
121121
OPTIMAL_FOLDING_PROPERTIES[log_domain_size as usize].total_caps_size_log2 as u32
122122
}
123123

124+
fn get_recompute_trees(circuit_type: CircuitType, context: &ProverContext) -> bool {
125+
match circuit_type {
126+
CircuitType::Main(main) => match main {
127+
MainCircuitType::ReducedRiscVLog23Machine => (context.get_mem_size() >> 30) < 28, // less than 28GB
128+
_ => false,
129+
},
130+
_ => false,
131+
}
132+
}
133+
124134
#[derive(Clone)]
125135
struct SetupHolder<'a> {
126136
pub setup: Rc<RefCell<SetupPrecomputations<'a>>>,
@@ -164,10 +174,13 @@ fn gpu_worker(
164174
assert!(domain_size.is_power_of_two());
165175
let log_domain_size = domain_size.trailing_zeros();
166176
let log_tree_cap_size = get_tree_cap_size(log_domain_size);
177+
let recompute_trees = get_recompute_trees(circuit_type, &context);
167178
let mut setup = SetupPrecomputations::new(
168179
&precomputations.compiled_circuit,
169180
log_lde_factor,
170181
log_tree_cap_size,
182+
false,
183+
recompute_trees,
171184
&context,
172185
)?;
173186
match circuit_type {
@@ -185,7 +198,6 @@ fn gpu_worker(
185198
),
186199
}
187200
setup.ensure_commitment_produced(&context)?;
188-
setup.trace_holder.produce_tree_caps(&context)?;
189201
context.get_exec_stream().synchronize()?;
190202
if matches!(circuit_type, CircuitType::Main(_)) {
191203
let accessors = setup.trace_holder.get_tree_caps_accessors();
@@ -275,7 +287,8 @@ fn gpu_worker(
275287
}
276288
GpuWorkRequest::Proof(request) => {
277289
let batch_id = request.batch_id;
278-
match request.circuit_type {
290+
let circuit_type = request.circuit_type;
291+
match circuit_type {
279292
CircuitType::Main(main) => trace!(
280293
"BATCH[{batch_id}] GPU_WORKER[{device_id}] producing proof for main circuit {:?} chunk {}",
281294
main,
@@ -309,14 +322,15 @@ fn gpu_worker(
309322
aux_boundary_values,
310323
};
311324
let setup = setup.unwrap();
312-
let circuit_sequence = match request.circuit_type {
325+
let circuit_sequence = match circuit_type {
313326
CircuitType::Main(_) => request.circuit_sequence,
314327
CircuitType::Delegation(_) => 0,
315328
};
316-
let delegation_processing_type = match request.circuit_type {
329+
let delegation_processing_type = match circuit_type {
317330
CircuitType::Main(_) => None,
318331
CircuitType::Delegation(delegation) => Some(delegation as u16),
319332
};
333+
let recompute_trees = get_recompute_trees(circuit_type, &context);
320334
let job = prove(
321335
precomputations.compiled_circuit.clone(),
322336
external_values,
@@ -329,6 +343,8 @@ fn gpu_worker(
329343
NUM_QUERIES,
330344
POW_BITS,
331345
None,
346+
false,
347+
recompute_trees,
332348
&context,
333349
)?;
334350
JobType::Proof(job)

gpu_prover/src/ntt/mod.rs

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ cuda_kernel!(
240240
log_n: u32,
241241
blocks_per_ntt: u32,
242242
evals_are_coset: bool,
243+
evals_are_compressed: bool,
243244
);
244245

245246
one_stage_kernel!(ab_evals_to_Z_one_stage);
@@ -263,6 +264,9 @@ n2b_multi_stage_kernel!(ab_main_domain_evals_to_Z_final_9_to_12_stages_block);
263264
n2b_multi_stage_kernel!(ab_coset_evals_to_Z_final_7_stages_warp);
264265
n2b_multi_stage_kernel!(ab_coset_evals_to_Z_final_8_stages_warp);
265266
n2b_multi_stage_kernel!(ab_coset_evals_to_Z_final_9_to_12_stages_block);
267+
n2b_multi_stage_kernel!(ab_compressed_coset_evals_to_Z_final_7_stages_warp);
268+
n2b_multi_stage_kernel!(ab_compressed_coset_evals_to_Z_final_8_stages_warp);
269+
n2b_multi_stage_kernel!(ab_compressed_coset_evals_to_Z_final_9_to_12_stages_block);
266270

267271
#[allow(clippy::too_many_arguments)]
268272
fn natural_evals_to_bitrev_Z(
@@ -271,6 +275,7 @@ fn natural_evals_to_bitrev_Z(
271275
log_n: usize,
272276
num_bf_cols: usize,
273277
evals_are_coset: bool,
278+
evals_are_compressed: bool,
274279
stream: &CudaStream,
275280
) -> CudaResult<()> {
276281
assert!(log_n >= 1);
@@ -282,6 +287,9 @@ fn natural_evals_to_bitrev_Z(
282287
assert_eq!(inputs_matrix.cols(), num_bf_cols);
283288
assert_eq!(outputs_matrix.rows(), n);
284289
assert_eq!(outputs_matrix.cols(), num_bf_cols);
290+
if !evals_are_coset {
291+
assert!(!evals_are_compressed);
292+
}
285293

286294
let inputs_matrix = inputs_matrix.as_ptr_and_stride();
287295
let outputs_matrix_const = outputs_matrix.as_ptr_and_stride();
@@ -302,6 +310,7 @@ fn natural_evals_to_bitrev_Z(
302310
log_n as u32,
303311
blocks_per_ntt as u32,
304312
evals_are_coset,
313+
evals_are_compressed,
305314
);
306315
kernel_function.launch(&config, &args)?;
307316
for stage in 1..log_n {
@@ -312,6 +321,7 @@ fn natural_evals_to_bitrev_Z(
312321
log_n as u32,
313322
blocks_per_ntt as u32,
314323
evals_are_coset,
324+
evals_are_compressed,
315325
);
316326
kernel_function.launch(&config, &args)?;
317327
}
@@ -330,7 +340,11 @@ fn natural_evals_to_bitrev_Z(
330340
match kern {
331341
FINAL_7_WARP => (
332342
if evals_are_coset {
333-
ab_coset_evals_to_Z_final_7_stages_warp
343+
if evals_are_compressed {
344+
ab_compressed_coset_evals_to_Z_final_7_stages_warp
345+
} else {
346+
ab_coset_evals_to_Z_final_7_stages_warp
347+
}
334348
} else {
335349
ab_main_domain_evals_to_Z_final_7_stages_warp
336350
},
@@ -339,7 +353,11 @@ fn natural_evals_to_bitrev_Z(
339353
),
340354
FINAL_8_WARP => (
341355
if evals_are_coset {
342-
ab_coset_evals_to_Z_final_8_stages_warp
356+
if evals_are_compressed {
357+
ab_compressed_coset_evals_to_Z_final_8_stages_warp
358+
} else {
359+
ab_coset_evals_to_Z_final_8_stages_warp
360+
}
343361
} else {
344362
ab_main_domain_evals_to_Z_final_8_stages_warp
345363
},
@@ -348,7 +366,11 @@ fn natural_evals_to_bitrev_Z(
348366
),
349367
FINAL_9_TO_12_BLOCK => (
350368
if evals_are_coset {
351-
ab_coset_evals_to_Z_final_9_to_12_stages_block
369+
if evals_are_compressed {
370+
ab_compressed_coset_evals_to_Z_final_9_to_12_stages_block
371+
} else {
372+
ab_coset_evals_to_Z_final_9_to_12_stages_block
373+
}
352374
} else {
353375
ab_main_domain_evals_to_Z_final_9_to_12_stages_block
354376
},
@@ -403,6 +425,7 @@ pub fn natural_trace_main_evals_to_bitrev_Z(
403425
log_n,
404426
num_bf_cols,
405427
false,
428+
false,
406429
stream,
407430
)
408431
}
@@ -421,6 +444,26 @@ pub fn natural_composition_coset_evals_to_bitrev_Z(
421444
log_n,
422445
num_bf_cols,
423446
true,
447+
false,
448+
stream,
449+
)
450+
}
451+
452+
#[allow(clippy::too_many_arguments)]
453+
pub fn natural_compressed_coset_evals_to_bitrev_Z(
454+
inputs_matrix: &(impl DeviceMatrixChunkImpl<BF> + ?Sized),
455+
outputs_matrix: &mut (impl DeviceMatrixChunkMutImpl<BF> + ?Sized),
456+
log_n: usize,
457+
num_bf_cols: usize,
458+
stream: &CudaStream,
459+
) -> CudaResult<()> {
460+
natural_evals_to_bitrev_Z(
461+
inputs_matrix,
462+
outputs_matrix,
463+
log_n,
464+
num_bf_cols,
465+
true,
466+
true,
424467
stream,
425468
)
426469
}

0 commit comments

Comments
 (0)