diff --git a/Cargo.lock b/Cargo.lock index 13f92fde4d..85f13ded5d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5458,8 +5458,8 @@ dependencies = [ [[package]] name = "openvm-cuda-backend" -version = "1.2.1-rc.3" -source = "git+https://github.com/openvm-org/stark-backend.git?tag=v1.2.1-rc.3#9c649e1c084f72b3cbdaebbb62d572eb05d088a0" +version = "1.2.1-rc.5" +source = "git+https://github.com/openvm-org/stark-backend.git?branch=chore%2Fvpmm_v2#8b401499c8ff8e6fcf99e9b1bf69502184436428" dependencies = [ "bincode 2.0.1", "bincode_derive", @@ -5490,8 +5490,8 @@ dependencies = [ [[package]] name = "openvm-cuda-builder" -version = "1.2.1-rc.3" -source = "git+https://github.com/openvm-org/stark-backend.git?tag=v1.2.1-rc.3#9c649e1c084f72b3cbdaebbb62d572eb05d088a0" +version = "1.2.1-rc.5" +source = "git+https://github.com/openvm-org/stark-backend.git?branch=chore%2Fvpmm_v2#8b401499c8ff8e6fcf99e9b1bf69502184436428" dependencies = [ "cc", "glob", @@ -5499,8 +5499,8 @@ dependencies = [ [[package]] name = "openvm-cuda-common" -version = "1.2.1-rc.3" -source = "git+https://github.com/openvm-org/stark-backend.git?tag=v1.2.1-rc.3#9c649e1c084f72b3cbdaebbb62d572eb05d088a0" +version = "1.2.1-rc.5" +source = "git+https://github.com/openvm-org/stark-backend.git?branch=chore%2Fvpmm_v2#8b401499c8ff8e6fcf99e9b1bf69502184436428" dependencies = [ "bytesize", "ctor", @@ -5711,6 +5711,7 @@ dependencies = [ "serde", "strum", "tiny-keccak", + "tokio", ] [[package]] @@ -6163,6 +6164,7 @@ dependencies = [ "snark-verifier-sdk", "tempfile", "thiserror 1.0.69", + "tokio", "toml 0.8.23", "tracing", ] @@ -6244,8 +6246,8 @@ dependencies = [ [[package]] name = "openvm-stark-backend" -version = "1.2.1-rc.3" -source = "git+https://github.com/openvm-org/stark-backend.git?tag=v1.2.1-rc.3#9c649e1c084f72b3cbdaebbb62d572eb05d088a0" +version = "1.2.1-rc.5" +source = "git+https://github.com/openvm-org/stark-backend.git?branch=chore%2Fvpmm_v2#8b401499c8ff8e6fcf99e9b1bf69502184436428" dependencies = [ "bitcode", "cfg-if", @@ -6274,12 +6276,12 @@ dependencies = [ [[package]] name = "openvm-stark-sdk" -version = "1.2.1-rc.3" -source = "git+https://github.com/openvm-org/stark-backend.git?tag=v1.2.1-rc.3#9c649e1c084f72b3cbdaebbb62d572eb05d088a0" +version = "1.2.1-rc.5" +source = "git+https://github.com/openvm-org/stark-backend.git?branch=chore%2Fvpmm_v2#8b401499c8ff8e6fcf99e9b1bf69502184436428" dependencies = [ "dashmap", "derivative", - "derive_more 0.99.20", + "derive_more 1.0.0", "ff 0.13.1", "itertools 0.14.0", "metrics", @@ -9238,6 +9240,7 @@ dependencies = [ "io-uring", "libc", "mio", + "parking_lot", "pin-project-lite", "signal-hook-registry", "slab", diff --git a/Cargo.toml b/Cargo.toml index 5d6c8f6d88..3c6eb19820 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -113,11 +113,11 @@ lto = "thin" [workspace.dependencies] # Stark Backend -openvm-stark-backend = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.2.1-rc.3", default-features = false } -openvm-stark-sdk = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.2.1-rc.3", default-features = false } -openvm-cuda-backend = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.2.1-rc.3", default-features = false } -openvm-cuda-builder = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.2.1-rc.3", default-features = false } -openvm-cuda-common = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.2.1-rc.3", default-features = false } +openvm-stark-backend = { git = "https://github.com/openvm-org/stark-backend.git", branch="chore/vpmm_v2", default-features = false } +openvm-stark-sdk = { git = "https://github.com/openvm-org/stark-backend.git", branch="chore/vpmm_v2", default-features = false } +openvm-cuda-backend = { git = "https://github.com/openvm-org/stark-backend.git", branch="chore/vpmm_v2", default-features = false } +openvm-cuda-builder = { git = "https://github.com/openvm-org/stark-backend.git", branch="chore/vpmm_v2", default-features = false } +openvm-cuda-common = { git = "https://github.com/openvm-org/stark-backend.git", branch="chore/vpmm_v2", default-features = false } # OpenVM openvm-sdk = { path = "crates/sdk", default-features = false } @@ -233,6 +233,7 @@ dashmap = "6.1.0" memmap2 = "0.9.5" libc = "0.2.175" tracing-subscriber = { version = "0.3.20", features = ["std", "env-filter"] } +tokio = "1" # >=1.0.0 to allow downstream flexibility # default-features = false for no_std for use in guest programs itertools = { version = "0.14.0", default-features = false } diff --git a/benchmarks/prove/Cargo.toml b/benchmarks/prove/Cargo.toml index 04252f224c..dd9b75b79e 100644 --- a/benchmarks/prove/Cargo.toml +++ b/benchmarks/prove/Cargo.toml @@ -19,9 +19,9 @@ openvm-native-circuit.workspace = true openvm-native-compiler.workspace = true openvm-native-recursion = { workspace = true, features = ["test-utils"] } -clap = { version = "4.5.9", features = ["derive", "env"] } +clap = { workspace = true, features = ["derive", "env"] } eyre.workspace = true -tokio = { version = "1.43.1", features = ["rt", "rt-multi-thread", "macros"] } +tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] } rand_chacha = { version = "0.3", default-features = false } k256 = { workspace = true, features = ["ecdsa"] } tiny-keccak.workspace = true @@ -33,11 +33,12 @@ metrics.workspace = true [dev-dependencies] [features] -default = ["parallel", "jemalloc", "metrics"] +default = ["parallel", "jemalloc", "metrics", "async"] metrics = ["openvm-sdk/metrics"] tco = ["openvm-sdk/tco"] perf-metrics = ["openvm-sdk/perf-metrics", "metrics"] stark-debug = ["openvm-sdk/stark-debug"] +async = ["openvm-sdk/async"] # runs leaf aggregation benchmarks: aggregation = [] evm = ["openvm-sdk/evm-verify"] @@ -63,3 +64,8 @@ path = "src/bin/fib_e2e.rs" [[bin]] name = "kitchen_sink" path = "src/bin/kitchen_sink.rs" + +[[bin]] +name = "async_regex" +path = "src/bin/async_regex.rs" +required-features = ["async"] diff --git a/benchmarks/prove/src/bin/async_regex.rs b/benchmarks/prove/src/bin/async_regex.rs new file mode 100644 index 0000000000..65dcfe85df --- /dev/null +++ b/benchmarks/prove/src/bin/async_regex.rs @@ -0,0 +1,53 @@ +use std::env::var; + +use clap::Parser; +use openvm_benchmarks_prove::util::BenchmarkCli; +use openvm_benchmarks_utils::get_programs_dir; +use openvm_sdk::{ + config::{SdkVmBuilder, SdkVmConfig}, + prover::AsyncAppProver, + DefaultStarkEngine, Sdk, StdIn, F, +}; +use openvm_stark_sdk::config::setup_tracing; + +#[tokio::main] +async fn main() -> eyre::Result<()> { + setup_tracing(); + let args = BenchmarkCli::parse(); + let mut config = SdkVmConfig::from_toml(include_str!("../../../guest/regex/openvm.toml"))?; + if let Some(max_height) = args.max_segment_length { + config + .app_vm_config + .as_mut() + .segmentation_limits + .max_trace_height = max_height; + } + if let Some(max_cells) = args.segment_max_cells { + config.app_vm_config.as_mut().segmentation_limits.max_cells = max_cells; + } + + let sdk = Sdk::new(config)?; + + let manifest_dir = get_programs_dir().join("regex"); + let elf = sdk.build(Default::default(), manifest_dir, &None, None)?; + let app_exe = sdk.convert_to_exe(elf)?; + + let data = include_str!("../../../guest/regex/regex_email.txt"); + let fe_bytes = data.to_owned().into_bytes(); + let input = StdIn::::from_bytes(&fe_bytes); + + let (app_pk, _app_vk) = sdk.app_keygen(); + + let max_par_jobs: usize = var("MAX_PAR_JOBS").map(|m| m.parse()).unwrap_or(Ok(1))?; + + let prover = AsyncAppProver::::new( + SdkVmBuilder, + app_pk.app_vm_pk.clone(), + app_exe, + app_pk.leaf_verifier_program_commit(), + max_par_jobs, + )?; + let _proof = prover.prove(input).await?; + + Ok(()) +} diff --git a/crates/sdk/Cargo.toml b/crates/sdk/Cargo.toml index 44197f1eab..e9e13bbffb 100644 --- a/crates/sdk/Cargo.toml +++ b/crates/sdk/Cargo.toml @@ -61,9 +61,10 @@ forge-fmt = { workspace = true, optional = true } rrs-lib.workspace = true num-bigint.workspace = true cfg-if.workspace = true +tokio = { workspace = true, features = ["rt", "sync"], optional = true } [features] -default = ["parallel", "jemalloc"] +default = ["parallel", "jemalloc", "async"] evm-prove = [ "openvm-continuations/static-verifier", "openvm-native-recursion/evm-prove", @@ -101,6 +102,7 @@ perf-metrics = [ # turns on stark-backend debugger in all proofs stark-debug = ["openvm-circuit/stark-debug"] test-utils = ["openvm-circuit/test-utils"] +async = ["tokio"] # performance features: # (rayon is always imported because of halo2, so "parallel" feature is redundant) parallel = ["openvm-circuit/parallel"] diff --git a/crates/sdk/src/prover/app.rs b/crates/sdk/src/prover/app.rs index c85cf413a2..7cc900fd1d 100644 --- a/crates/sdk/src/prover/app.rs +++ b/crates/sdk/src/prover/app.rs @@ -1,5 +1,7 @@ use std::sync::{Arc, OnceLock}; +#[cfg(feature = "async")] +pub use async_prover::*; use getset::Getters; use itertools::Itertools; use openvm_circuit::{ @@ -130,10 +132,7 @@ where + MeteredExecutor> + PreflightExecutor, VB::RecordArena>, { - assert!( - self.vm_config().as_ref().continuation_enabled, - "Use generate_app_proof_without_continuations instead." - ); + assert!(self.vm_config().as_ref().continuation_enabled); check_max_constraint_degrees( self.vm_config().as_ref(), &self.instance.vm.engine.fri_params(), @@ -238,3 +237,233 @@ pub fn verify_app_proof( user_public_values, }) } + +#[cfg(feature = "async")] +mod async_prover { + use derivative::Derivative; + use openvm_circuit::{ + arch::ExecutionError, system::memory::merkle::public_values::UserPublicValuesProof, + }; + use openvm_stark_sdk::config::FriParameters; + use tokio::{spawn, sync::Semaphore, task::spawn_blocking}; + use tracing::{instrument, Instrument}; + + use super::*; + + /// Thread-safe asynchronous app prover. + #[derive(Derivative, Getters)] + #[derivative(Clone)] + pub struct AsyncAppProver + where + E: StarkEngine, + VB: VmBuilder, + { + pub program_name: Option, + #[getset(get = "pub")] + vm_builder: VB, + #[getset(get = "pub")] + app_vm_pk: Arc>, + app_exe: Arc>>, + #[getset(get = "pub")] + leaf_verifier_program_commit: Com, + + semaphore: Arc, + } + + impl AsyncAppProver + where + E: StarkFriEngine + 'static, + VB: VmBuilder + Clone + Send + Sync + 'static, + VB::VmConfig: Send + Sync, + >>::Executor: Executor> + + MeteredExecutor> + + PreflightExecutor, VB::RecordArena>, + Val: PrimeField32, + Com: + AsRef<[Val; CHUNK]> + From<[Val; CHUNK]> + Into<[Val; CHUNK]>, + { + pub fn new( + vm_builder: VB, + app_vm_pk: Arc>, + app_exe: Arc>>, + leaf_verifier_program_commit: Com, + max_concurrency: usize, + ) -> Result { + Ok(Self { + program_name: None, + vm_builder, + app_vm_pk, + app_exe, + leaf_verifier_program_commit, + semaphore: Arc::new(Semaphore::new(max_concurrency)), + }) + } + + pub fn set_program_name(&mut self, program_name: impl AsRef) -> &mut Self { + self.program_name = Some(program_name.as_ref().to_string()); + self + } + pub fn with_program_name(mut self, program_name: impl AsRef) -> Self { + self.set_program_name(program_name); + self + } + + /// App Exe + pub fn exe(&self) -> Arc>> { + self.app_exe.clone() + } + + /// App VM config + pub fn vm_config(&self) -> &VB::VmConfig { + &self.app_vm_pk.vm_config + } + + pub fn fri_params(&self) -> FriParameters { + self.app_vm_pk.fri_params + } + + /// Creates an [AppProver] within a particular thread. The former instance is not + /// thread-safe and should **not** be moved between threads. + pub fn local(&self) -> Result, VirtualMachineError> { + AppProver::new( + self.vm_builder.clone(), + &self.app_vm_pk, + self.app_exe.clone(), + self.leaf_verifier_program_commit.clone(), + ) + } + + #[instrument( + name = "app proof", + skip_all, + fields( + group = self.program_name.as_ref().unwrap_or(&"app_proof".to_string()) + ) + )] + pub async fn prove( + self, + input: StdIn>, + ) -> eyre::Result> { + assert!(self.vm_config().as_ref().continuation_enabled); + check_max_constraint_degrees(self.vm_config().as_ref(), &self.fri_params()); + #[cfg(feature = "metrics")] + metrics::counter!("fri.log_blowup").absolute(self.fri_params().log_blowup as u64); + + // PERF[jpw]: it is possible to create metered_interpreter without creating vm. The + // latter is more convenient, but does unnecessary setup (e.g., transfer pk to + // device). Also, app_commit should be cached. + let mut local_prover = self.local()?; + let app_commit = local_prover.app_commit(); + local_prover.instance.reset_state(input.clone()); + let mut state = local_prover.instance.state_mut().take().unwrap(); + let vm = &mut local_prover.instance.vm; + let metered_ctx = vm.build_metered_ctx(&self.app_exe); + let metered_interpreter = vm.metered_interpreter(&self.app_exe)?; + let (segments, _) = metered_interpreter.execute_metered(input, metered_ctx)?; + drop(metered_interpreter); + let pure_interpreter = vm.interpreter(&self.app_exe)?; + let mut tasks = Vec::with_capacity(segments.len()); + let terminal_instret = segments + .last() + .map(|s| s.instret_start + s.num_insns) + .unwrap_or(u64::MAX); + for (seg_idx, segment) in segments.into_iter().enumerate() { + tracing::info!( + %seg_idx, + instret = state.instret(), + %segment.instret_start, + pc = state.pc(), + "Re-executing", + ); + let num_insns = segment.instret_start.checked_sub(state.instret()).unwrap(); + state = pure_interpreter.execute_from_state(state, Some(num_insns))?; + + let semaphore = self.semaphore.clone(); + let async_worker = self.clone(); + let start_state = state.clone(); + let task = spawn( + async move { + let _permit = semaphore.acquire().await?; + let span = tracing::Span::current(); + spawn_blocking(move || { + let _span = span.enter(); + info_span!("prove_segment", segment = seg_idx).in_scope( + || -> eyre::Result<_> { + // We need a separate span so the metric label includes + // "segment" + // from _segment_span + let _prove_span = info_span!( + "vm_prove", + thread_id = ?std::thread::current().id() + ) + .entered(); + let mut worker = async_worker.local()?; + let instance = &mut worker.instance; + let vm = &mut instance.vm; + let preflight_interpreter = &mut instance.interpreter; + let (segment_proof, _) = vm.prove( + preflight_interpreter, + start_state, + Some(segment.num_insns), + &segment.trace_heights, + )?; + Ok(segment_proof) + }, + ) + }) + .await? + } + .in_current_span(), + ); + tasks.push(task); + } + // Finish execution to termination + state = pure_interpreter.execute_from_state(state, None)?; + if state.instret() != terminal_instret { + tracing::warn!( + "Pure execution terminal instret={}, metered execution terminal instret={}", + state.instret(), + terminal_instret + ); + // This should never happen + return Err(ExecutionError::DidNotTerminate.into()); + } + let final_memory = &state.memory.memory; + let user_public_values = UserPublicValuesProof::compute( + vm.config().as_ref().memory_config.memory_dimensions(), + vm.config().as_ref().num_public_values, + &vm_poseidon2_hasher(), + final_memory, + ); + + let mut proofs = Vec::with_capacity(tasks.len()); + for task in tasks { + let proof = task.await??; + proofs.push(proof); + } + let cont_proof = ContinuationVmProof { + per_segment: proofs, + user_public_values, + }; + + // We skip verification of the user public values proof here because it is directly + // computed from the merkle tree above + let engine = E::new(self.fri_params()); + let res = verify_segments( + &engine, + &self.app_vm_pk.vm_pk.get_vk(), + &cont_proof.per_segment, + )?; + let app_exe_commit_u32s = app_commit.app_exe_commit.to_u32_digest(); + let exe_commit_u32s = res.exe_commit.map(|x| x.as_canonical_u32()); + if exe_commit_u32s != app_exe_commit_u32s { + return Err(VmVerificationError::ExeCommitMismatch { + expected: app_exe_commit_u32s, + actual: exe_commit_u32s, + } + .into()); + } + Ok(cont_proof) + } + } +} diff --git a/crates/vm/src/arch/execution.rs b/crates/vm/src/arch/execution.rs index 4431c2f1e3..b6c8270585 100644 --- a/crates/vm/src/arch/execution.rs +++ b/crates/vm/src/arch/execution.rs @@ -66,6 +66,8 @@ pub enum ExecutionError { FailedWithExitCode(u32), #[error("trace buffer out of bounds: requested {requested} but capacity is {capacity}")] TraceBufferOutOfBounds { requested: usize, capacity: usize }, + #[error("instruction counter overflow: {instret} + {num_insns} > u64::MAX")] + InstretOverflow { instret: u64, num_insns: u64 }, #[error("inventory error: {0}")] Inventory(#[from] ExecutorInventoryError), #[error("static program error: {0}")] diff --git a/crates/vm/src/arch/interpreter.rs b/crates/vm/src/arch/interpreter.rs index 03acbf245c..52b1980424 100644 --- a/crates/vm/src/arch/interpreter.rs +++ b/crates/vm/src/arch/interpreter.rs @@ -104,6 +104,12 @@ macro_rules! run { #[cfg(feature = "tco")] { tracing::debug!("execute_tco"); + + if $ctx::should_suspend($instret, $pc, $arg, &mut $exec_state) { + $exec_state.set_instret_and_pc($instret, $pc); + return Ok(()); + } + let handler = $interpreter .get_handler($pc) .ok_or(ExecutionError::PcOutOfBounds($pc))?; @@ -353,10 +359,21 @@ where from_state: VmState, num_insns: Option, ) -> Result, ExecutionError> { - let ctx = ExecutionCtx::new(num_insns); + let instret = from_state.instret(); + let instret_end = if let Some(n) = num_insns { + let end = instret + .checked_add(n) + .ok_or(ExecutionError::InstretOverflow { + instret, + num_insns: n, + })?; + Some(end) + } else { + None + }; + let ctx = ExecutionCtx::new(instret_end); let mut exec_state = VmExecState::new(from_state, ctx); - let instret = exec_state.instret(); let pc = exec_state.pc(); let instret_end = exec_state.ctx.instret_end; run!( diff --git a/extensions/keccak256/circuit/Cargo.toml b/extensions/keccak256/circuit/Cargo.toml index 24e10deb10..c8cdacf8d1 100644 --- a/extensions/keccak256/circuit/Cargo.toml +++ b/extensions/keccak256/circuit/Cargo.toml @@ -35,6 +35,7 @@ cfg-if.workspace = true openvm-stark-sdk = { workspace = true } openvm-circuit = { workspace = true, features = ["test-utils"] } hex.workspace = true +tokio = { version = "1.47.1", features = ["full"] } [build-dependencies] openvm-cuda-builder = { workspace = true, optional = true } diff --git a/extensions/keccak256/circuit/src/tests.rs b/extensions/keccak256/circuit/src/tests.rs index 26f820928e..e874c3fb83 100644 --- a/extensions/keccak256/circuit/src/tests.rs +++ b/extensions/keccak256/circuit/src/tests.rs @@ -406,3 +406,103 @@ fn test_keccak256_cuda_tracegen() { .simple_test() .unwrap(); } + +#[cfg(feature = "cuda")] +#[test] +fn test_keccak256_cuda_tracegen_multi() { + let num_threads: usize = std::env::var("NUM_THREADS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(2); + + let num_tasks: usize = std::env::var("NUM_TASKS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(num_threads * 4); + + let runtime = tokio::runtime::Builder::new_multi_thread() + .max_blocking_threads(num_threads) + .enable_all() + .build() + .unwrap(); + + runtime.block_on(async { + let tasks_per_thread = num_tasks.div_ceil(num_threads); + let mut worker_handles = Vec::new(); + + for worker_idx in 0..num_threads { + let start_task = worker_idx * tasks_per_thread; + let end_task = std::cmp::min(start_task + tasks_per_thread, num_tasks); + + let worker_handle = tokio::task::spawn(async move { + for task_id in start_task..end_task { + tokio::task::spawn_blocking(move || { + println!("[worker {}, task {}] Starting test", worker_idx, task_id); + + let mut rng = create_seeded_rng(); + let mut tester = GpuChipTestBuilder::default() + .with_bitwise_op_lookup(default_bitwise_lookup_bus()); + + let mut harness = create_cuda_harness(&tester); + + let num_ops: usize = 10; + for _ in 0..num_ops { + set_and_execute( + &mut tester, + &mut harness.executor, + &mut harness.dense_arena, + &mut rng, + KECCAK256, + None, + None, + None, + ); + } + + for len in [0, 135, 136, 137, 2000] { + set_and_execute( + &mut tester, + &mut harness.executor, + &mut harness.dense_arena, + &mut rng, + KECCAK256, + None, + Some(len), + None, + ); + } + + harness + .dense_arena + .get_record_seeker::() + .transfer_to_matrix_arena(&mut harness.matrix_arena); + + tester + .build() + .load_gpu_harness(harness) + .finalize() + .simple_test() + .unwrap(); + + println!( + "[worker {}, task {}] Test completed ✅", + worker_idx, task_id + ); + }) + .await + .expect("task failed"); + } + }); + worker_handles.push(worker_handle); + } + + for handle in worker_handles { + handle.await.expect("worker failed"); + } + + println!( + "\nAll {} tasks completed on {} workers.", + num_tasks, num_threads + ); + }); +}