Skip to content

Commit 8c8f3b7

Browse files
gaxiomjonathanpwang
andcommitted
feat(cuda): logup zerocheck (round 0) (#156)
Closes INT-5182 - [x] GPU state scaffolding (with CPU fallback) - [x] Interaction metadata plumbing - [x] Fractional sumcheck kernel - [x] Interaction evaluation kernel - [x] round0 impl (iDFT > pad > DFT > eval DAG) - [x] fold_ple - [ ] sumcheck_polys_eval - [ ] fold_mle - [ ] column openings - [x] small heights support --------- Co-authored-by: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com>
1 parent 9dc7f20 commit 8c8f3b7

File tree

9 files changed

+59
-58
lines changed

9 files changed

+59
-58
lines changed

crates/stark-backend-v2/src/prover/cpu_backend.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,16 @@ impl<TS: FiatShamirTranscript> MultiRapProver<CpuBackendV2, TS> for CpuDeviceV2
7171
transcript: &mut TS,
7272
mpk: &DeviceMultiStarkProvingKeyV2<CpuBackendV2>,
7373
ctx: ProvingContextV2<CpuBackendV2>,
74+
common_main_pcs_data: &StackedPcsData<F, Digest>,
7475
) -> ((GkrProof, BatchConstraintProof), Vec<EF>) {
7576
let (gkr_proof, batch_constraint_proof, r) =
76-
prove_zerocheck_and_logup::<_, _, TS, LogupZerocheckCpu>(self, transcript, mpk, ctx);
77+
prove_zerocheck_and_logup::<_, _, TS, LogupZerocheckCpu>(
78+
self,
79+
transcript,
80+
mpk,
81+
ctx,
82+
common_main_pcs_data,
83+
);
7784
((gkr_proof, batch_constraint_proof), r)
7885
}
7986
}

crates/stark-backend-v2/src/prover/hal.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ pub trait MultiRapProver<PB: ProverBackendV2, TS> {
8383
transcript: &mut TS,
8484
mpk: &DeviceMultiStarkProvingKeyV2<PB>,
8585
ctx: ProvingContextV2<PB>,
86+
common_main_pcs_data: &PB::PcsData,
8687
) -> (Self::PartialProof, Self::Artifacts);
8788
}
8889

crates/stark-backend-v2/src/prover/logup_zerocheck/cpu.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use crate::{
2323
poseidon2::sponge::FiatShamirTranscript,
2424
prover::{
2525
ColMajorMatrix, CpuBackendV2, CpuDeviceV2, DeviceMultiStarkProvingKeyV2,
26-
LogupZerocheckProver, MatrixView, ProvingContextV2,
26+
LogupZerocheckProver, MatrixView, ProverBackendV2, ProvingContextV2,
2727
fractional_sumcheck_gkr::{Frac, FracSumcheckProof, fractional_sumcheck},
2828
logup_zerocheck::EvalHelper,
2929
poly::evals_eq_hypercube,
@@ -58,14 +58,14 @@ pub struct LogupZerocheckCpu<'a> {
5858
pub xi: Vec<EF>,
5959
lambda_pows: Vec<EF>,
6060

61-
eq_xi_per_trace: Vec<ColMajorMatrix<EF>>,
62-
eq_sharp_per_trace: Vec<ColMajorMatrix<EF>>,
61+
pub eq_xi_per_trace: Vec<ColMajorMatrix<EF>>,
62+
pub eq_sharp_per_trace: Vec<ColMajorMatrix<EF>>,
6363
eq_3b_per_trace: Vec<Vec<EF>>,
6464
// TODO[jpw]: delete these
6565
sels_per_trace_base: Vec<ColMajorMatrix<F>>,
6666
// After univariate round 0:
67-
mat_evals_per_trace: Vec<Vec<ColMajorMatrix<EF>>>,
68-
sels_per_trace: Vec<ColMajorMatrix<EF>>,
67+
pub mat_evals_per_trace: Vec<Vec<ColMajorMatrix<EF>>>,
68+
pub sels_per_trace: Vec<ColMajorMatrix<EF>>,
6969
// Stores \hat{f}(\vec r_n) * r_{n+1} .. r_{round-1} for polys f that are "done" in the batch
7070
// sumcheck
7171
zerocheck_tilde_evals: Vec<EF>,
@@ -81,6 +81,7 @@ where
8181
transcript: &mut TS,
8282
pk: &'a DeviceMultiStarkProvingKeyV2<CpuBackendV2>,
8383
ctx: &ProvingContextV2<CpuBackendV2>,
84+
_common_main_data: &'a <CpuBackendV2 as ProverBackendV2>::PcsData,
8485
n_logup: usize,
8586
interactions_layout: StackedLayout,
8687
alpha_logup: EF,
@@ -374,6 +375,7 @@ where
374375
ColMajorMatrix::new(mat, 3)
375376
})
376377
.collect_vec();
378+
377379
// PERF[jpw]: see Gruen, Section 3.2 and 4 on some ways to reduce the degree of the
378380
// univariate polynomial. We know s_0 is supposed to vanish on univariate skip
379381
// domain `D` of size `2^{l_skip}`. Hence `s_0 = Z_D * s'_0(Z)` where `Z_D =

crates/stark-backend-v2/src/prover/logup_zerocheck/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ pub trait LogupZerocheckProver<'a, PB: ProverBackendV2, PD, TS>: Sized {
4949
transcript: &mut TS,
5050
pk: &'a DeviceMultiStarkProvingKeyV2<PB>,
5151
ctx: &ProvingContextV2<PB>,
52+
common_main_pcs_data: &'a PB::PcsData,
5253
n_logup: usize,
5354
interactions_layout: StackedLayout,
5455
alpha_logup: PB::Challenge,
@@ -91,6 +92,7 @@ pub fn prove_zerocheck_and_logup<'a, PB, PD, TS, LZP>(
9192
transcript: &mut TS,
9293
mpk: &'a DeviceMultiStarkProvingKeyV2<PB>,
9394
ctx: ProvingContextV2<PB>,
95+
common_main_pcs_data: &'a PB::PcsData,
9496
) -> (GkrProof, BatchConstraintProof, Vec<PB::Challenge>)
9597
where
9698
PB: ProverBackendV2<Val = F, Challenge = EF>,
@@ -139,6 +141,7 @@ where
139141
transcript,
140142
mpk,
141143
&ctx,
144+
common_main_pcs_data,
142145
n_logup,
143146
interactions_layout,
144147
alpha_logup,

crates/stark-backend-v2/src/prover/logup_zerocheck/single.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use crate::prover::{
1515
};
1616

1717
/// For a single AIR
18-
pub(super) struct EvalHelper<'a, F> {
18+
pub struct EvalHelper<'a, F> {
1919
/// AIR constraints
2020
pub constraints_dag: &'a SymbolicExpressionDag<F>,
2121
/// Interactions

crates/stark-backend-v2/src/prover/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,9 @@ where
146146
})
147147
.collect();
148148

149-
let (constraints_proof, r) = self.device.prove_rap_constraints(transcript, mpk, ctx);
149+
let (constraints_proof, r) =
150+
self.device
151+
.prove_rap_constraints(transcript, mpk, ctx, &common_main_pcs_data);
150152

151153
let opening_proof = self.device.prove_openings(
152154
transcript,

crates/stark-backend-v2/src/prover/stacked_pcs.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@ use crate::{
1313
prover::{ColMajorMatrix, MatrixView, StridedColMajorMatrixView, col_maj_idx, poly::Ple},
1414
};
1515

16-
#[derive(Clone, Serialize, Deserialize, Debug)]
16+
#[derive(Clone, Serialize, Deserialize, Debug, CopyGetters)]
1717
pub struct StackedLayout {
1818
/// The minimum log2 height of a stacked slice. When stacking columns with smaller height, the
1919
/// column is expanded to `2^l_skip` by striding.
20+
#[getset(get_copy = "pub")]
2021
l_skip: usize,
2122
/// The columns of the unstacked matrices in sorted order. Each entry `(matrix index, column
2223
/// index, coordinate)` contains the pointer `(matrix index, column index)` to a column of the

crates/stark-backend-v2/src/test_utils.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ use crate::{
2727
proof::Proof,
2828
prover::{
2929
AirProvingContextV2, ColMajorMatrix, CommittedTraceDataV2, CpuBackendV2,
30-
DeviceDataTransporterV2, ProverBackendV2, ProvingContextV2, stacked_pcs::stacked_commit,
30+
DeviceDataTransporterV2, DeviceMultiStarkProvingKeyV2, MultiRapProver, ProverBackendV2,
31+
ProvingContextV2, TraceCommitterV2, stacked_pcs::stacked_commit,
3132
},
3233
};
3334

@@ -66,6 +67,25 @@ where
6667
ProvingContextV2::new(per_trace)
6768
}
6869

70+
pub fn prove_up_to_batch_constraints<E: StarkEngineV2>(
71+
engine: &E,
72+
transcript: &mut E::TS,
73+
pk: &DeviceMultiStarkProvingKeyV2<E::PB>,
74+
ctx: ProvingContextV2<E::PB>,
75+
) -> (
76+
<E::PD as MultiRapProver<E::PB, E::TS>>::PartialProof,
77+
<E::PD as MultiRapProver<E::PB, E::TS>>::Artifacts,
78+
) {
79+
let (_, common_main_pcs_data) = engine.device().commit(
80+
&ctx.common_main_traces()
81+
.map(|(_, trace)| trace)
82+
.collect_vec(),
83+
);
84+
engine
85+
.device()
86+
.prove_rap_constraints(transcript, pk, ctx, &common_main_pcs_data)
87+
}
88+
6989
fn get_fib_number(mut a: u32, mut b: u32, n: usize) -> u32 {
7090
for _ in 0..n - 1 {
7191
let c = (a + b) % BabyBear::ORDER_U32;

crates/stark-backend-v2/src/tests.rs

Lines changed: 13 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use crate::{
2020
},
2121
test_utils::{
2222
CachedFixture11, FibFixture, InteractionsFixture11, PreprocessedFibFixture, TestFixture,
23-
test_engine_small,
23+
prove_up_to_batch_constraints, test_engine_small,
2424
},
2525
verifier::{
2626
batch_constraints::{BatchConstraintError, verify_zerocheck_and_logup},
@@ -159,9 +159,8 @@ fn test_batch_sumcheck_zero_interactions(
159159

160160
let pvs = vec![ctx.per_trace[0].1.public_values.clone()];
161161
let ((gkr_proof, batch_proof), _) =
162-
engine
163-
.device()
164-
.prove_rap_constraints(&mut prover_sponge, &pk, ctx);
162+
prove_up_to_batch_constraints(&engine, &mut prover_sponge, &pk, ctx);
163+
165164
let r = verify_zerocheck_and_logup(
166165
&mut verifier_sponge,
167166
&vk.inner,
@@ -211,8 +210,12 @@ fn test_stacked_opening_reduction(log_trace_degree: usize) -> Result<(), Stacked
211210

212211
let device = engine.device();
213212
// We need batch_proof to obtain the column openings
214-
let ((_, batch_proof), r) =
215-
device.prove_rap_constraints(&mut DuplexSponge::default(), &pk, ctx);
213+
let ((_, batch_proof), r) = device.prove_rap_constraints(
214+
&mut DuplexSponge::default(),
215+
&pk,
216+
ctx,
217+
&common_main_pcs_data,
218+
);
216219

217220
let (stacking_proof, _) = prove_stacked_opening_reduction::<_, _, _, StackedReductionCpu>(
218221
device,
@@ -237,7 +240,6 @@ fn test_stacked_opening_reduction(log_trace_degree: usize) -> Result<(), Stacked
237240
assert_eq!(u_prism.len(), params.n_stack + 1);
238241
Ok(())
239242
}
240-
241243
#[test_case(3)]
242244
#[test_case(2 ; "when fib log_height equals l_skip")]
243245
#[test_case(1 ; "when fib log_height less than l_skip")]
@@ -299,41 +301,8 @@ fn test_single_fib_and_dummy_trace_stark(log_trace_degree: usize) {
299301
per_trace.push((per_trace.len(), fib_ctx));
300302
let combined_ctx = ProvingContextV2::new(per_trace).into_sorted();
301303

302-
let l_skip = engine.config().l_skip;
303-
let mut pvs = vec![vec![]; 3];
304-
let (trace_id_to_air_ids, ns): (Vec<_>, Vec<_>) = combined_ctx
305-
.per_trace
306-
.iter()
307-
.map(|(air_idx, air_ctx)| {
308-
pvs[*air_idx] = air_ctx.public_values.clone();
309-
(
310-
*air_idx,
311-
log2_strict_usize(air_ctx.common_main.height()) as isize - l_skip as isize,
312-
)
313-
})
314-
.multiunzip();
315-
let omega_pows = F::two_adic_generator(l_skip)
316-
.powers()
317-
.take(1 << l_skip)
318-
.collect_vec();
319-
320-
let mut transcript = DuplexSponge::default();
321-
let ((gkr_proof, batch_proof), _) =
322-
engine
323-
.device()
324-
.prove_rap_constraints(&mut transcript, &combined_pk, combined_ctx);
325-
let mut transcript = DuplexSponge::default();
326-
verify_zerocheck_and_logup(
327-
&mut transcript,
328-
&combined_pk.get_vk().inner,
329-
&pvs,
330-
&gkr_proof,
331-
&batch_proof,
332-
&trace_id_to_air_ids,
333-
&ns,
334-
&omega_pows,
335-
)
336-
.unwrap();
304+
let proof = engine.prove(&combined_pk, combined_ctx);
305+
engine.verify(&combined_pk.get_vk(), &proof).unwrap();
337306
}
338307

339308
#[test]
@@ -380,9 +349,7 @@ fn test_gkr_verify_zero_interactions() -> eyre::Result<()> {
380349
let pk = engine.device().transport_pk_to_device(&pk);
381350
let ctx = fx.generate_proving_ctx().into_sorted();
382351
let mut transcript = DuplexSponge::default();
383-
let ((gkr_proof, _), _) = engine
384-
.device()
385-
.prove_rap_constraints(&mut transcript, &pk, ctx);
352+
let ((gkr_proof, _), _) = prove_up_to_batch_constraints(&engine, &mut transcript, &pk, ctx);
386353

387354
let mut transcript = DuplexSponge::default();
388355
assert!(transcript.check_witness(params.logup_pow_bits, gkr_proof.logup_pow_witness));
@@ -425,9 +392,7 @@ fn test_batch_constraints_with_interactions() -> eyre::Result<()> {
425392

426393
let mut transcript = DuplexSponge::default();
427394
let ((gkr_proof, batch_proof), _) =
428-
engine
429-
.device()
430-
.prove_rap_constraints(&mut transcript, &pk, ctx);
395+
prove_up_to_batch_constraints(&engine, &mut transcript, &pk, ctx);
431396
let mut transcript = DuplexSponge::default();
432397
verify_zerocheck_and_logup(
433398
&mut transcript,

0 commit comments

Comments
 (0)