Skip to content

Commit c93da50

Browse files
committed
Extract do_with_trace
1 parent 187778f commit c93da50

File tree

2 files changed

+120
-55
lines changed

2 files changed

+120
-55
lines changed

openvm/src/lib.rs

Lines changed: 8 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@ use eyre::Result;
88
use itertools::Itertools;
99
use openvm_build::{build_guest_package, find_unique_executable, get_package, TargetFilter};
1010
use openvm_circuit::arch::execution_mode::metered::segment_ctx::SegmentationLimits;
11-
use openvm_circuit::arch::execution_mode::Segment;
1211
use openvm_circuit::arch::instructions::exe::VmExe;
1312
use openvm_circuit::arch::{
1413
debug_proving_ctx, AirInventory, AirInventoryError, ChipInventory, ChipInventoryError,
1514
ExecutorInventory, ExecutorInventoryError, InitFileGenerator, MatrixRecordArena,
16-
PreflightExecutionOutput, RowMajorMatrixArena, SystemConfig, VmBuilder, VmChipComplex,
17-
VmCircuitConfig, VmCircuitExtension, VmExecutionConfig, VmInstance, VmProverExtension,
15+
RowMajorMatrixArena, SystemConfig, VmBuilder, VmChipComplex, VmCircuitConfig,
16+
VmCircuitExtension, VmExecutionConfig, VmProverExtension,
1817
};
1918
use openvm_circuit::system::SystemChipInventory;
2019
use openvm_circuit::{circuit_derive::Chip, derive::AnyEnum};
@@ -23,7 +22,6 @@ use openvm_sdk::config::SdkVmCpuBuilder;
2322

2423
use openvm_instructions::program::{Program, DEFAULT_PC_STEP};
2524
use openvm_sdk::config::TranspilerConfig;
26-
use openvm_sdk::prover::vm::new_local_prover;
2725
use openvm_sdk::prover::{verify_app_proof, AggStarkProver};
2826
use openvm_sdk::GenericSdk;
2927
use openvm_sdk::{
@@ -63,7 +61,8 @@ use std::{
6361
use crate::customize_exe::OpenVmApcCandidate;
6462
pub use crate::customize_exe::Prog;
6563
use crate::powdr_extension::chip::PowdrAir;
66-
use tracing::{info_span, Level};
64+
use crate::trace_generation::do_with_trace;
65+
use tracing::Level;
6766

6867
#[cfg(test)]
6968
use crate::extraction_utils::AirWidthsDiff;
@@ -77,6 +76,7 @@ pub mod cuda_abi;
7776
pub mod extraction_utils;
7877
pub mod opcode;
7978
pub mod symbolic_instruction_builder;
79+
pub mod trace_generation;
8080
mod utils;
8181
pub use opcode::instruction_allowlist;
8282
pub use powdr_autoprecompiles::DegreeBound;
@@ -874,56 +874,9 @@ pub fn prove(
874874
#[cfg(not(feature = "cuda"))]
875875
let sdk = PowdrSdkCpu::new(app_config).unwrap();
876876
if mock {
877-
// Build owned vm instance, so we can mutate it later
878-
let vm_builder = sdk.app_vm_builder().clone();
879-
let vm_pk = sdk.app_pk().app_vm_pk.clone();
880-
let exe = sdk.convert_to_exe(exe.clone())?;
881-
let mut vm_instance: VmInstance<_, _> = new_local_prover(vm_builder, &vm_pk, exe.clone())?;
882-
883-
vm_instance.reset_state(inputs.clone());
884-
let metered_ctx = vm_instance.vm.build_metered_ctx(&exe);
885-
let metered_interpreter = vm_instance.vm.metered_interpreter(vm_instance.exe())?;
886-
let (segments, _) = metered_interpreter.execute_metered(inputs.clone(), metered_ctx)?;
887-
let mut state = vm_instance.state_mut().take();
888-
889-
// Get reusable inputs for `debug_proving_ctx`, the mock prover API from OVM.
890-
let vm = &mut vm_instance.vm;
891-
let air_inv = vm.config().create_airs().unwrap();
892-
#[cfg(feature = "cuda")]
893-
let pk = air_inv.keygen::<GpuBabyBearPoseidon2Engine>(&vm.engine);
894-
#[cfg(not(feature = "cuda"))]
895-
let pk = air_inv.keygen::<BabyBearPoseidon2Engine>(&vm.engine);
896-
897-
for (seg_idx, segment) in segments.into_iter().enumerate() {
898-
let _segment_span = info_span!("prove_segment", segment = seg_idx).entered();
899-
// We need a separate span so the metric label includes "segment" from _segment_span
900-
let _prove_span = info_span!("total_proof").entered();
901-
let Segment {
902-
instret_start,
903-
num_insns,
904-
trace_heights,
905-
} = segment;
906-
assert_eq!(state.as_ref().unwrap().instret(), instret_start);
907-
let from_state = Option::take(&mut state).unwrap();
908-
vm.transport_init_memory_to_device(&from_state.memory);
909-
let PreflightExecutionOutput {
910-
system_records,
911-
record_arenas,
912-
to_state,
913-
} = vm.execute_preflight(
914-
&mut vm_instance.interpreter,
915-
from_state,
916-
Some(num_insns),
917-
&trace_heights,
918-
)?;
919-
state = Some(to_state);
920-
921-
// Generate proving context for each segment
922-
let ctx = vm.generate_proving_ctx(system_records, record_arenas)?;
923-
924-
// Run the mock prover for each segment
925-
debug_proving_ctx(vm, &pk, &ctx);
926-
}
877+
do_with_trace(program, inputs, |vm, pk, ctx| {
878+
debug_proving_ctx(vm, pk, &ctx);
879+
});
927880
} else {
928881
let mut app_prover = sdk.app_prover(exe.clone())?;
929882

openvm/src/trace_generation.rs

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
use openvm_circuit::arch::{
2+
execution_mode::Segment, PreflightExecutionOutput, VirtualMachine, VmCircuitConfig, VmInstance,
3+
};
4+
use openvm_sdk::{
5+
config::{AppConfig, DEFAULT_APP_LOG_BLOWUP},
6+
prover::vm::new_local_prover,
7+
StdIn,
8+
};
9+
use openvm_stark_backend::{keygen::types::MultiStarkProvingKey, prover::types::ProvingContext};
10+
use openvm_stark_sdk::{config::FriParameters, engine::StarkEngine};
11+
use tracing::info_span;
12+
13+
use crate::{BabyBearSC, CompiledProgram, SpecializedConfigCpuBuilder};
14+
15+
#[cfg(not(feature = "cuda"))]
16+
use crate::PowdrSdkCpu;
17+
#[cfg(feature = "cuda")]
18+
use crate::PowdrSdkGpu;
19+
20+
#[cfg(not(feature = "cuda"))]
21+
use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Engine;
22+
#[cfg(feature = "cuda")]
23+
use openvm_stark_sdk::config::gpu_baby_bear_poseidon2::GpuBabyBearPoseidon2Engine;
24+
25+
/// Given a program and input, generates the trace segment by segment and calls the provided
26+
/// callback with the VM, proving key, and proving context (containing the trace) for each segment.
27+
pub fn do_with_trace(
28+
program: &CompiledProgram,
29+
inputs: StdIn,
30+
mut callback: impl FnMut(
31+
&VirtualMachine<BabyBearPoseidon2Engine, SpecializedConfigCpuBuilder>,
32+
&MultiStarkProvingKey<BabyBearSC>,
33+
ProvingContext<<BabyBearPoseidon2Engine as StarkEngine>::PB>,
34+
),
35+
) {
36+
let exe = &program.exe;
37+
let vm_config = program.vm_config.clone();
38+
39+
// Set app configuration
40+
let app_fri_params =
41+
FriParameters::standard_with_100_bits_conjectured_security(DEFAULT_APP_LOG_BLOWUP);
42+
let app_config = AppConfig::new(app_fri_params, vm_config.clone());
43+
44+
// Create the SDK
45+
#[cfg(feature = "cuda")]
46+
let sdk = PowdrSdkGpu::new(app_config).unwrap();
47+
#[cfg(not(feature = "cuda"))]
48+
let sdk = PowdrSdkCpu::new(app_config).unwrap();
49+
// Build owned vm instance, so we can mutate it later
50+
let vm_builder = sdk.app_vm_builder().clone();
51+
let vm_pk = sdk.app_pk().app_vm_pk.clone();
52+
let exe = sdk.convert_to_exe(exe.clone()).unwrap();
53+
let mut vm_instance: VmInstance<_, _> =
54+
new_local_prover(vm_builder, &vm_pk, exe.clone()).unwrap();
55+
56+
vm_instance.reset_state(inputs.clone());
57+
let metered_ctx = vm_instance.vm.build_metered_ctx(&exe);
58+
let metered_interpreter = vm_instance
59+
.vm
60+
.metered_interpreter(vm_instance.exe())
61+
.unwrap();
62+
let (segments, _) = metered_interpreter
63+
.execute_metered(inputs.clone(), metered_ctx)
64+
.unwrap();
65+
let mut state = vm_instance.state_mut().take();
66+
67+
// Move `vm` and `interpreter` out of `vm_instance`
68+
// (after this, you can't use `vm_instance` anymore).
69+
let mut vm = vm_instance.vm;
70+
let mut interpreter = vm_instance.interpreter;
71+
72+
// Get reusable inputs for `debug_proving_ctx`, the mock prover API from OVM.
73+
let air_inv = vm.config().create_airs().unwrap();
74+
#[cfg(feature = "cuda")]
75+
let pk = air_inv.keygen::<GpuBabyBearPoseidon2Engine>(&vm.engine);
76+
#[cfg(not(feature = "cuda"))]
77+
let pk = air_inv.keygen::<BabyBearPoseidon2Engine>(&vm.engine);
78+
79+
for (seg_idx, segment) in segments.into_iter().enumerate() {
80+
let _segment_span = info_span!("prove_segment", segment = seg_idx).entered();
81+
// We need a separate span so the metric label includes "segment" from _segment_span
82+
let _prove_span = info_span!("total_proof").entered();
83+
let Segment {
84+
instret_start,
85+
num_insns,
86+
trace_heights,
87+
} = segment;
88+
assert_eq!(state.as_ref().unwrap().instret(), instret_start);
89+
let from_state = Option::take(&mut state).unwrap();
90+
vm.transport_init_memory_to_device(&from_state.memory);
91+
let PreflightExecutionOutput {
92+
system_records,
93+
record_arenas,
94+
to_state,
95+
} = vm
96+
.execute_preflight(
97+
&mut interpreter,
98+
from_state,
99+
Some(num_insns),
100+
&trace_heights,
101+
)
102+
.unwrap();
103+
state = Some(to_state);
104+
105+
// Generate proving context for each segment
106+
let ctx = vm
107+
.generate_proving_ctx(system_records, record_arenas)
108+
.unwrap();
109+
110+
callback(&vm, &pk, ctx);
111+
}
112+
}

0 commit comments

Comments
 (0)