diff --git a/.github/workflows/benchmark-call.yml b/.github/workflows/benchmark-call.yml index 71aa9bc54a..7ad9687a01 100644 --- a/.github/workflows/benchmark-call.yml +++ b/.github/workflows/benchmark-call.yml @@ -107,7 +107,7 @@ on: env: S3_METRICS_PATH: s3://openvm-public-data-sandbox-us-east-1/benchmark/github/metrics S3_FLAMEGRAPHS_PATH: s3://openvm-public-data-sandbox-us-east-1/benchmark/github/flamegraphs - FEATURE_FLAGS: "metrics,parallel,nightly-features" + FEATURE_FLAGS: "metrics,parallel,nightly-features,tco" INPUT_ARGS: "" CARGO_NET_GIT_FETCH_WITH_CLI: "true" diff --git a/.github/workflows/benchmarks-execute.yml b/.github/workflows/benchmarks-execute.yml index fba8eca81c..5108c7d48d 100644 --- a/.github/workflows/benchmarks-execute.yml +++ b/.github/workflows/benchmarks-execute.yml @@ -2,8 +2,7 @@ name: "Execution benchmarks" on: push: - # TODO(ayush): remove after feat/new-execution is merged - branches: ["main", "feat/new-execution"] + branches: ["main"] pull_request: types: [opened, synchronize, reopened, labeled] branches: ["**"] @@ -28,6 +27,7 @@ env: CARGO_TERM_COLOR: always S3_FIXTURES_PATH: s3://openvm-public-data-sandbox-us-east-1/benchmark/fixtures JEMALLOC_SYS_WITH_MALLOC_CONF: "retain:true,background_thread:true,metadata_thp:always,thp:always,dirty_decay_ms:10000,muzzy_decay_ms:10000,abort_conf:true" + TOOLCHAIN: "+nightly-2025-08-19" jobs: codspeed-walltime-benchmarks: @@ -66,12 +66,12 @@ jobs: - name: Build benchmarks working-directory: benchmarks/execute - run: cargo codspeed build --profile maxperf + run: cargo $TOOLCHAIN codspeed build --profile maxperf --features tco - name: Run benchmarks uses: CodSpeedHQ/action@v3 with: working-directory: benchmarks/execute - run: cargo codspeed run + run: cargo $TOOLCHAIN codspeed run token: ${{ secrets.CODSPEED_TOKEN }} codspeed-instrumentation-benchmarks: @@ -111,10 +111,10 @@ jobs: - name: Build benchmarks working-directory: benchmarks/execute - run: cargo codspeed build + run: cargo $TOOLCHAIN codspeed build --features tco - name: Run benchmarks uses: CodSpeedHQ/action@v3 with: working-directory: benchmarks/execute - run: cargo codspeed run + run: cargo $TOOLCHAIN codspeed run token: ${{ secrets.CODSPEED_TOKEN }} diff --git a/benchmarks/execute/Cargo.toml b/benchmarks/execute/Cargo.toml index fdfd4bf0e9..5fcf58b1de 100644 --- a/benchmarks/execute/Cargo.toml +++ b/benchmarks/execute/Cargo.toml @@ -46,6 +46,7 @@ divan = { package = "codspeed-divan-compat", version = "3.0.2" } [features] default = ["jemalloc"] +tco = ["openvm-sdk/tco"] mimalloc = ["openvm-circuit/mimalloc"] jemalloc = ["openvm-circuit/jemalloc"] jemalloc-prof = ["openvm-circuit/jemalloc-prof"] diff --git a/benchmarks/prove/Cargo.toml b/benchmarks/prove/Cargo.toml index 88f0784e95..786be53ae3 100644 --- a/benchmarks/prove/Cargo.toml +++ b/benchmarks/prove/Cargo.toml @@ -33,8 +33,9 @@ metrics.workspace = true [dev-dependencies] [features] -default = ["parallel", "jemalloc", "metrics", "evm"] +default = ["parallel", "jemalloc", "metrics"] metrics = ["openvm-sdk/metrics"] +tco = ["openvm-sdk/tco"] perf-metrics = ["openvm-sdk/perf-metrics", "metrics"] stark-debug = ["openvm-sdk/stark-debug"] # runs leaf aggregation benchmarks: diff --git a/ci/scripts/bench.py b/ci/scripts/bench.py index 9bf87f622f..8584999d48 100644 --- a/ci/scripts/bench.py +++ b/ci/scripts/bench.py @@ -15,9 +15,12 @@ def run_cargo_command( kzg_params_dir, profile="release" ): + toolchain = "+1.86" + if "tco" in feature_flags: + toolchain = "+nightly-2025-08-19" # Command to run (for best performance but slower builds, use --profile maxperf) command = [ - "cargo", "run", "--no-default-features", "-p", "openvm-benchmarks-prove", "--bin", bin_name, "--profile", profile, "--features", ",".join(feature_flags), "--" + "cargo", toolchain, "run", "--no-default-features", "-p", "openvm-benchmarks-prove", "--bin", bin_name, "--profile", profile, "--features", ",".join(feature_flags), "--" ] if app_log_blowup is not None: diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index 4983992663..66363ecc0d 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -45,6 +45,7 @@ default = ["parallel", "jemalloc", "evm-verify", "metrics"] evm-prove = ["openvm-sdk/evm-prove"] evm-verify = ["evm-prove", "openvm-sdk/evm-verify"] metrics = ["openvm-sdk/metrics"] +tco = ["openvm-sdk/tco"] # for guest profiling: perf-metrics = ["openvm-sdk/perf-metrics", "metrics"] # performance features: diff --git a/crates/cli/src/lib.rs b/crates/cli/src/lib.rs index 1b58c45920..4516207ade 100644 --- a/crates/cli/src/lib.rs +++ b/crates/cli/src/lib.rs @@ -1,3 +1,6 @@ +#![cfg_attr(feature = "tco", allow(incomplete_features))] +#![cfg_attr(feature = "tco", feature(explicit_tail_calls))] + pub mod commands; pub mod default; pub mod input; diff --git a/crates/sdk/Cargo.toml b/crates/sdk/Cargo.toml index 8e1bdd449a..f5d2f9b141 100644 --- a/crates/sdk/Cargo.toml +++ b/crates/sdk/Cargo.toml @@ -79,6 +79,17 @@ metrics = [ "openvm-native-recursion/metrics", "openvm-native-compiler/metrics", ] +tco = [ + "openvm-circuit/tco", + "openvm-rv32im-circuit/tco", + "openvm-native-circuit/tco", + "openvm-sha256-circuit/tco", + "openvm-keccak256-circuit/tco", + "openvm-bigint-circuit/tco", + "openvm-algebra-circuit/tco", + "openvm-ecc-circuit/tco", + "openvm-pairing-circuit/tco" +] # for guest profiling: perf-metrics = ["openvm-circuit/perf-metrics", "openvm-transpiler/function-span"] # turns on stark-backend debugger in all proofs diff --git a/crates/sdk/src/lib.rs b/crates/sdk/src/lib.rs index 69ea7fc5d9..1df612540f 100644 --- a/crates/sdk/src/lib.rs +++ b/crates/sdk/src/lib.rs @@ -1,3 +1,5 @@ +#![cfg_attr(feature = "tco", allow(incomplete_features))] +#![cfg_attr(feature = "tco", feature(explicit_tail_calls))] use std::{ borrow::Borrow, fs::read, diff --git a/crates/vm/Cargo.toml b/crates/vm/Cargo.toml index 55d2e16030..7cd7e2ca26 100644 --- a/crates/vm/Cargo.toml +++ b/crates/vm/Cargo.toml @@ -68,6 +68,9 @@ basic-memory = [] # turns on stark-backend debugger in all proofs stark-debug = [] test-utils = ["openvm-stark-sdk"] +# Tail call optimizations. This requires nightly for the `become` keyword (https://github.com/rust-lang/rust/pull/144232). +# However tail call elimination is still an incomplete feature in Rust, so the `tco` feature remains experimental until then. +tco = ["openvm-circuit-derive/tco"] # performance features: mimalloc = ["openvm-stark-backend/mimalloc"] jemalloc = ["openvm-stark-backend/jemalloc"] diff --git a/crates/vm/derive/Cargo.toml b/crates/vm/derive/Cargo.toml index d2d11dcc78..f3fd65e2e9 100644 --- a/crates/vm/derive/Cargo.toml +++ b/crates/vm/derive/Cargo.toml @@ -10,7 +10,10 @@ license.workspace = true proc-macro = true [dependencies] -syn = { version = "2.0", features = ["parsing"] } +syn = { version = "2.0", features = ["parsing", "full"] } quote = "1.0" proc-macro2 = "1.0" itertools = { workspace = true } + +[features] +tco = [] diff --git a/crates/vm/derive/src/lib.rs b/crates/vm/derive/src/lib.rs index a43053e0cd..33051b3532 100644 --- a/crates/vm/derive/src/lib.rs +++ b/crates/vm/derive/src/lib.rs @@ -9,6 +9,9 @@ use syn::{ GenericParam, Ident, Meta, Token, }; +#[cfg(feature = "tco")] +mod tco; + #[proc_macro_derive(PreflightExecutor)] pub fn preflight_executor_derive(input: TokenStream) -> TokenStream { let ast: syn::DeriveInput = syn::parse(input).unwrap(); @@ -155,6 +158,25 @@ pub fn executor_derive(input: TokenStream) -> TokenStream { where_clause .predicates .push(syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::Executor }); + + // We use the macro's feature to decide whether to generate the impl or not. This avoids + // the target crate needing the "tco" feature defined. + #[cfg(feature = "tco")] + let handler = quote! { + fn handler( + &self, + pc: u32, + inst: &::openvm_circuit::arch::instructions::instruction::Instruction, + data: &mut [u8], + ) -> Result<::openvm_circuit::arch::Handler, ::openvm_circuit::arch::StaticProgramError> + where + Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, { + self.0.handler(pc, inst, data) + } + }; + #[cfg(not(feature = "tco"))] + let handler = quote! {}; + quote! { impl #impl_generics ::openvm_circuit::arch::Executor for #name #ty_generics #where_clause { #[inline(always)] @@ -172,6 +194,8 @@ pub fn executor_derive(input: TokenStream) -> TokenStream { Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, { self.0.pre_compute(pc, inst, data) } + + #handler } } .into() @@ -205,7 +229,7 @@ pub fn executor_derive(input: TokenStream) -> TokenStream { }); // Use full path ::openvm_circuit... so it can be used either within or outside the vm // crate. Assume F is already generic of the field. - let (pre_compute_size_arms, pre_compute_arms, where_predicates): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(|(variant_name, field)| { + let (pre_compute_size_arms, pre_compute_arms, _handler_arms, where_predicates): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(|(variant_name, field)| { let field_ty = &field.ty; let pre_compute_size_arm = quote! { #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::Executor<#first_ty_generic>>::pre_compute_size(x) @@ -213,15 +237,38 @@ pub fn executor_derive(input: TokenStream) -> TokenStream { let pre_compute_arm = quote! { #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::Executor<#first_ty_generic>>::pre_compute(x, pc, instruction, data) }; + let handler_arm = quote! { + #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::Executor<#first_ty_generic>>::handler(x, pc, instruction, data) + }; let where_predicate = syn::parse_quote! { #field_ty: ::openvm_circuit::arch::Executor<#first_ty_generic> }; - (pre_compute_size_arm, pre_compute_arm, where_predicate) + (pre_compute_size_arm, pre_compute_arm, handler_arm, where_predicate) })); let where_clause = new_generics.make_where_clause(); for predicate in where_predicates { where_clause.predicates.push(predicate); } + // We use the macro's feature to decide whether to generate the impl or not. This avoids + // the target crate needing the "tco" feature defined. + #[cfg(feature = "tco")] + let handler = quote! { + fn handler( + &self, + pc: u32, + instruction: &::openvm_circuit::arch::instructions::instruction::Instruction, + data: &mut [u8], + ) -> Result<::openvm_circuit::arch::Handler, ::openvm_circuit::arch::StaticProgramError> + where + Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, { + match self { + #(#_handler_arms,)* + } + } + }; + #[cfg(not(feature = "tco"))] + let handler = quote! {}; + // Don't use these ty_generics because it might have extra "F" let (impl_generics, _, where_clause) = new_generics.split_for_impl(); @@ -247,6 +294,8 @@ pub fn executor_derive(input: TokenStream) -> TokenStream { #(#pre_compute_arms,)* } } + + #handler } } .into() @@ -282,6 +331,26 @@ pub fn metered_executor_derive(input: TokenStream) -> TokenStream { where_clause .predicates .push(syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::MeteredExecutor }); + + // We use the macro's feature to decide whether to generate the impl or not. This avoids + // the target crate needing the "tco" feature defined. + #[cfg(feature = "tco")] + let metered_handler = quote! { + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &::openvm_circuit::arch::instructions::instruction::Instruction, + data: &mut [u8], + ) -> Result<::openvm_circuit::arch::Handler, ::openvm_circuit::arch::StaticProgramError> + where + Ctx: ::openvm_circuit::arch::execution_mode::MeteredExecutionCtxTrait, { + self.0.metered_handler(chip_idx, pc, inst, data) + } + }; + #[cfg(not(feature = "tco"))] + let metered_handler = quote! {}; + quote! { impl #impl_generics ::openvm_circuit::arch::MeteredExecutor for #name #ty_generics #where_clause { #[inline(always)] @@ -300,6 +369,7 @@ pub fn metered_executor_derive(input: TokenStream) -> TokenStream { Ctx: ::openvm_circuit::arch::execution_mode::MeteredExecutionCtxTrait, { self.0.metered_pre_compute(chip_idx, pc, inst, data) } + #metered_handler } } .into() @@ -333,7 +403,7 @@ pub fn metered_executor_derive(input: TokenStream) -> TokenStream { }); // Use full path ::openvm_circuit... so it can be used either within or outside the vm // crate. Assume F is already generic of the field. - let (pre_compute_size_arms, metered_pre_compute_arms, where_predicates): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(|(variant_name, field)| { + let (pre_compute_size_arms, metered_pre_compute_arms, _metered_handler_arms, where_predicates): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(|(variant_name, field)| { let field_ty = &field.ty; let pre_compute_size_arm = quote! { #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::MeteredExecutor<#first_ty_generic>>::metered_pre_compute_size(x) @@ -341,10 +411,13 @@ pub fn metered_executor_derive(input: TokenStream) -> TokenStream { let metered_pre_compute_arm = quote! { #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::MeteredExecutor<#first_ty_generic>>::metered_pre_compute(x, chip_idx, pc, instruction, data) }; + let metered_handler_arm = quote! { + #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::MeteredExecutor<#first_ty_generic>>::metered_handler(x, chip_idx, pc, instruction, data) + }; let where_predicate = syn::parse_quote! { #field_ty: ::openvm_circuit::arch::MeteredExecutor<#first_ty_generic> }; - (pre_compute_size_arm, metered_pre_compute_arm, where_predicate) + (pre_compute_size_arm, metered_pre_compute_arm, metered_handler_arm, where_predicate) })); let where_clause = new_generics.make_where_clause(); for predicate in where_predicates { @@ -353,6 +426,28 @@ pub fn metered_executor_derive(input: TokenStream) -> TokenStream { // Don't use these ty_generics because it might have extra "F" let (impl_generics, _, where_clause) = new_generics.split_for_impl(); + // We use the macro's feature to decide whether to generate the impl or not. This avoids + // the target crate needing the "tco" feature defined. + #[cfg(feature = "tco")] + let metered_handler = quote! { + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + instruction: &::openvm_circuit::arch::instructions::instruction::Instruction, + data: &mut [u8], + ) -> Result<::openvm_circuit::arch::Handler, ::openvm_circuit::arch::StaticProgramError> + where + Ctx: ::openvm_circuit::arch::execution_mode::MeteredExecutionCtxTrait, + { + match self { + #(#_metered_handler_arms,)* + } + } + }; + #[cfg(not(feature = "tco"))] + let metered_handler = quote! {}; + quote! { impl #impl_generics ::openvm_circuit::arch::MeteredExecutor<#first_ty_generic> for #name #ty_generics #where_clause { #[inline(always)] @@ -376,6 +471,8 @@ pub fn metered_executor_derive(input: TokenStream) -> TokenStream { #(#metered_pre_compute_arms,)* } } + + #metered_handler } } .into() @@ -501,8 +598,12 @@ fn generate_config_traits_impl(name: &Ident, inner: &DataStruct) -> syn::Result< .iter() .filter(|f| f.attrs.iter().any(|attr| attr.path().is_ident("config"))) .exactly_one() - .clone() - .expect("Exactly one field must have the #[config] attribute"); + .map_err(|_| { + syn::Error::new( + name.span(), + "Exactly one field must have the #[config] attribute", + ) + })?; let (source_name, source_name_upper) = gen_name_with_uppercase_idents(source_field.ident.as_ref().unwrap()); @@ -700,3 +801,44 @@ fn parse_executor_type( }) } } + +/// An attribute procedural macro for creating TCO (Tail Call Optimization) handlers. +/// +/// This macro generates a handler function that wraps an execute implementation +/// with tail call optimization using the `become` keyword. It extracts the generics +/// and where clauses from the original function. +/// +/// # Usage +/// +/// Place this attribute above a function definition: +/// ``` +/// #[create_tco_handler] +/// unsafe fn execute_e1_impl( +/// pre_compute: &[u8], +/// state: &mut VmExecState, +/// ) where +/// CTX: ExecutionCtxTrait, +/// { +/// // function body +/// } +/// ``` +/// +/// This will generate a TCO handler function with the same generics and where clauses. +/// +/// # Safety +/// +/// Do not use this macro if your function wants to terminate execution without error with a +/// specific error code. The handler generated by this macro assumes that execution should continue +/// unless the execute_impl returns an error. This is done for performance to skip an exit code +/// check. +#[proc_macro_attribute] +pub fn create_tco_handler(_attr: TokenStream, item: TokenStream) -> TokenStream { + #[cfg(feature = "tco")] + { + tco::tco_impl(item) + } + #[cfg(not(feature = "tco"))] + { + item + } +} diff --git a/crates/vm/derive/src/tco.rs b/crates/vm/derive/src/tco.rs new file mode 100644 index 0000000000..9019acd1e6 --- /dev/null +++ b/crates/vm/derive/src/tco.rs @@ -0,0 +1,128 @@ +use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use syn::{parse_macro_input, ItemFn}; + +/// Implementation of the TCO handler generation logic. +/// This is called from the proc macro attribute in lib.rs. +pub fn tco_impl(item: TokenStream) -> TokenStream { + // Parse the input function + let input_fn = parse_macro_input!(item as ItemFn); + + // Extract information from the function + let fn_name = &input_fn.sig.ident; + let generics = &input_fn.sig.generics; + let where_clause = &generics.where_clause; + + // Extract the first two generic type parameters (F and CTX) + let (f_type, ctx_type) = extract_f_and_ctx_types(generics); + // Derive new function name: + // If original ends with `_impl`, replace with `_tco_handler`, else append suffix. + let new_name_str = fn_name + .to_string() + .strip_suffix("_impl") + .map(|base| format!("{base}_tco_handler")) + .unwrap_or_else(|| format!("{fn_name}_tco_handler")); + let handler_name = format_ident!("{}", new_name_str); + + // Build the generic parameters for the handler, preserving all original generics + let handler_generics = generics.clone(); + + // Build the function call with all the generics + let generic_args = build_generic_args(generics); + let execute_call = if generic_args.is_empty() { + quote! { #fn_name(pre_compute, exec_state) } + } else { + quote! { #fn_name::<#(#generic_args),*>(pre_compute, exec_state) } + }; + + // Generate the TCO handler function + let handler_fn = quote! { + #[inline(never)] + unsafe fn #handler_name #handler_generics ( + interpreter: &::openvm_circuit::arch::interpreter::InterpretedInstance<#f_type, #ctx_type>, + exec_state: &mut ::openvm_circuit::arch::VmExecState< + #f_type, + ::openvm_circuit::system::memory::online::GuestMemory, + #ctx_type, + >, + ) + #where_clause + { + use ::openvm_circuit::arch::ExecutionError; + + let pre_compute = interpreter.get_pre_compute(exec_state.vm_state.pc); + #execute_call; + + if exec_state.exit_code.is_err() { + // stop execution + return; + } + if #ctx_type::should_suspend(exec_state) { + return; + } + // exec_state.pc should have been updated by execute_impl at this point + let next_handler = interpreter.get_handler(exec_state.vm_state.pc); + if next_handler.is_none() { + exec_state.exit_code = Err(ExecutionError::PcOutOfBounds (exec_state.vm_state.pc)); + return; + } + let next_handler = next_handler.unwrap_unchecked(); + + // NOTE: `become` is a keyword that requires Rust Nightly. + // It is part of the explicit tail calls RFC: + // which is still incomplete. + become next_handler(interpreter, exec_state) + } + }; + + // Return both the original function and the new handler + let output = quote! { + #input_fn + + #handler_fn + }; + + TokenStream::from(output) +} + +fn extract_f_and_ctx_types(generics: &syn::Generics) -> (syn::Ident, syn::Ident) { + let mut type_params = generics.params.iter().filter_map(|param| { + if let syn::GenericParam::Type(type_param) = param { + Some(&type_param.ident) + } else { + None + } + }); + + let f_type = type_params + .next() + .expect("Function must have at least one type parameter (F)") + .clone(); + let ctx_type = type_params + .next() + .expect("Function must have at least two type parameters (F and CTX)") + .clone(); + + (f_type, ctx_type) +} + +fn build_generic_args(generics: &syn::Generics) -> Vec { + generics + .params + .iter() + .map(|param| match param { + syn::GenericParam::Type(type_param) => { + let ident = &type_param.ident; + quote! { #ident } + } + syn::GenericParam::Lifetime(lifetime) => { + let lifetime = &lifetime.lifetime; + quote! { #lifetime } + } + syn::GenericParam::Const(const_param) => { + let ident = &const_param.ident; + quote! { #ident } + } + }) + .collect() +} diff --git a/crates/vm/src/arch/execution.rs b/crates/vm/src/arch/execution.rs index 4e3f11804c..6bd3d3b90a 100644 --- a/crates/vm/src/arch/execution.rs +++ b/crates/vm/src/arch/execution.rs @@ -12,6 +12,8 @@ use serde::{Deserialize, Serialize}; use thiserror::Error; use super::{execution_mode::ExecutionCtxTrait, Streams, VmExecState}; +#[cfg(feature = "tco")] +use crate::arch::interpreter::InterpretedInstance; #[cfg(feature = "metrics")] use crate::metrics::VmMetrics; use crate::{ @@ -26,12 +28,8 @@ use crate::{ pub enum ExecutionError { #[error("execution failed at pc {pc}, err: {msg}")] Fail { pc: u32, msg: &'static str }, - #[error("pc {pc} out of bounds for program of length {program_len}, with pc_base {pc_base}")] - PcOutOfBounds { - pc: u32, - pc_base: u32, - program_len: usize, - }, + #[error("pc {0} out of bounds")] + PcOutOfBounds(u32), #[error("unreachable instruction at pc {0}")] Unreachable(u32), #[error("at pc {pc}, opcode {opcode} was not enabled")] @@ -91,7 +89,20 @@ pub enum StaticProgramError { /// The `pre_compute: &[u8]` is a pre-computed buffer of data corresponding to a single instruction. /// The contents of `pre_compute` are determined from the program code as specified by the /// [Executor] and [MeteredExecutor] traits. -pub type ExecuteFunc = unsafe fn(&[u8], &mut VmExecState); +pub type ExecuteFunc = + unsafe fn(pre_compute: &[u8], exec_state: &mut VmExecState); + +/// Handler for tail call elimination. The `CTX` is assumed to contain pointers to the pre-computed +/// buffer and the function handler table. +/// +/// - `pre_compute_buf` is the starting pointer of the pre-computed buffer. +/// - `handlers` is the starting pointer of the table of function pointers of `Handler` type. The +/// pointer is typeless to avoid self-referential types. +#[cfg(feature = "tco")] +pub type Handler = unsafe fn( + interpreter: &InterpretedInstance, + exec_state: &mut VmExecState, +); /// Trait for pure execution via a host interpreter. The trait methods provide the methods to /// pre-process the program code into function pointers which operate on `pre_compute` instruction @@ -108,6 +119,20 @@ pub trait Executor { ) -> Result, StaticProgramError> where Ctx: ExecutionCtxTrait; + + /// Returns a function pointer with tail call optimization. The handler function assumes that + /// the pre-compute buffer it receives is the populated `data`. + // NOTE: we could have used `pre_compute` above to populate `data`, but the implementations were + // simpler to keep `handler` entirely separate from `pre_compute`. + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait; } /// Trait for metered execution via a host interpreter. The trait methods provide the methods to @@ -126,6 +151,22 @@ pub trait MeteredExecutor { ) -> Result, StaticProgramError> where Ctx: MeteredExecutionCtxTrait; + + /// Returns a function pointer with tail call optimization. The handler function assumes that + /// the pre-compute buffer it receives is the populated `data`. + // NOTE: we could have used `metered_pre_compute` above to populate `data`, but the + // implementations were simpler to keep `metered_handler` entirely separate from + // `metered_pre_compute`. + #[cfg(feature = "tco")] + fn metered_handler( + &self, + air_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait; } /// Trait for preflight execution via a host interpreter. The trait methods allow execution of diff --git a/crates/vm/src/arch/execution_mode/pure.rs b/crates/vm/src/arch/execution_mode/pure.rs index 176a8c8a2b..83001d7b64 100644 --- a/crates/vm/src/arch/execution_mode/pure.rs +++ b/crates/vm/src/arch/execution_mode/pure.rs @@ -19,12 +19,6 @@ impl ExecutionCtx { } } -impl Default for ExecutionCtx { - fn default() -> Self { - Self::new(None) - } -} - impl ExecutionCtxTrait for ExecutionCtx { #[inline(always)] fn on_memory_operation(&mut self, _address_space: u32, _ptr: u32, _size: u32) {} diff --git a/crates/vm/src/arch/interpreter.rs b/crates/vm/src/arch/interpreter.rs index 7962fdc7d9..504059c286 100644 --- a/crates/vm/src/arch/interpreter.rs +++ b/crates/vm/src/arch/interpreter.rs @@ -1,6 +1,9 @@ +#[cfg(feature = "tco")] +use std::marker::PhantomData; use std::{ alloc::{alloc, dealloc, handle_alloc_error, Layout}, borrow::{Borrow, BorrowMut}, + iter::repeat_n, ptr::NonNull, }; @@ -15,6 +18,8 @@ use openvm_instructions::{ use openvm_stark_backend::p3_field::PrimeField32; use tracing::info_span; +#[cfg(feature = "tco")] +use crate::arch::Handler; use crate::{ arch::{ execution_mode::{ @@ -42,15 +47,25 @@ pub struct InterpretedInstance<'a, F, Ctx> { #[allow(dead_code)] pre_compute_buf: AlignedBuf, /// Instruction table of function pointers and pointers to the pre-computed buffer. Indexed by - /// `pc_index = (pc - pc_base) / DEFAULT_PC_STEP`. + /// `pc_index = pc / DEFAULT_PC_STEP`. + /// SAFETY: The first `pc_base / DEFAULT_PC_STEP` entries will be unreachable. We do this to + /// avoid needing to subtract `pc_base` during runtime. + #[cfg(not(feature = "tco"))] pre_compute_insns: Vec>, + #[cfg(feature = "tco")] + pre_compute_max_size: usize, + /// Handler function pointers for tail call optimization. + #[cfg(feature = "tco")] + handlers: Vec>, - pc_base: u32, pc_start: u32, init_memory: SparseMemoryImage, + #[cfg(feature = "tco")] + phantom: PhantomData<&'a ()>, } +#[cfg_attr(feature = "tco", allow(dead_code))] struct PreComputeInstruction<'a, F, Ctx> { pub handler: ExecuteFunc, pub pre_compute: &'a [u8], @@ -62,17 +77,46 @@ struct TerminatePreCompute { exit_code: u32, } -macro_rules! execute_with_metrics { - ($span:literal, $pc_base:expr, $exec_state:expr, $pre_compute_insts:expr) => {{ +macro_rules! run { + ($span:literal, $interpreter:ident, $exec_state:ident, $ctx:ident) => {{ #[cfg(feature = "metrics")] let start = std::time::Instant::now(); #[cfg(feature = "metrics")] let start_instret = $exec_state.instret; - // SAFETY: pre_compute_insts contains valid function pointers and pre-computed data - info_span!($span).in_scope(|| unsafe { - execute_trampoline($pc_base, $exec_state, $pre_compute_insts); - }); + info_span!($span).in_scope(|| -> Result<(), ExecutionError> { + // SAFETY: + // - it is the responsibility of each Executor to ensure that pre_compute_insts contains + // valid function pointers and pre-computed data + #[cfg(not(feature = "tco"))] + unsafe { + tracing::debug!("execute_trampoline"); + execute_trampoline(&mut $exec_state, &$interpreter.pre_compute_insns); + } + #[cfg(feature = "tco")] + { + tracing::debug!("execute_tco"); + let handler = $interpreter + .get_handler($exec_state.pc) + .ok_or(ExecutionError::PcOutOfBounds($exec_state.pc))?; + // SAFETY: + // - handler is generated by Executor, MeteredExecutor traits + // - it is the responsibility of each Executor to ensure handler is safe given a + // valid VM state + unsafe { + handler($interpreter, &mut $exec_state); + } + + if $exec_state + .exit_code + .as_ref() + .is_ok_and(|exit_code| exit_code.is_some()) + { + $ctx::on_terminate(&mut $exec_state); + } + } + Ok(()) + })?; #[cfg(feature = "metrics")] { @@ -106,27 +150,90 @@ where { let program = &exe.program; let pre_compute_max_size = get_pre_compute_max_size(program, inventory); - let mut pre_compute_buf = alloc_pre_compute_buf(program.len(), pre_compute_max_size); + let mut pre_compute_buf = alloc_pre_compute_buf(program, pre_compute_max_size); let mut split_pre_compute_buf = split_pre_compute_buf(program, &mut pre_compute_buf, pre_compute_max_size); + #[cfg_attr(feature = "tco", allow(unused_variables))] let pre_compute_insns = get_pre_compute_instructions::( program, inventory, &mut split_pre_compute_buf, )?; - let pc_base = program.pc_base; let pc_start = exe.pc_start; let init_memory = exe.init_memory.clone(); + #[cfg(feature = "tco")] + let handlers = repeat_n(&None, get_pc_index(program.pc_base)) + .chain(program.instructions_and_debug_infos.iter()) + .zip_eq(split_pre_compute_buf.iter_mut()) + .enumerate() + .map( + |(pc_idx, (inst_opt, pre_compute))| -> Result, StaticProgramError> { + if let Some((inst, _)) = inst_opt { + let pc = pc_idx as u32 * DEFAULT_PC_STEP; + if get_system_opcode_handler::(inst, pre_compute).is_some() { + Ok(terminate_execute_e12_tco_handler) + } else { + // unwrap because get_pre_compute_instructions would have errored + // already on DisabledOperation + let executor = inventory.get_executor(inst.opcode).unwrap(); + executor.handler(pc, inst, pre_compute) + } + } else { + Ok(unreachable_tco_handler) + } + }, + ) + .collect::, _>>()?; Ok(Self { system_config: inventory.config().clone(), pre_compute_buf, + #[cfg(not(feature = "tco"))] pre_compute_insns, - pc_base, pc_start, init_memory, + #[cfg(feature = "tco")] + pre_compute_max_size, + #[cfg(feature = "tco")] + handlers, + #[cfg(feature = "tco")] + phantom: PhantomData, }) } + + /// # Safety + /// - This function assumes that the `pc` is within program bounds - this should be the case if + /// the pc is checked to be in bounds before jumping to it. + /// - The returned slice may not be entirely initialized, but it is the job of each Executor to + /// initialize the parts of the buffer that the instruction handler will use. + #[cfg(feature = "tco")] + #[inline(always)] + pub fn get_pre_compute(&self, pc: u32) -> &[u8] { + let pc_idx = get_pc_index(pc); + // SAFETY: + // - we assume that pc is in bounds + // - pre_compute_buf is allocated for pre_compute_max_size * program_len bytes, with each + // instruction getting pre_compute_max_size bytes + // - self.pre_compute_buf.ptr is non-null + // - initialization of the contents of the slice is the responsibility of each Executor + debug_assert!( + (pc_idx + 1) * self.pre_compute_max_size <= self.pre_compute_buf.layout.size() + ); + unsafe { + let ptr = self + .pre_compute_buf + .ptr + .add(pc_idx * self.pre_compute_max_size); + std::slice::from_raw_parts(ptr, self.pre_compute_max_size) + } + } + + #[cfg(feature = "tco")] + #[inline(always)] + pub fn get_handler(&self, pc: u32) -> Option> { + let pc_idx = get_pc_index(pc); + self.handlers.get(pc_idx).copied() + } } impl<'a, F, Ctx> InterpretedInstance<'a, F, Ctx> @@ -146,9 +253,10 @@ where { let program = &exe.program; let pre_compute_max_size = get_metered_pre_compute_max_size(program, inventory); - let mut pre_compute_buf = alloc_pre_compute_buf(program.len(), pre_compute_max_size); + let mut pre_compute_buf = alloc_pre_compute_buf(program, pre_compute_max_size); let mut split_pre_compute_buf = split_pre_compute_buf(program, &mut pre_compute_buf, pre_compute_max_size); + #[cfg_attr(feature = "tco", allow(unused_variables))] let pre_compute_insns = get_metered_pre_compute_instructions::( program, inventory, @@ -156,17 +264,47 @@ where &mut split_pre_compute_buf, )?; - let pc_base = program.pc_base; let pc_start = exe.pc_start; let init_memory = exe.init_memory.clone(); + #[cfg(feature = "tco")] + let handlers = repeat_n(&None, get_pc_index(program.pc_base)) + .chain(program.instructions_and_debug_infos.iter()) + .zip_eq(split_pre_compute_buf.iter_mut()) + .enumerate() + .map( + |(pc_idx, (inst_opt, pre_compute))| -> Result, StaticProgramError> { + if let Some((inst, _)) = inst_opt { + let pc = pc_idx as u32 * DEFAULT_PC_STEP; + if get_system_opcode_handler::(inst, pre_compute).is_some() { + Ok(terminate_execute_e12_tco_handler) + } else { + // unwrap because get_pre_compute_instructions would have errored + // already on DisabledOperation + let executor_idx = inventory.instruction_lookup[&inst.opcode] as usize; + let executor = &inventory.executors[executor_idx]; + let air_idx = executor_idx_to_air_idx[executor_idx]; + executor.metered_handler(air_idx, pc, inst, pre_compute) + } + } else { + Ok(unreachable_tco_handler) + } + }, + ) + .collect::, _>>()?; Ok(Self { system_config: inventory.config().clone(), pre_compute_buf, + #[cfg(not(feature = "tco"))] pre_compute_insns, - pc_base, pc_start, init_memory, + #[cfg(feature = "tco")] + pre_compute_max_size, + #[cfg(feature = "tco")] + handlers, + #[cfg(feature = "tco")] + phantom: PhantomData, }) } } @@ -208,13 +346,7 @@ where ) -> Result, ExecutionError> { let ctx = ExecutionCtx::new(num_insns); let mut exec_state = VmExecState::new(from_state, ctx); - // Start execution - execute_with_metrics!( - "execute_e1", - self.pc_base, - &mut exec_state, - &self.pre_compute_insns - ); + run!("execute_e1", self, exec_state, ExecutionCtx); if num_insns.is_some() { check_exit_code(exec_state.exit_code)?; } else { @@ -261,12 +393,7 @@ where ) -> Result<(Vec, VmState), ExecutionError> { let mut exec_state = VmExecState::new(from_state, ctx); // Start execution - execute_with_metrics!( - "execute_metered", - self.pc_base, - &mut exec_state, - &self.pre_compute_insns - ); + run!("execute_metered", self, exec_state, MeteredCtx); check_termination(exec_state.exit_code)?; let VmExecState { vm_state, ctx, .. } = exec_state; Ok((ctx.into_segments(), vm_state)) @@ -306,12 +433,7 @@ where ) -> Result { let mut exec_state = VmExecState::new(from_state, ctx); // Start execution - execute_with_metrics!( - "execute_metered_cost", - self.pc_base, - &mut exec_state, - &self.pre_compute_insns - ); + run!("execute_metered_cost", self, exec_state, MeteredCostCtx); check_exit_code(exec_state.exit_code)?; let VmExecState { ctx, vm_state, .. } = exec_state; let output = MeteredCostExecutionOutput::new(vm_state.instret, ctx.cost); @@ -319,8 +441,10 @@ where } } -fn alloc_pre_compute_buf(program_len: usize, pre_compute_max_size: usize) -> AlignedBuf { - let buf_len = program_len * pre_compute_max_size; +fn alloc_pre_compute_buf(program: &Program, pre_compute_max_size: usize) -> AlignedBuf { + let base_idx = get_pc_index(program.pc_base); + let padded_program_len = base_idx + program.instructions_and_debug_infos.len(); + let buf_len = padded_program_len * pre_compute_max_size; AlignedBuf::uninit(buf_len, pre_compute_max_size) } @@ -329,20 +453,16 @@ fn split_pre_compute_buf<'a, F>( pre_compute_buf: &'a mut AlignedBuf, pre_compute_max_size: usize, ) -> Vec<&'a mut [u8]> { - let program_len = program.instructions_and_debug_infos.len(); - let buf_len = program_len * pre_compute_max_size; + let base_idx = get_pc_index(program.pc_base); + let padded_program_len = base_idx + program.instructions_and_debug_infos.len(); + let buf_len = padded_program_len * pre_compute_max_size; // SAFETY: // - pre_compute_buf.ptr was allocated with exactly buf_len bytes // - lifetime 'a ensures the returned slices don't outlive the AlignedBuf - let mut pre_compute_buf_ptr = - unsafe { std::slice::from_raw_parts_mut(pre_compute_buf.ptr, buf_len) }; - let mut split_pre_compute_buf = Vec::with_capacity(program_len); - for _ in 0..program_len { - let (first, last) = pre_compute_buf_ptr.split_at_mut(pre_compute_max_size); - pre_compute_buf_ptr = last; - split_pre_compute_buf.push(first); - } - split_pre_compute_buf + let pre_compute_buf = unsafe { std::slice::from_raw_parts_mut(pre_compute_buf.ptr, buf_len) }; + pre_compute_buf + .chunks_exact_mut(pre_compute_max_size) + .collect() } /// Executes using function pointers with the trampoline (loop) approach. @@ -351,7 +471,6 @@ fn split_pre_compute_buf<'a, F>( /// The `fn_ptrs` pointer to pre-computed buffers that outlive this function. #[inline(always)] unsafe fn execute_trampoline( - pc_base: u32, vm_state: &mut VmExecState, fn_ptrs: &[PreComputeInstruction], ) { @@ -363,16 +482,12 @@ unsafe fn execute_trampoline( if Ctx::should_suspend(vm_state) { break; } - let pc_index = get_pc_index(pc_base, vm_state.pc); + let pc_index = get_pc_index(vm_state.pc); if let Some(inst) = fn_ptrs.get(pc_index) { // SAFETY: pre_compute assumed to live long enough unsafe { (inst.handler)(inst.pre_compute, vm_state) }; } else { - vm_state.exit_code = Err(ExecutionError::PcOutOfBounds { - pc: vm_state.pc, - pc_base, - program_len: fn_ptrs.len(), - }); + vm_state.exit_code = Err(ExecutionError::PcOutOfBounds(vm_state.pc)); } } if vm_state @@ -385,8 +500,8 @@ unsafe fn execute_trampoline( } #[inline(always)] -pub fn get_pc_index(pc_base: u32, pc: u32) -> usize { - ((pc - pc_base) / DEFAULT_PC_STEP) as usize +pub fn get_pc_index(pc: u32) -> usize { + (pc / DEFAULT_PC_STEP) as usize } /// Bytes allocated according to the given Layout @@ -428,6 +543,7 @@ impl Drop for AlignedBuf { } } +#[inline(always)] unsafe fn terminate_execute_e12_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -437,6 +553,22 @@ unsafe fn terminate_execute_e12_impl( vm_state.exit_code = Ok(Some(pre_compute.exit_code)); } +#[cfg(feature = "tco")] +unsafe fn terminate_execute_e12_tco_handler( + interpreter: &InterpretedInstance, + vm_state: &mut VmExecState, +) { + let pre_compute = interpreter.get_pre_compute(vm_state.pc); + terminate_execute_e12_impl(pre_compute, vm_state); +} +#[cfg(feature = "tco")] +unsafe fn unreachable_tco_handler( + _: &InterpretedInstance, + vm_state: &mut VmExecState, +) { + vm_state.exit_code = Err(ExecutionError::Unreachable(vm_state.pc)); +} + fn get_pre_compute_max_size>( program: &Program, inventory: &ExecutorInventory, @@ -506,15 +638,19 @@ where Ctx: ExecutionCtxTrait, E: Executor, { - program - .instructions_and_debug_infos - .iter() + let unreachable_handler: ExecuteFunc = |_, vm_state| { + vm_state.exit_code = Err(ExecutionError::Unreachable(vm_state.pc)); + }; + + repeat_n(&None, get_pc_index(program.pc_base)) + .chain(program.instructions_and_debug_infos.iter()) .zip_eq(pre_compute.iter_mut()) .enumerate() .map(|(i, (inst_opt, buf))| { - // SAFETY: we cast to raw pointer and then borrow to remove the lifetime. This is safe - // only in the current context because `buf` comes from `pre_compute_buf` which will - // outlive the returned `PreComputeInstruction`s. + // SAFETY: we cast to raw pointer and then borrow to remove the lifetime. This + // is safe only in the current context because `buf` comes + // from `pre_compute_buf` which will outlive the returned + // `PreComputeInstruction`s. let buf: &mut [u8] = unsafe { &mut *(*buf as *mut [u8]) }; let pre_inst = if let Some((inst, _)) = inst_opt { tracing::trace!("get_pre_compute_instruction {inst:?}"); @@ -538,9 +674,7 @@ where } else { // Dead instruction at this pc PreComputeInstruction { - handler: |_, vm_state| { - vm_state.exit_code = Err(ExecutionError::Unreachable(vm_state.pc)); - }, + handler: unreachable_handler, pre_compute: buf, } }; @@ -560,15 +694,18 @@ where Ctx: MeteredExecutionCtxTrait, E: MeteredExecutor, { - program - .instructions_and_debug_infos - .iter() + let unreachable_handler: ExecuteFunc = |_, vm_state| { + vm_state.exit_code = Err(ExecutionError::Unreachable(vm_state.pc)); + }; + repeat_n(&None, get_pc_index(program.pc_base)) + .chain(program.instructions_and_debug_infos.iter()) .zip_eq(pre_compute.iter_mut()) .enumerate() .map(|(i, (inst_opt, buf))| { - // SAFETY: we cast to raw pointer and then borrow to remove the lifetime. This is safe - // only in the current context because `buf` comes from `pre_compute_buf` which will - // outlive the returned `PreComputeInstruction`s. + // SAFETY: we cast to raw pointer and then borrow to remove the lifetime. This + // is safe only in the current context because `buf` comes + // from `pre_compute_buf` which will outlive the returned + // `PreComputeInstruction`s. let buf: &mut [u8] = unsafe { &mut *(*buf as *mut [u8]) }; let pre_inst = if let Some((inst, _)) = inst_opt { tracing::trace!("get_metered_pre_compute_instruction {inst:?}"); @@ -597,9 +734,7 @@ where } } else { PreComputeInstruction { - handler: |_, vm_state| { - vm_state.exit_code = Err(ExecutionError::Unreachable(vm_state.pc)); - }, + handler: unreachable_handler, pre_compute: buf, } }; diff --git a/crates/vm/src/arch/interpreter_preflight.rs b/crates/vm/src/arch/interpreter_preflight.rs index 7fb8006157..1b5530b2cb 100644 --- a/crates/vm/src/arch/interpreter_preflight.rs +++ b/crates/vm/src/arch/interpreter_preflight.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{iter::repeat_n, sync::Arc}; use openvm_instructions::{instruction::Instruction, program::Program, LocalOpcode, SystemOpcode}; use openvm_stark_backend::{ @@ -36,6 +36,7 @@ pub struct PreflightInterpretedInstance { } #[repr(C)] +#[derive(Clone)] pub struct PcEntry { // NOTE[jpw]: revisit storing only smaller `precompute` for better cache locality. Currently // VmOpcode is usize so align=8 and there are 7 u32 operands so we store ExecutorId(u32) after @@ -60,7 +61,10 @@ impl PreflightInterpretedInstance { return Err(StaticProgramError::TooManyExecutors); } let len = program.instructions_and_debug_infos.len(); - let mut pc_handler = Vec::with_capacity(len); + let pc_base = program.pc_base; + let base_idx = get_pc_index(pc_base); + let mut pc_handler = Vec::with_capacity(base_idx + len); + pc_handler.extend(repeat_n(PcEntry::undefined(), base_idx)); for insn_and_debug_info in &program.instructions_and_debug_infos { if let Some((insn, _)) = insn_and_debug_info { let insn = insn.clone(); @@ -86,9 +90,9 @@ impl PreflightInterpretedInstance { } Ok(Self { inventory, - execution_frequencies: vec![0u32; len], + execution_frequencies: vec![0u32; base_idx + len], + pc_base, pc_handler, - pc_base: program.pc_base, executor_idx_to_air_idx, }) } @@ -101,9 +105,11 @@ impl PreflightInterpretedInstance { where E: Send + Sync, { + let base_idx = get_pc_index(self.pc_base); self.pc_handler .par_iter() .enumerate() + .skip(base_idx) .filter(|(_, entry)| entry.is_some()) .map(|(i, _)| self.execution_frequencies[i]) .collect() @@ -157,15 +163,11 @@ impl PreflightInterpretedInstance { E: PreflightExecutor, { let pc = state.pc; - let pc_idx = get_pc_index(self.pc_base, pc); - let pc_entry = - self.pc_handler - .get(pc_idx) - .ok_or_else(|| ExecutionError::PcOutOfBounds { - pc, - pc_base: self.pc_base, - program_len: self.pc_handler.len(), - })?; + let pc_idx = get_pc_index(pc); + let pc_entry = self + .pc_handler + .get(pc_idx) + .ok_or_else(|| ExecutionError::PcOutOfBounds(pc))?; // SAFETY: `execution_frequencies` has the same length as `pc_handler` so `get_pc_entry` // already does the bounds check unsafe { diff --git a/crates/vm/src/arch/mod.rs b/crates/vm/src/arch/mod.rs index 974b86008e..545a463883 100644 --- a/crates/vm/src/arch/mod.rs +++ b/crates/vm/src/arch/mod.rs @@ -30,6 +30,8 @@ pub use execution::*; pub use execution_mode::{ExecutionCtxTrait, MeteredExecutionCtxTrait}; pub use extensions::*; pub use integration_api::*; +pub use interpreter::InterpretedInstance; +pub use openvm_circuit_derive::create_tco_handler; pub use openvm_instructions as instructions; pub use record_arena::*; pub use state::*; diff --git a/crates/vm/src/arch/state.rs b/crates/vm/src/arch/state.rs index 6fe2e4e4f8..611094ecfc 100644 --- a/crates/vm/src/arch/state.rs +++ b/crates/vm/src/arch/state.rs @@ -5,6 +5,7 @@ use std::{ use openvm_instructions::exe::SparseMemoryImage; use rand::{rngs::StdRng, SeedableRng}; +use tracing::instrument; use super::{create_memory_image, ExecutionError, Streams}; #[cfg(feature = "metrics")] @@ -53,6 +54,7 @@ impl VmState { } impl VmState { + #[instrument(name = "VmState::initial", level = "debug", skip_all)] pub fn initial( system_config: &SystemConfig, init_memory: &SparseMemoryImage, diff --git a/crates/vm/src/lib.rs b/crates/vm/src/lib.rs index 271ea04b82..138549fb70 100644 --- a/crates/vm/src/lib.rs +++ b/crates/vm/src/lib.rs @@ -1,3 +1,5 @@ +#![cfg_attr(feature = "tco", allow(incomplete_features))] +#![cfg_attr(feature = "tco", feature(explicit_tail_calls))] extern crate self as openvm_circuit; pub use openvm_circuit_derive as derive; diff --git a/crates/vm/src/system/phantom/execution.rs b/crates/vm/src/system/phantom/execution.rs index e7e1775052..155b5d5713 100644 --- a/crates/vm/src/system/phantom/execution.rs +++ b/crates/vm/src/system/phantom/execution.rs @@ -7,8 +7,11 @@ use openvm_instructions::{ use openvm_stark_backend::p3_field::PrimeField32; use rand::rngs::StdRng; +#[cfg(feature = "tco")] +use crate::arch::Handler; use crate::{ arch::{ + create_tco_handler, execution_mode::{ExecutionCtxTrait, MeteredExecutionCtxTrait}, E2PreCompute, ExecuteFunc, ExecutionError, Executor, MeteredExecutor, PhantomSubExecutor, StaticProgramError, Streams, VmExecState, @@ -53,6 +56,20 @@ where self.pre_compute_impl(inst, data); Ok(execute_e1_impl) } + #[cfg(feature = "tco")] + fn handler( + &self, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut PhantomPreCompute = data.borrow_mut(); + self.pre_compute_impl(inst, data); + Ok(execute_e1_tco_handler) + } } pub(super) struct PhantomStateMut<'a, F> { @@ -85,6 +102,7 @@ unsafe fn execute_e12_impl( vm_state.instret += 1; } +#[create_tco_handler] #[inline(always)] unsafe fn execute_e1_impl( pre_compute: &[u8], @@ -94,6 +112,7 @@ unsafe fn execute_e1_impl( execute_e12_impl(pre_compute, vm_state); } +#[create_tco_handler] #[inline(always)] unsafe fn execute_e2_impl( pre_compute: &[u8], @@ -189,4 +208,21 @@ where self.pre_compute_impl(inst, &mut e2_data.data); Ok(execute_e2_impl) } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let e2_data: &mut E2PreCompute> = data.borrow_mut(); + e2_data.chip_idx = chip_idx as u32; + self.pre_compute_impl(inst, &mut e2_data.data); + Ok(execute_e2_tco_handler) + } } diff --git a/crates/vm/src/system/public_values/execution.rs b/crates/vm/src/system/public_values/execution.rs index 34c1f22ff0..dcc25b3bd9 100644 --- a/crates/vm/src/system/public_values/execution.rs +++ b/crates/vm/src/system/public_values/execution.rs @@ -7,8 +7,11 @@ use openvm_instructions::{ use openvm_stark_backend::p3_field::PrimeField32; use super::PublicValuesExecutor; +#[cfg(feature = "tco")] +use crate::arch::Handler; use crate::{ arch::{ + create_tco_handler, execution_mode::{ExecutionCtxTrait, MeteredExecutionCtxTrait}, E2PreCompute, ExecuteFunc, Executor, MeteredExecutor, StaticProgramError, VmExecState, }, @@ -57,6 +60,17 @@ where } } +macro_rules! dispatch { + ($execute_impl:ident, $b_is_imm:ident, $c_is_imm:ident) => { + match ($b_is_imm, $c_is_imm) { + (true, true) => Ok($execute_impl::<_, _, true, true>), + (true, false) => Ok($execute_impl::<_, _, true, false>), + (false, true) => Ok($execute_impl::<_, _, false, true>), + (false, false) => Ok($execute_impl::<_, _, false, false>), + } + }; +} + impl Executor for PublicValuesExecutor where F: PrimeField32, @@ -79,13 +93,23 @@ where let data: &mut PublicValuesPreCompute = data.borrow_mut(); let (b_is_imm, c_is_imm) = self.pre_compute_impl(inst, data); - let fn_ptr = match (b_is_imm, c_is_imm) { - (true, true) => execute_e1_impl::<_, _, true, true>, - (true, false) => execute_e1_impl::<_, _, true, false>, - (false, true) => execute_e1_impl::<_, _, false, true>, - (false, false) => execute_e1_impl::<_, _, false, false>, - }; - Ok(fn_ptr) + dispatch!(execute_e1_impl, b_is_imm, c_is_imm) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut PublicValuesPreCompute = data.borrow_mut(); + let (b_is_imm, c_is_imm) = self.pre_compute_impl(inst, data); + + dispatch!(execute_e1_tco_handler, b_is_imm, c_is_imm) } } @@ -111,13 +135,25 @@ where data.chip_idx = chip_idx as u32; let (b_is_imm, c_is_imm) = self.pre_compute_impl(inst, &mut data.data); - let fn_ptr = match (b_is_imm, c_is_imm) { - (true, true) => execute_e2_impl::<_, _, true, true>, - (true, false) => execute_e2_impl::<_, _, true, false>, - (false, true) => execute_e2_impl::<_, _, false, true>, - (false, false) => execute_e2_impl::<_, _, false, false>, - }; - Ok(fn_ptr) + dispatch!(execute_e2_impl, b_is_imm, c_is_imm) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let (b_is_imm, c_is_imm) = self.pre_compute_impl(inst, &mut data.data); + + dispatch!(execute_e2_tco_handler, b_is_imm, c_is_imm) } } @@ -155,6 +191,7 @@ unsafe fn execute_e12_impl( pre_compute: &[u8], @@ -166,6 +203,7 @@ unsafe fn execute_e1_impl(pre_compute, state); } +#[create_tco_handler] #[inline(always)] unsafe fn execute_e2_impl( pre_compute: &[u8], diff --git a/extensions/algebra/circuit/Cargo.toml b/extensions/algebra/circuit/Cargo.toml index 7d0eb389e6..4e4e7f7357 100644 --- a/extensions/algebra/circuit/Cargo.toml +++ b/extensions/algebra/circuit/Cargo.toml @@ -38,5 +38,9 @@ openvm-rv32-adapters = { workspace = true, features = ["test-utils"] } openvm-pairing-guest = { workspace = true, features = ["halo2curves"] } test-case = { workspace = true } +[features] +default = [] +tco = ["openvm-rv32im-circuit/tco"] + [package.metadata.cargo-shear] ignored = ["derive_more"] diff --git a/extensions/algebra/circuit/src/execution.rs b/extensions/algebra/circuit/src/execution.rs index a99c4ba37b..e626e08f59 100644 --- a/extensions/algebra/circuit/src/execution.rs +++ b/extensions/algebra/circuit/src/execution.rs @@ -74,6 +74,83 @@ macro_rules! generate_fp2_dispatch { }; } +macro_rules! dispatch { + ($execute_impl:ident,$execute_generic_impl:ident,$execute_setup_impl:ident,$pre_compute:ident,$op:ident) => { + if let Some(op) = $op { + let modulus = &$pre_compute.expr.prime; + if IS_FP2 { + if let Some(field_type) = get_fp2_field_type(modulus) { + generate_fp2_dispatch!( + field_type, + op, + BLOCKS, + BLOCK_SIZE, + $execute_impl, + [ + (BN254Coordinate, Add), + (BN254Coordinate, Sub), + (BN254Coordinate, Mul), + (BN254Coordinate, Div), + (BLS12_381Coordinate, Add), + (BLS12_381Coordinate, Sub), + (BLS12_381Coordinate, Mul), + (BLS12_381Coordinate, Div), + ] + ) + } else { + Ok($execute_generic_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>) + } + } else if let Some(field_type) = get_field_type(modulus) { + generate_field_dispatch!( + field_type, + op, + BLOCKS, + BLOCK_SIZE, + $execute_impl, + [ + (K256Coordinate, Add), + (K256Coordinate, Sub), + (K256Coordinate, Mul), + (K256Coordinate, Div), + (K256Scalar, Add), + (K256Scalar, Sub), + (K256Scalar, Mul), + (K256Scalar, Div), + (P256Coordinate, Add), + (P256Coordinate, Sub), + (P256Coordinate, Mul), + (P256Coordinate, Div), + (P256Scalar, Add), + (P256Scalar, Sub), + (P256Scalar, Mul), + (P256Scalar, Div), + (BN254Coordinate, Add), + (BN254Coordinate, Sub), + (BN254Coordinate, Mul), + (BN254Coordinate, Div), + (BN254Scalar, Add), + (BN254Scalar, Sub), + (BN254Scalar, Mul), + (BN254Scalar, Div), + (BLS12_381Coordinate, Add), + (BLS12_381Coordinate, Sub), + (BLS12_381Coordinate, Mul), + (BLS12_381Coordinate, Div), + (BLS12_381Scalar, Add), + (BLS12_381Scalar, Sub), + (BLS12_381Scalar, Mul), + (BLS12_381Scalar, Div), + ] + ) + } else { + Ok($execute_generic_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>) + } + } else { + Ok($execute_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>) + } + }; +} + #[derive(AlignedBytesBorrow, Clone)] #[repr(C)] struct FieldExpressionPreCompute<'a> { @@ -192,81 +269,37 @@ impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut FieldExpressionPreCompute = data.borrow_mut(); let op = self.pre_compute_impl(pc, inst, pre_compute)?; - if let Some(op) = op { - let modulus = &pre_compute.expr.prime; - if IS_FP2 { - if let Some(field_type) = get_fp2_field_type(modulus) { - generate_fp2_dispatch!( - field_type, - op, - BLOCKS, - BLOCK_SIZE, - execute_e1_impl, - [ - (BN254Coordinate, Add), - (BN254Coordinate, Sub), - (BN254Coordinate, Mul), - (BN254Coordinate, Div), - (BLS12_381Coordinate, Add), - (BLS12_381Coordinate, Sub), - (BLS12_381Coordinate, Mul), - (BLS12_381Coordinate, Div), - ] - ) - } else { - Ok(execute_e1_generic_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>) - } - } else if let Some(field_type) = get_field_type(modulus) { - generate_field_dispatch!( - field_type, - op, - BLOCKS, - BLOCK_SIZE, - execute_e1_impl, - [ - (K256Coordinate, Add), - (K256Coordinate, Sub), - (K256Coordinate, Mul), - (K256Coordinate, Div), - (K256Scalar, Add), - (K256Scalar, Sub), - (K256Scalar, Mul), - (K256Scalar, Div), - (P256Coordinate, Add), - (P256Coordinate, Sub), - (P256Coordinate, Mul), - (P256Coordinate, Div), - (P256Scalar, Add), - (P256Scalar, Sub), - (P256Scalar, Mul), - (P256Scalar, Div), - (BN254Coordinate, Add), - (BN254Coordinate, Sub), - (BN254Coordinate, Mul), - (BN254Coordinate, Div), - (BN254Scalar, Add), - (BN254Scalar, Sub), - (BN254Scalar, Mul), - (BN254Scalar, Div), - (BLS12_381Coordinate, Add), - (BLS12_381Coordinate, Sub), - (BLS12_381Coordinate, Mul), - (BLS12_381Coordinate, Div), - (BLS12_381Scalar, Add), - (BLS12_381Scalar, Sub), - (BLS12_381Scalar, Mul), - (BLS12_381Scalar, Div), - ] - ) - } else { - Ok(execute_e1_generic_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>) - } - } else { - Ok(execute_e1_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>) - } + dispatch!( + execute_e1_tco_handler, + execute_e1_generic_tco_handler, + execute_e1_setup_tco_handler, + pre_compute, + op + ) } } @@ -291,80 +324,42 @@ impl = data.borrow_mut(); pre_compute.chip_idx = chip_idx as u32; - let op = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + let pre_compute_pure = &mut pre_compute.data; + let op = self.pre_compute_impl(pc, inst, pre_compute_pure)?; - if let Some(op) = op { - let modulus = &pre_compute.data.expr.prime; - if IS_FP2 { - if let Some(field_type) = get_fp2_field_type(modulus) { - generate_fp2_dispatch!( - field_type, - op, - BLOCKS, - BLOCK_SIZE, - execute_e2_impl, - [ - (BN254Coordinate, Add), - (BN254Coordinate, Sub), - (BN254Coordinate, Mul), - (BN254Coordinate, Div), - (BLS12_381Coordinate, Add), - (BLS12_381Coordinate, Sub), - (BLS12_381Coordinate, Mul), - (BLS12_381Coordinate, Div), - ] - ) - } else { - Ok(execute_e2_generic_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>) - } - } else if let Some(field_type) = get_field_type(modulus) { - generate_field_dispatch!( - field_type, - op, - BLOCKS, - BLOCK_SIZE, - execute_e2_impl, - [ - (K256Coordinate, Add), - (K256Coordinate, Sub), - (K256Coordinate, Mul), - (K256Coordinate, Div), - (K256Scalar, Add), - (K256Scalar, Sub), - (K256Scalar, Mul), - (K256Scalar, Div), - (P256Coordinate, Add), - (P256Coordinate, Sub), - (P256Coordinate, Mul), - (P256Coordinate, Div), - (P256Scalar, Add), - (P256Scalar, Sub), - (P256Scalar, Mul), - (P256Scalar, Div), - (BN254Coordinate, Add), - (BN254Coordinate, Sub), - (BN254Coordinate, Mul), - (BN254Coordinate, Div), - (BN254Scalar, Add), - (BN254Scalar, Sub), - (BN254Scalar, Mul), - (BN254Scalar, Div), - (BLS12_381Coordinate, Add), - (BLS12_381Coordinate, Sub), - (BLS12_381Coordinate, Mul), - (BLS12_381Coordinate, Div), - (BLS12_381Scalar, Add), - (BLS12_381Scalar, Sub), - (BLS12_381Scalar, Mul), - (BLS12_381Scalar, Div), - ] - ) - } else { - Ok(execute_e2_generic_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>) - } - } else { - Ok(execute_e2_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>) - } + dispatch!( + execute_e2_impl, + execute_e2_generic_impl, + execute_e2_setup_impl, + pre_compute_pure, + op + ) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let pre_compute_pure = &mut pre_compute.data; + let op = self.pre_compute_impl(pc, inst, pre_compute_pure)?; + + dispatch!( + execute_e2_tco_handler, + execute_e2_generic_tco_handler, + execute_e2_setup_tco_handler, + pre_compute_pure, + op + ) } } unsafe fn execute_e12_impl< @@ -496,6 +491,7 @@ unsafe fn execute_e12_setup_impl< vm_state.instret += 1; } +#[create_tco_handler] unsafe fn execute_e1_setup_impl< F: PrimeField32, CTX: ExecutionCtxTrait, @@ -510,6 +506,7 @@ unsafe fn execute_e1_setup_impl< execute_e12_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_setup_impl< F: PrimeField32, CTX: MeteredExecutionCtxTrait, @@ -527,6 +524,7 @@ unsafe fn execute_e2_setup_impl< execute_e12_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>(&pre_compute.data, vm_state); } +#[create_tco_handler] unsafe fn execute_e1_impl< F: PrimeField32, CTX: ExecutionCtxTrait, @@ -543,6 +541,7 @@ unsafe fn execute_e1_impl< execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2, FIELD_TYPE, OP>(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl< F: PrimeField32, CTX: MeteredExecutionCtxTrait, @@ -565,6 +564,7 @@ unsafe fn execute_e2_impl< ); } +#[create_tco_handler] unsafe fn execute_e1_generic_impl< F: PrimeField32, CTX: ExecutionCtxTrait, @@ -579,6 +579,7 @@ unsafe fn execute_e1_generic_impl< execute_e12_generic_impl::<_, _, BLOCKS, BLOCK_SIZE>(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_generic_impl< F: PrimeField32, CTX: MeteredExecutionCtxTrait, diff --git a/extensions/algebra/circuit/src/lib.rs b/extensions/algebra/circuit/src/lib.rs index b4e494c812..08a69c650a 100644 --- a/extensions/algebra/circuit/src/lib.rs +++ b/extensions/algebra/circuit/src/lib.rs @@ -1,3 +1,6 @@ +#![cfg_attr(feature = "tco", allow(incomplete_features))] +#![cfg_attr(feature = "tco", feature(explicit_tail_calls))] + use derive_more::derive::{Deref, DerefMut}; use openvm_circuit_derive::PreflightExecutor; use openvm_mod_circuit_builder::FieldExpressionExecutor; diff --git a/extensions/algebra/circuit/src/modular_chip/is_eq.rs b/extensions/algebra/circuit/src/modular_chip/is_eq.rs index 16e5c60fc9..28dd0488d7 100644 --- a/extensions/algebra/circuit/src/modular_chip/is_eq.rs +++ b/extensions/algebra/circuit/src/modular_chip/is_eq.rs @@ -528,6 +528,16 @@ impl { + Ok(if $is_setup { + $execute_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, true> + } else { + $execute_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, false> + }) + }; +} + impl Executor for VmModularIsEqualExecutor where @@ -545,15 +555,25 @@ where data: &mut [u8], ) -> Result, StaticProgramError> { let pre_compute: &mut ModularIsEqualPreCompute = data.borrow_mut(); + let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?; + + dispatch!(execute_e1_impl, is_setup) + } + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut ModularIsEqualPreCompute = data.borrow_mut(); let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?; - let fn_ptr = if is_setup { - execute_e1_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, true> - } else { - execute_e1_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, false> - }; - Ok(fn_ptr) + dispatch!(execute_e1_tco_handler, is_setup) } } @@ -579,16 +599,29 @@ where pre_compute.chip_idx = chip_idx as u32; let is_setup = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; - let fn_ptr = if is_setup { - execute_e2_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, true> - } else { - execute_e2_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, false> - }; - Ok(fn_ptr) + dispatch!(execute_e2_impl, is_setup) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute> = + data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let is_setup = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + dispatch!(execute_e2_tco_handler, is_setup) } } +#[create_tco_handler] unsafe fn execute_e1_impl< F: PrimeField32, CTX: ExecutionCtxTrait, @@ -608,6 +641,7 @@ unsafe fn execute_e1_impl< ); } +#[create_tco_handler] unsafe fn execute_e2_impl< F: PrimeField32, CTX: MeteredExecutionCtxTrait, diff --git a/extensions/bigint/circuit/Cargo.toml b/extensions/bigint/circuit/Cargo.toml index aa9114c34a..a745dd33cc 100644 --- a/extensions/bigint/circuit/Cargo.toml +++ b/extensions/bigint/circuit/Cargo.toml @@ -36,6 +36,7 @@ alloy-primitives = { version = "1.2.1" } default = ["parallel", "jemalloc"] parallel = ["openvm-circuit/parallel"] test-utils = ["openvm-circuit/test-utils"] +tco = ["openvm-rv32im-circuit/tco"] # performance features: mimalloc = ["openvm-circuit/mimalloc"] jemalloc = ["openvm-circuit/jemalloc"] diff --git a/extensions/bigint/circuit/src/base_alu.rs b/extensions/bigint/circuit/src/base_alu.rs index 617bfab5f4..53f2efb7b9 100644 --- a/extensions/bigint/circuit/src/base_alu.rs +++ b/extensions/bigint/circuit/src/base_alu.rs @@ -34,6 +34,18 @@ struct BaseAluPreCompute { c: u8, } +macro_rules! dispatch { + ($execute_impl:ident, $local_opcode:ident) => { + Ok(match $local_opcode { + BaseAluOpcode::ADD => $execute_impl::<_, _, AddOp>, + BaseAluOpcode::SUB => $execute_impl::<_, _, SubOp>, + BaseAluOpcode::XOR => $execute_impl::<_, _, XorOp>, + BaseAluOpcode::OR => $execute_impl::<_, _, OrOp>, + BaseAluOpcode::AND => $execute_impl::<_, _, AndOp>, + }) + }; +} + impl Executor for Rv32BaseAlu256Executor { fn pre_compute_size(&self) -> usize { size_of::() @@ -50,14 +62,24 @@ impl Executor for Rv32BaseAlu256Executor { { let data: &mut BaseAluPreCompute = data.borrow_mut(); let local_opcode = self.pre_compute_impl(pc, inst, data)?; - let fn_ptr = match local_opcode { - BaseAluOpcode::ADD => execute_e1_impl::<_, _, AddOp>, - BaseAluOpcode::SUB => execute_e1_impl::<_, _, SubOp>, - BaseAluOpcode::XOR => execute_e1_impl::<_, _, XorOp>, - BaseAluOpcode::OR => execute_e1_impl::<_, _, OrOp>, - BaseAluOpcode::AND => execute_e1_impl::<_, _, AndOp>, - }; - Ok(fn_ptr) + + dispatch!(execute_e1_impl, local_opcode) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut BaseAluPreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, data)?; + + dispatch!(execute_e1_tco_handler, local_opcode) } } @@ -79,14 +101,26 @@ impl MeteredExecutor for Rv32BaseAlu256Executor { let data: &mut E2PreCompute = data.borrow_mut(); data.chip_idx = chip_idx as u32; let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; - let fn_ptr = match local_opcode { - BaseAluOpcode::ADD => execute_e2_impl::<_, _, AddOp>, - BaseAluOpcode::SUB => execute_e2_impl::<_, _, SubOp>, - BaseAluOpcode::XOR => execute_e2_impl::<_, _, XorOp>, - BaseAluOpcode::OR => execute_e2_impl::<_, _, OrOp>, - BaseAluOpcode::AND => execute_e2_impl::<_, _, AndOp>, - }; - Ok(fn_ptr) + + dispatch!(execute_e2_impl, local_opcode) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; + + dispatch!(execute_e2_tco_handler, local_opcode) } } @@ -106,6 +140,7 @@ unsafe fn execute_e12_impl( vm_state.instret += 1; } +#[create_tco_handler] unsafe fn execute_e1_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -114,6 +149,7 @@ unsafe fn execute_e1_impl( execute_e12_impl::(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl( pre_compute: &[u8], vm_state: &mut VmExecState, diff --git a/extensions/bigint/circuit/src/branch_eq.rs b/extensions/bigint/circuit/src/branch_eq.rs index 3a421bdceb..70ba68fb31 100644 --- a/extensions/bigint/circuit/src/branch_eq.rs +++ b/extensions/bigint/circuit/src/branch_eq.rs @@ -32,6 +32,15 @@ struct BranchEqPreCompute { b: u8, } +macro_rules! dispatch { + ($execute_impl:ident, $local_opcode:ident) => { + match $local_opcode { + BranchEqualOpcode::BEQ => Ok($execute_impl::<_, _, false>), + BranchEqualOpcode::BNE => Ok($execute_impl::<_, _, true>), + } + }; +} + impl Executor for Rv32BranchEqual256Executor { fn pre_compute_size(&self) -> usize { size_of::() @@ -48,11 +57,22 @@ impl Executor for Rv32BranchEqual256Executor { { let data: &mut BranchEqPreCompute = data.borrow_mut(); let local_opcode = self.pre_compute_impl(pc, inst, data)?; - let fn_ptr = match local_opcode { - BranchEqualOpcode::BEQ => execute_e1_impl::<_, _, false>, - BranchEqualOpcode::BNE => execute_e1_impl::<_, _, true>, - }; - Ok(fn_ptr) + dispatch!(execute_e1_impl, local_opcode) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut BranchEqPreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, data)?; + dispatch!(execute_e1_tco_handler, local_opcode) } } @@ -74,11 +94,24 @@ impl MeteredExecutor for Rv32BranchEqual256Executor { let data: &mut E2PreCompute = data.borrow_mut(); data.chip_idx = chip_idx as u32; let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; - let fn_ptr = match local_opcode { - BranchEqualOpcode::BEQ => execute_e2_impl::<_, _, false>, - BranchEqualOpcode::BNE => execute_e2_impl::<_, _, true>, - }; - Ok(fn_ptr) + dispatch!(execute_e2_impl, local_opcode) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; + dispatch!(execute_e2_tco_handler, local_opcode) } } @@ -101,6 +134,7 @@ unsafe fn execute_e12_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -109,6 +143,7 @@ unsafe fn execute_e1_impl(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl( pre_compute: &[u8], vm_state: &mut VmExecState, diff --git a/extensions/bigint/circuit/src/branch_lt.rs b/extensions/bigint/circuit/src/branch_lt.rs index 7a701d4812..b161fa0091 100644 --- a/extensions/bigint/circuit/src/branch_lt.rs +++ b/extensions/bigint/circuit/src/branch_lt.rs @@ -35,6 +35,17 @@ struct BranchLtPreCompute { b: u8, } +macro_rules! dispatch { + ($execute_impl:ident, $local_opcode:ident) => { + Ok(match $local_opcode { + BranchLessThanOpcode::BLT => $execute_impl::<_, _, BltOp>, + BranchLessThanOpcode::BLTU => $execute_impl::<_, _, BltuOp>, + BranchLessThanOpcode::BGE => $execute_impl::<_, _, BgeOp>, + BranchLessThanOpcode::BGEU => $execute_impl::<_, _, BgeuOp>, + }) + }; +} + impl Executor for Rv32BranchLessThan256Executor { fn pre_compute_size(&self) -> usize { size_of::() @@ -51,13 +62,22 @@ impl Executor for Rv32BranchLessThan256Executor { { let data: &mut BranchLtPreCompute = data.borrow_mut(); let local_opcode = self.pre_compute_impl(pc, inst, data)?; - let fn_ptr = match local_opcode { - BranchLessThanOpcode::BLT => execute_e1_impl::<_, _, BltOp>, - BranchLessThanOpcode::BLTU => execute_e1_impl::<_, _, BltuOp>, - BranchLessThanOpcode::BGE => execute_e1_impl::<_, _, BgeOp>, - BranchLessThanOpcode::BGEU => execute_e1_impl::<_, _, BgeuOp>, - }; - Ok(fn_ptr) + dispatch!(execute_e1_impl, local_opcode) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut BranchLtPreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, data)?; + dispatch!(execute_e1_tco_handler, local_opcode) } } @@ -79,13 +99,24 @@ impl MeteredExecutor for Rv32BranchLessThan256Executor { let data: &mut E2PreCompute = data.borrow_mut(); data.chip_idx = chip_idx as u32; let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; - let fn_ptr = match local_opcode { - BranchLessThanOpcode::BLT => execute_e2_impl::<_, _, BltOp>, - BranchLessThanOpcode::BLTU => execute_e2_impl::<_, _, BltuOp>, - BranchLessThanOpcode::BGE => execute_e2_impl::<_, _, BgeOp>, - BranchLessThanOpcode::BGEU => execute_e2_impl::<_, _, BgeuOp>, - }; - Ok(fn_ptr) + dispatch!(execute_e2_impl, local_opcode) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; + dispatch!(execute_e2_tco_handler, local_opcode) } } @@ -107,6 +138,7 @@ unsafe fn execute_e12_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -115,6 +147,7 @@ unsafe fn execute_e1_impl(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl( pre_compute: &[u8], vm_state: &mut VmExecState, diff --git a/extensions/bigint/circuit/src/less_than.rs b/extensions/bigint/circuit/src/less_than.rs index e153a6221e..85bfd152ce 100644 --- a/extensions/bigint/circuit/src/less_than.rs +++ b/extensions/bigint/circuit/src/less_than.rs @@ -32,6 +32,15 @@ struct LessThanPreCompute { c: u8, } +macro_rules! dispatch { + ($execute_impl:ident, $local_opcode:ident) => { + Ok(match $local_opcode { + LessThanOpcode::SLT => $execute_impl::<_, _, false>, + LessThanOpcode::SLTU => $execute_impl::<_, _, true>, + }) + }; +} + impl Executor for Rv32LessThan256Executor { fn pre_compute_size(&self) -> usize { size_of::() @@ -48,11 +57,22 @@ impl Executor for Rv32LessThan256Executor { { let data: &mut LessThanPreCompute = data.borrow_mut(); let local_opcode = self.pre_compute_impl(pc, inst, data)?; - let fn_ptr = match local_opcode { - LessThanOpcode::SLT => execute_e1_impl::<_, _, false>, - LessThanOpcode::SLTU => execute_e1_impl::<_, _, true>, - }; - Ok(fn_ptr) + dispatch!(execute_e1_impl, local_opcode) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut LessThanPreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, data)?; + dispatch!(execute_e1_tco_handler, local_opcode) } } @@ -74,11 +94,24 @@ impl MeteredExecutor for Rv32LessThan256Executor { let data: &mut E2PreCompute = data.borrow_mut(); data.chip_idx = chip_idx as u32; let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; - let fn_ptr = match local_opcode { - LessThanOpcode::SLT => execute_e2_impl::<_, _, false>, - LessThanOpcode::SLTU => execute_e2_impl::<_, _, true>, - }; - Ok(fn_ptr) + dispatch!(execute_e2_impl, local_opcode) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; + dispatch!(execute_e2_tco_handler, local_opcode) } } @@ -105,6 +138,7 @@ unsafe fn execute_e12_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -113,6 +147,7 @@ unsafe fn execute_e1_impl(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl( pre_compute: &[u8], vm_state: &mut VmExecState, diff --git a/extensions/bigint/circuit/src/lib.rs b/extensions/bigint/circuit/src/lib.rs index 0dd5a5b4d4..0109a3f88e 100644 --- a/extensions/bigint/circuit/src/lib.rs +++ b/extensions/bigint/circuit/src/lib.rs @@ -1,3 +1,5 @@ +#![cfg_attr(feature = "tco", allow(incomplete_features))] +#![cfg_attr(feature = "tco", feature(explicit_tail_calls))] use openvm_circuit::{ self, arch::{ diff --git a/extensions/bigint/circuit/src/mult.rs b/extensions/bigint/circuit/src/mult.rs index c48a025d98..fae0e65894 100644 --- a/extensions/bigint/circuit/src/mult.rs +++ b/extensions/bigint/circuit/src/mult.rs @@ -53,6 +53,21 @@ impl Executor for Rv32Multiplication256Executor { self.pre_compute_impl(pc, inst, data)?; Ok(execute_e1_impl) } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut MultPreCompute = data.borrow_mut(); + self.pre_compute_impl(pc, inst, data)?; + Ok(execute_e1_tco_handler) + } } impl MeteredExecutor for Rv32Multiplication256Executor { @@ -75,6 +90,23 @@ impl MeteredExecutor for Rv32Multiplication256Executor { self.pre_compute_impl(pc, inst, &mut data.data)?; Ok(execute_e2_impl) } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + self.pre_compute_impl(pc, inst, &mut data.data)?; + Ok(execute_e2_tco_handler) + } } #[inline(always)] @@ -94,6 +126,7 @@ unsafe fn execute_e12_impl( vm_state.instret += 1; } +#[create_tco_handler] unsafe fn execute_e1_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -102,6 +135,7 @@ unsafe fn execute_e1_impl( execute_e12_impl(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl( pre_compute: &[u8], vm_state: &mut VmExecState, diff --git a/extensions/bigint/circuit/src/shift.rs b/extensions/bigint/circuit/src/shift.rs index ba669b0eef..462a8f7af5 100644 --- a/extensions/bigint/circuit/src/shift.rs +++ b/extensions/bigint/circuit/src/shift.rs @@ -35,6 +35,16 @@ struct ShiftPreCompute { c: u8, } +macro_rules! dispatch { + ($execute_impl:ident, $local_opcode:ident) => { + Ok(match $local_opcode { + ShiftOpcode::SLL => $execute_impl::<_, _, SllOp>, + ShiftOpcode::SRA => $execute_impl::<_, _, SraOp>, + ShiftOpcode::SRL => $execute_impl::<_, _, SrlOp>, + }) + }; +} + impl Executor for Rv32Shift256Executor { fn pre_compute_size(&self) -> usize { size_of::() @@ -51,12 +61,22 @@ impl Executor for Rv32Shift256Executor { { let data: &mut ShiftPreCompute = data.borrow_mut(); let local_opcode = self.pre_compute_impl(pc, inst, data)?; - let fn_ptr = match local_opcode { - ShiftOpcode::SLL => execute_e1_impl::<_, _, SllOp>, - ShiftOpcode::SRA => execute_e1_impl::<_, _, SraOp>, - ShiftOpcode::SRL => execute_e1_impl::<_, _, SrlOp>, - }; - Ok(fn_ptr) + dispatch!(execute_e1_impl, local_opcode) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut ShiftPreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, data)?; + dispatch!(execute_e1_tco_handler, local_opcode) } } @@ -78,12 +98,24 @@ impl MeteredExecutor for Rv32Shift256Executor { let data: &mut E2PreCompute = data.borrow_mut(); data.chip_idx = chip_idx as u32; let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; - let fn_ptr = match local_opcode { - ShiftOpcode::SLL => execute_e2_impl::<_, _, SllOp>, - ShiftOpcode::SRA => execute_e2_impl::<_, _, SraOp>, - ShiftOpcode::SRL => execute_e2_impl::<_, _, SrlOp>, - }; - Ok(fn_ptr) + dispatch!(execute_e2_impl, local_opcode) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; + dispatch!(execute_e2_tco_handler, local_opcode) } } @@ -103,6 +135,7 @@ unsafe fn execute_e12_impl vm_state.instret += 1; } +#[create_tco_handler] unsafe fn execute_e1_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -110,6 +143,7 @@ unsafe fn execute_e1_impl( let pre_compute: &ShiftPreCompute = pre_compute.borrow(); execute_e12_impl::(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl( pre_compute: &[u8], vm_state: &mut VmExecState, diff --git a/extensions/ecc/circuit/Cargo.toml b/extensions/ecc/circuit/Cargo.toml index a194b5ac5a..c6ed2f14e1 100644 --- a/extensions/ecc/circuit/Cargo.toml +++ b/extensions/ecc/circuit/Cargo.toml @@ -39,5 +39,9 @@ openvm-circuit = { workspace = true, features = ["test-utils"] } openvm-rv32-adapters = { workspace = true, features = ["test-utils"] } lazy_static = { workspace = true } +[features] +default = [] +tco = ["openvm-algebra-circuit/tco"] + [package.metadata.cargo-shear] ignored = ["rand"] diff --git a/extensions/ecc/circuit/src/lib.rs b/extensions/ecc/circuit/src/lib.rs index 9986dca696..713088e864 100644 --- a/extensions/ecc/circuit/src/lib.rs +++ b/extensions/ecc/circuit/src/lib.rs @@ -1,3 +1,6 @@ +#![cfg_attr(feature = "tco", allow(incomplete_features))] +#![cfg_attr(feature = "tco", feature(explicit_tail_calls))] + mod weierstrass_chip; pub use weierstrass_chip::*; diff --git a/extensions/ecc/circuit/src/weierstrass_chip/add_ne/execution.rs b/extensions/ecc/circuit/src/weierstrass_chip/add_ne/execution.rs index 5c63a05c41..a983321e10 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/add_ne/execution.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/add_ne/execution.rs @@ -92,33 +92,14 @@ impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize> EcAddNeExecutor Executor - for EcAddNeExecutor -{ - #[inline(always)] - fn pre_compute_size(&self) -> usize { - std::mem::size_of::() - } - - fn pre_compute( - &self, - pc: u32, - inst: &Instruction, - data: &mut [u8], - ) -> Result, StaticProgramError> - where - Ctx: ExecutionCtxTrait, - { - let pre_compute: &mut EcAddNePreCompute = data.borrow_mut(); - - let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?; - +macro_rules! dispatch { + ($execute_impl:ident, $pre_compute:ident, $is_setup:ident) => { if let Some(field_type) = { - let modulus = &pre_compute.expr.builder.prime; + let modulus = &$pre_compute.expr.builder.prime; get_field_type(modulus) } { - match (is_setup, field_type) { - (true, FieldType::K256Coordinate) => Ok(execute_e1_impl::< + match ($is_setup, field_type) { + (true, FieldType::K256Coordinate) => Ok($execute_impl::< _, _, BLOCKS, @@ -126,7 +107,7 @@ impl Executor { FieldType::K256Coordinate as u8 }, true, >), - (true, FieldType::P256Coordinate) => Ok(execute_e1_impl::< + (true, FieldType::P256Coordinate) => Ok($execute_impl::< _, _, BLOCKS, @@ -134,7 +115,7 @@ impl Executor { FieldType::P256Coordinate as u8 }, true, >), - (true, FieldType::BN254Coordinate) => Ok(execute_e1_impl::< + (true, FieldType::BN254Coordinate) => Ok($execute_impl::< _, _, BLOCKS, @@ -142,7 +123,7 @@ impl Executor { FieldType::BN254Coordinate as u8 }, true, >), - (true, FieldType::BLS12_381Coordinate) => Ok(execute_e1_impl::< + (true, FieldType::BLS12_381Coordinate) => Ok($execute_impl::< _, _, BLOCKS, @@ -150,7 +131,7 @@ impl Executor { FieldType::BLS12_381Coordinate as u8 }, true, >), - (false, FieldType::K256Coordinate) => Ok(execute_e1_impl::< + (false, FieldType::K256Coordinate) => Ok($execute_impl::< _, _, BLOCKS, @@ -158,7 +139,7 @@ impl Executor { FieldType::K256Coordinate as u8 }, false, >), - (false, FieldType::P256Coordinate) => Ok(execute_e1_impl::< + (false, FieldType::P256Coordinate) => Ok($execute_impl::< _, _, BLOCKS, @@ -166,7 +147,7 @@ impl Executor { FieldType::P256Coordinate as u8 }, false, >), - (false, FieldType::BN254Coordinate) => Ok(execute_e1_impl::< + (false, FieldType::BN254Coordinate) => Ok($execute_impl::< _, _, BLOCKS, @@ -174,7 +155,7 @@ impl Executor { FieldType::BN254Coordinate as u8 }, false, >), - (false, FieldType::BLS12_381Coordinate) => Ok(execute_e1_impl::< + (false, FieldType::BLS12_381Coordinate) => Ok($execute_impl::< _, _, BLOCKS, @@ -184,11 +165,50 @@ impl Executor >), _ => panic!("Unsupported field type"), } - } else if is_setup { - Ok(execute_e1_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, true>) + } else if $is_setup { + Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, true>) } else { - Ok(execute_e1_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, false>) + Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, false>) } + }; +} +impl Executor + for EcAddNeExecutor +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + std::mem::size_of::() + } + + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut EcAddNePreCompute = data.borrow_mut(); + let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?; + + dispatch!(execute_e1_impl, pre_compute, is_setup) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut EcAddNePreCompute = data.borrow_mut(); + let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?; + + dispatch!(execute_e1_tco_handler, pre_compute, is_setup) } } @@ -213,84 +233,28 @@ impl MeteredExecu let pre_compute: &mut E2PreCompute = data.borrow_mut(); pre_compute.chip_idx = chip_idx as u32; - let is_setup = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + let pre_compute_pure = &mut pre_compute.data; + let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?; + dispatch!(execute_e2_impl, pre_compute_pure, is_setup) + } - if let Some(field_type) = { - let modulus = &pre_compute.data.expr.builder.prime; - get_field_type(modulus) - } { - match (is_setup, field_type) { - (true, FieldType::K256Coordinate) => Ok(execute_e2_impl::< - _, - _, - BLOCKS, - BLOCK_SIZE, - { FieldType::K256Coordinate as u8 }, - true, - >), - (true, FieldType::P256Coordinate) => Ok(execute_e2_impl::< - _, - _, - BLOCKS, - BLOCK_SIZE, - { FieldType::P256Coordinate as u8 }, - true, - >), - (true, FieldType::BN254Coordinate) => Ok(execute_e2_impl::< - _, - _, - BLOCKS, - BLOCK_SIZE, - { FieldType::BN254Coordinate as u8 }, - true, - >), - (true, FieldType::BLS12_381Coordinate) => Ok(execute_e2_impl::< - _, - _, - BLOCKS, - BLOCK_SIZE, - { FieldType::BLS12_381Coordinate as u8 }, - true, - >), - (false, FieldType::K256Coordinate) => Ok(execute_e2_impl::< - _, - _, - BLOCKS, - BLOCK_SIZE, - { FieldType::K256Coordinate as u8 }, - false, - >), - (false, FieldType::P256Coordinate) => Ok(execute_e2_impl::< - _, - _, - BLOCKS, - BLOCK_SIZE, - { FieldType::P256Coordinate as u8 }, - false, - >), - (false, FieldType::BN254Coordinate) => Ok(execute_e2_impl::< - _, - _, - BLOCKS, - BLOCK_SIZE, - { FieldType::BN254Coordinate as u8 }, - false, - >), - (false, FieldType::BLS12_381Coordinate) => Ok(execute_e2_impl::< - _, - _, - BLOCKS, - BLOCK_SIZE, - { FieldType::BLS12_381Coordinate as u8 }, - false, - >), - _ => panic!("Unsupported field type"), - } - } else if is_setup { - Ok(execute_e2_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, true>) - } else { - Ok(execute_e2_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, false>) - } + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let pre_compute_pure = &mut pre_compute.data; + let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?; + dispatch!(execute_e2_tco_handler, pre_compute_pure, is_setup) } } @@ -351,6 +315,7 @@ unsafe fn execute_e12_impl< vm_state.instret += 1; } +#[create_tco_handler] unsafe fn execute_e1_impl< F: PrimeField32, CTX: ExecutionCtxTrait, @@ -366,6 +331,7 @@ unsafe fn execute_e1_impl< execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, FIELD_TYPE, IS_SETUP>(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl< F: PrimeField32, CTX: MeteredExecutionCtxTrait, diff --git a/extensions/ecc/circuit/src/weierstrass_chip/double/execution.rs b/extensions/ecc/circuit/src/weierstrass_chip/double/execution.rs index 8e755aa6f7..b6d569442a 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/double/execution.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/double/execution.rs @@ -84,45 +84,24 @@ impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize> EcDoubleExecutor Executor - for EcDoubleExecutor -{ - #[inline(always)] - fn pre_compute_size(&self) -> usize { - std::mem::size_of::() - } - - fn pre_compute( - &self, - pc: u32, - inst: &Instruction, - data: &mut [u8], - ) -> Result, StaticProgramError> - where - Ctx: ExecutionCtxTrait, - { - let pre_compute: &mut EcDoublePreCompute = data.borrow_mut(); - - let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?; - +macro_rules! dispatch { + ($execute_impl:ident,$pre_compute:ident,$is_setup:ident) => { if let Some(curve_type) = { - let modulus = &pre_compute.expr.builder.prime; - let a_coeff = &pre_compute.expr.setup_values[0]; + let modulus = &$pre_compute.expr.builder.prime; + let a_coeff = &$pre_compute.expr.setup_values[0]; get_curve_type(modulus, a_coeff) } { - match (is_setup, curve_type) { + match ($is_setup, curve_type) { (true, CurveType::K256) => { - Ok(execute_e1_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::K256 as u8 }, true>) + Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::K256 as u8 }, true>) } (true, CurveType::P256) => { - Ok(execute_e1_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::P256 as u8 }, true>) + Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::P256 as u8 }, true>) } (true, CurveType::BN254) => { - Ok( - execute_e1_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::BN254 as u8 }, true>, - ) + Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::BN254 as u8 }, true>) } - (true, CurveType::BLS12_381) => Ok(execute_e1_impl::< + (true, CurveType::BLS12_381) => Ok($execute_impl::< _, _, BLOCKS, @@ -131,24 +110,15 @@ impl Executor true, >), (false, CurveType::K256) => { - Ok( - execute_e1_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::K256 as u8 }, false>, - ) + Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::K256 as u8 }, false>) } (false, CurveType::P256) => { - Ok( - execute_e1_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::P256 as u8 }, false>, - ) + Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::P256 as u8 }, false>) } - (false, CurveType::BN254) => Ok(execute_e1_impl::< - _, - _, - BLOCKS, - BLOCK_SIZE, - { CurveType::BN254 as u8 }, - false, - >), - (false, CurveType::BLS12_381) => Ok(execute_e1_impl::< + (false, CurveType::BN254) => { + Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::BN254 as u8 }, false>) + } + (false, CurveType::BLS12_381) => Ok($execute_impl::< _, _, BLOCKS, @@ -157,11 +127,51 @@ impl Executor false, >), } - } else if is_setup { - Ok(execute_e1_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, true>) + } else if $is_setup { + Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, true>) } else { - Ok(execute_e1_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, false>) + Ok($execute_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, false>) } + }; +} + +impl Executor + for EcDoubleExecutor +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + std::mem::size_of::() + } + + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut EcDoublePreCompute = data.borrow_mut(); + let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?; + + dispatch!(execute_e1_impl, pre_compute, is_setup) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut EcDoublePreCompute = data.borrow_mut(); + let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?; + + dispatch!(execute_e1_tco_handler, pre_compute, is_setup) } } @@ -185,66 +195,29 @@ impl MeteredExecu { let pre_compute: &mut E2PreCompute = data.borrow_mut(); pre_compute.chip_idx = chip_idx as u32; + let pre_compute_pure = &mut pre_compute.data; + let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?; - let is_setup = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + dispatch!(execute_e2_impl, pre_compute_pure, is_setup) + } - if let Some(curve_type) = { - let modulus = &pre_compute.data.expr.builder.prime; - let a_coeff = &pre_compute.data.expr.setup_values[0]; - get_curve_type(modulus, a_coeff) - } { - match (is_setup, curve_type) { - (true, CurveType::K256) => { - Ok(execute_e2_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::K256 as u8 }, true>) - } - (true, CurveType::P256) => { - Ok(execute_e2_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::P256 as u8 }, true>) - } - (true, CurveType::BN254) => { - Ok( - execute_e2_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::BN254 as u8 }, true>, - ) - } - (true, CurveType::BLS12_381) => Ok(execute_e2_impl::< - _, - _, - BLOCKS, - BLOCK_SIZE, - { CurveType::BLS12_381 as u8 }, - true, - >), - (false, CurveType::K256) => { - Ok( - execute_e2_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::K256 as u8 }, false>, - ) - } - (false, CurveType::P256) => { - Ok( - execute_e2_impl::<_, _, BLOCKS, BLOCK_SIZE, { CurveType::P256 as u8 }, false>, - ) - } - (false, CurveType::BN254) => Ok(execute_e2_impl::< - _, - _, - BLOCKS, - BLOCK_SIZE, - { CurveType::BN254 as u8 }, - false, - >), - (false, CurveType::BLS12_381) => Ok(execute_e2_impl::< - _, - _, - BLOCKS, - BLOCK_SIZE, - { CurveType::BLS12_381 as u8 }, - false, - >), - } - } else if is_setup { - Ok(execute_e2_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, true>) - } else { - Ok(execute_e2_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, false>) - } + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + let pre_compute_pure = &mut pre_compute.data; + let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?; + + dispatch!(execute_e2_tco_handler, pre_compute_pure, is_setup) } } @@ -318,6 +291,7 @@ unsafe fn execute_e12_impl< vm_state.instret += 1; } +#[create_tco_handler] unsafe fn execute_e1_impl< F: PrimeField32, CTX: ExecutionCtxTrait, @@ -333,6 +307,7 @@ unsafe fn execute_e1_impl< execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, CURVE_TYPE, IS_SETUP>(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl< F: PrimeField32, CTX: MeteredExecutionCtxTrait, diff --git a/extensions/ecc/tests/Cargo.toml b/extensions/ecc/tests/Cargo.toml index 5f90e77fa4..7ce8df032c 100644 --- a/extensions/ecc/tests/Cargo.toml +++ b/extensions/ecc/tests/Cargo.toml @@ -28,3 +28,4 @@ halo2curves-axiom = { workspace = true } [features] default = ["parallel"] parallel = ["openvm-circuit/parallel"] +tco = ["openvm-ecc-circuit/tco"] diff --git a/extensions/keccak256/circuit/Cargo.toml b/extensions/keccak256/circuit/Cargo.toml index 2299a0599a..6b603a88f9 100644 --- a/extensions/keccak256/circuit/Cargo.toml +++ b/extensions/keccak256/circuit/Cargo.toml @@ -37,6 +37,7 @@ hex.workspace = true default = ["parallel", "jemalloc"] parallel = ["openvm-circuit/parallel"] test-utils = ["openvm-circuit/test-utils"] +tco = ["openvm-rv32im-circuit/tco"] # performance features: mimalloc = ["openvm-circuit/mimalloc"] jemalloc = ["openvm-circuit/jemalloc"] diff --git a/extensions/keccak256/circuit/src/execution.rs b/extensions/keccak256/circuit/src/execution.rs index b095fec4c4..28d96f0a2a 100644 --- a/extensions/keccak256/circuit/src/execution.rs +++ b/extensions/keccak256/circuit/src/execution.rs @@ -1,4 +1,7 @@ -use std::borrow::{Borrow, BorrowMut}; +use std::{ + borrow::{Borrow, BorrowMut}, + mem::size_of, +}; use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; use openvm_circuit_primitives_derive::AlignedBytesBorrow; @@ -71,6 +74,21 @@ impl Executor for KeccakVmExecutor { self.pre_compute_impl(pc, inst, data)?; Ok(execute_e1_impl::<_, _>) } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut KeccakPreCompute = data.borrow_mut(); + self.pre_compute_impl(pc, inst, data)?; + Ok(execute_e1_tco_handler) + } } impl MeteredExecutor for KeccakVmExecutor { @@ -93,6 +111,23 @@ impl MeteredExecutor for KeccakVmExecutor { self.pre_compute_impl(pc, inst, &mut data.data)?; Ok(execute_e2_impl::<_, _>) } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + self.pre_compute_impl(pc, inst, &mut data.data)?; + Ok(execute_e2_tco_handler::<_, _>) + } } #[inline(always)] @@ -134,6 +169,7 @@ unsafe fn execute_e12_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -142,6 +178,7 @@ unsafe fn execute_e1_impl( execute_e12_impl::(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl( pre_compute: &[u8], vm_state: &mut VmExecState, diff --git a/extensions/keccak256/circuit/src/lib.rs b/extensions/keccak256/circuit/src/lib.rs index 13bd7b27db..bc0b41026a 100644 --- a/extensions/keccak256/circuit/src/lib.rs +++ b/extensions/keccak256/circuit/src/lib.rs @@ -1,3 +1,5 @@ +#![cfg_attr(feature = "tco", allow(incomplete_features))] +#![cfg_attr(feature = "tco", feature(explicit_tail_calls))] //! Stateful keccak256 hasher. Handles full keccak sponge (padding, absorb, keccak-f) on //! variable length inputs read from VM memory. diff --git a/extensions/native/circuit/Cargo.toml b/extensions/native/circuit/Cargo.toml index f9b9bd78c5..8661cc9091 100644 --- a/extensions/native/circuit/Cargo.toml +++ b/extensions/native/circuit/Cargo.toml @@ -20,7 +20,6 @@ openvm-rv32im-circuit = { workspace = true } openvm-rv32im-transpiler = { workspace = true } openvm-native-compiler = { workspace = true } - strum.workspace = true itertools.workspace = true derive-new.workspace = true @@ -40,6 +39,8 @@ test-case = { workspace = true } test-log = { workspace = true } [features] -default = ["parallel"] +default = ["parallel", "jemalloc"] +tco = ["openvm-rv32im-circuit/tco"] +jemalloc = ["openvm-circuit/jemalloc"] parallel = ["openvm-circuit/parallel"] test-utils = ["openvm-circuit/test-utils"] diff --git a/extensions/native/circuit/src/branch_eq/execution.rs b/extensions/native/circuit/src/branch_eq/execution.rs index bbd8051214..003ac378bd 100644 --- a/extensions/native/circuit/src/branch_eq/execution.rs +++ b/extensions/native/circuit/src/branch_eq/execution.rs @@ -75,10 +75,42 @@ impl NativeBranchEqualExecutor { } } +macro_rules! dispatch { + ($execute_impl:ident, $a_is_imm:ident, $b_is_imm:ident, $is_bne:ident) => { + match ($a_is_imm, $b_is_imm, $is_bne) { + (true, true, true) => Ok($execute_impl::<_, _, true, true, true>), + (true, true, false) => Ok($execute_impl::<_, _, true, true, false>), + (true, false, true) => Ok($execute_impl::<_, _, true, false, true>), + (true, false, false) => Ok($execute_impl::<_, _, true, false, false>), + (false, true, true) => Ok($execute_impl::<_, _, false, true, true>), + (false, true, false) => Ok($execute_impl::<_, _, false, true, false>), + (false, false, true) => Ok($execute_impl::<_, _, false, false, true>), + (false, false, false) => Ok($execute_impl::<_, _, false, false, false>), + } + }; +} + impl Executor for NativeBranchEqualExecutor where F: PrimeField32, { + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut NativeBranchEqualPreCompute = data.borrow_mut(); + + let (a_is_imm, b_is_imm, is_bne) = self.pre_compute_impl(pc, inst, pre_compute)?; + + dispatch!(execute_e1_tco_handler, a_is_imm, b_is_imm, is_bne) + } + #[inline(always)] fn pre_compute_size(&self) -> usize { size_of::() @@ -95,18 +127,7 @@ where let (a_is_imm, b_is_imm, is_bne) = self.pre_compute_impl(pc, inst, pre_compute)?; - let fn_ptr = match (a_is_imm, b_is_imm, is_bne) { - (true, true, true) => execute_e1_impl::<_, _, true, true, true>, - (true, true, false) => execute_e1_impl::<_, _, true, true, false>, - (true, false, true) => execute_e1_impl::<_, _, true, false, true>, - (true, false, false) => execute_e1_impl::<_, _, true, false, false>, - (false, true, true) => execute_e1_impl::<_, _, false, true, true>, - (false, true, false) => execute_e1_impl::<_, _, false, true, false>, - (false, false, true) => execute_e1_impl::<_, _, false, false, true>, - (false, false, false) => execute_e1_impl::<_, _, false, false, false>, - }; - - Ok(fn_ptr) + dispatch!(execute_e1_impl, a_is_imm, b_is_imm, is_bne) } } @@ -133,18 +154,24 @@ where let (a_is_imm, b_is_imm, is_bne) = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; - let fn_ptr = match (a_is_imm, b_is_imm, is_bne) { - (true, true, true) => execute_e2_impl::<_, _, true, true, true>, - (true, true, false) => execute_e2_impl::<_, _, true, true, false>, - (true, false, true) => execute_e2_impl::<_, _, true, false, true>, - (true, false, false) => execute_e2_impl::<_, _, true, false, false>, - (false, true, true) => execute_e2_impl::<_, _, false, true, true>, - (false, true, false) => execute_e2_impl::<_, _, false, true, false>, - (false, false, true) => execute_e2_impl::<_, _, false, false, true>, - (false, false, false) => execute_e2_impl::<_, _, false, false, false>, - }; + dispatch!(execute_e2_impl, a_is_imm, b_is_imm, is_bne) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let (a_is_imm, b_is_imm, is_bne) = + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; - Ok(fn_ptr) + dispatch!(execute_e2_tco_handler, a_is_imm, b_is_imm, is_bne) } } @@ -177,6 +204,7 @@ unsafe fn execute_e12_impl< vm_state.instret += 1; } +#[create_tco_handler] unsafe fn execute_e1_impl< F: PrimeField32, CTX: ExecutionCtxTrait, @@ -191,6 +219,7 @@ unsafe fn execute_e1_impl< execute_e12_impl::<_, _, A_IS_IMM, B_IS_IMM, IS_NE>(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl< F: PrimeField32, CTX: MeteredExecutionCtxTrait, diff --git a/extensions/native/circuit/src/castf/execution.rs b/extensions/native/circuit/src/castf/execution.rs index b477620e4a..99d52913ad 100644 --- a/extensions/native/circuit/src/castf/execution.rs +++ b/extensions/native/circuit/src/castf/execution.rs @@ -52,6 +52,25 @@ impl Executor for CastFCoreExecutor where F: PrimeField32, { + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut CastFPreCompute = data.borrow_mut(); + + self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = execute_e1_tco_handler::<_, _>; + + Ok(fn_ptr) + } + #[inline(always)] fn pre_compute_size(&self) -> usize { size_of::() @@ -100,8 +119,27 @@ where Ok(fn_ptr) } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + let fn_ptr = execute_e2_tco_handler::<_, _>; + + Ok(fn_ptr) + } } +#[create_tco_handler] unsafe fn execute_e1_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -110,6 +148,7 @@ unsafe fn execute_e1_impl( execute_e12_impl(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl( pre_compute: &[u8], vm_state: &mut VmExecState, diff --git a/extensions/native/circuit/src/field_arithmetic/execution.rs b/extensions/native/circuit/src/field_arithmetic/execution.rs index cac0770181..38c6453763 100644 --- a/extensions/native/circuit/src/field_arithmetic/execution.rs +++ b/extensions/native/circuit/src/field_arithmetic/execution.rs @@ -76,78 +76,99 @@ impl FieldArithmeticCoreExecutor { } } -impl Executor for FieldArithmeticCoreExecutor -where - F: PrimeField32, -{ - #[inline(always)] - fn pre_compute_size(&self) -> usize { - size_of::() - } - - #[inline(always)] - fn pre_compute( - &self, - pc: u32, - inst: &Instruction, - data: &mut [u8], - ) -> Result, StaticProgramError> { - let pre_compute: &mut FieldArithmeticPreCompute = data.borrow_mut(); - - let (a_is_imm, b_is_imm, local_opcode) = self.pre_compute_impl(pc, inst, pre_compute)?; - - let fn_ptr = match (local_opcode, a_is_imm, b_is_imm) { +macro_rules! dispatch { + ($execute_impl:ident, $local_opcode:ident, $a_is_imm:ident, $b_is_imm:ident) => { + match ($local_opcode, $a_is_imm, $b_is_imm) { (FieldArithmeticOpcode::ADD, true, true) => { - execute_e1_impl::<_, _, true, true, { FieldArithmeticOpcode::ADD as u8 }> + Ok($execute_impl::<_, _, true, true, { FieldArithmeticOpcode::ADD as u8 }>) } (FieldArithmeticOpcode::ADD, true, false) => { - execute_e1_impl::<_, _, true, false, { FieldArithmeticOpcode::ADD as u8 }> + Ok($execute_impl::<_, _, true, false, { FieldArithmeticOpcode::ADD as u8 }>) } (FieldArithmeticOpcode::ADD, false, true) => { - execute_e1_impl::<_, _, false, true, { FieldArithmeticOpcode::ADD as u8 }> + Ok($execute_impl::<_, _, false, true, { FieldArithmeticOpcode::ADD as u8 }>) } (FieldArithmeticOpcode::ADD, false, false) => { - execute_e1_impl::<_, _, false, false, { FieldArithmeticOpcode::ADD as u8 }> + Ok($execute_impl::<_, _, false, false, { FieldArithmeticOpcode::ADD as u8 }>) } (FieldArithmeticOpcode::SUB, true, true) => { - execute_e1_impl::<_, _, true, true, { FieldArithmeticOpcode::SUB as u8 }> + Ok($execute_impl::<_, _, true, true, { FieldArithmeticOpcode::SUB as u8 }>) } (FieldArithmeticOpcode::SUB, true, false) => { - execute_e1_impl::<_, _, true, false, { FieldArithmeticOpcode::SUB as u8 }> + Ok($execute_impl::<_, _, true, false, { FieldArithmeticOpcode::SUB as u8 }>) } (FieldArithmeticOpcode::SUB, false, true) => { - execute_e1_impl::<_, _, false, true, { FieldArithmeticOpcode::SUB as u8 }> + Ok($execute_impl::<_, _, false, true, { FieldArithmeticOpcode::SUB as u8 }>) } (FieldArithmeticOpcode::SUB, false, false) => { - execute_e1_impl::<_, _, false, false, { FieldArithmeticOpcode::SUB as u8 }> + Ok($execute_impl::<_, _, false, false, { FieldArithmeticOpcode::SUB as u8 }>) } (FieldArithmeticOpcode::MUL, true, true) => { - execute_e1_impl::<_, _, true, true, { FieldArithmeticOpcode::MUL as u8 }> + Ok($execute_impl::<_, _, true, true, { FieldArithmeticOpcode::MUL as u8 }>) } (FieldArithmeticOpcode::MUL, true, false) => { - execute_e1_impl::<_, _, true, false, { FieldArithmeticOpcode::MUL as u8 }> + Ok($execute_impl::<_, _, true, false, { FieldArithmeticOpcode::MUL as u8 }>) } (FieldArithmeticOpcode::MUL, false, true) => { - execute_e1_impl::<_, _, false, true, { FieldArithmeticOpcode::MUL as u8 }> + Ok($execute_impl::<_, _, false, true, { FieldArithmeticOpcode::MUL as u8 }>) } (FieldArithmeticOpcode::MUL, false, false) => { - execute_e1_impl::<_, _, false, false, { FieldArithmeticOpcode::MUL as u8 }> + Ok($execute_impl::<_, _, false, false, { FieldArithmeticOpcode::MUL as u8 }>) } (FieldArithmeticOpcode::DIV, true, true) => { - execute_e1_impl::<_, _, true, true, { FieldArithmeticOpcode::DIV as u8 }> + Ok($execute_impl::<_, _, true, true, { FieldArithmeticOpcode::DIV as u8 }>) } (FieldArithmeticOpcode::DIV, true, false) => { - execute_e1_impl::<_, _, true, false, { FieldArithmeticOpcode::DIV as u8 }> + Ok($execute_impl::<_, _, true, false, { FieldArithmeticOpcode::DIV as u8 }>) } (FieldArithmeticOpcode::DIV, false, true) => { - execute_e1_impl::<_, _, false, true, { FieldArithmeticOpcode::DIV as u8 }> + Ok($execute_impl::<_, _, false, true, { FieldArithmeticOpcode::DIV as u8 }>) } (FieldArithmeticOpcode::DIV, false, false) => { - execute_e1_impl::<_, _, false, false, { FieldArithmeticOpcode::DIV as u8 }> + Ok($execute_impl::<_, _, false, false, { FieldArithmeticOpcode::DIV as u8 }>) } - }; + } + }; +} + +impl Executor for FieldArithmeticCoreExecutor +where + F: PrimeField32, +{ + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut FieldArithmeticPreCompute = data.borrow_mut(); + + let (a_is_imm, b_is_imm, local_opcode) = self.pre_compute_impl(pc, inst, pre_compute)?; - Ok(fn_ptr) + dispatch!(execute_e1_tco_handler, local_opcode, a_is_imm, b_is_imm) + } + + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } + + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut FieldArithmeticPreCompute = data.borrow_mut(); + + let (a_is_imm, b_is_imm, local_opcode) = self.pre_compute_impl(pc, inst, pre_compute)?; + + dispatch!(execute_e1_impl, local_opcode, a_is_imm, b_is_imm) } } @@ -174,58 +195,24 @@ where let (a_is_imm, b_is_imm, local_opcode) = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; - let fn_ptr = match (local_opcode, a_is_imm, b_is_imm) { - (FieldArithmeticOpcode::ADD, true, true) => { - execute_e2_impl::<_, _, true, true, { FieldArithmeticOpcode::ADD as u8 }> - } - (FieldArithmeticOpcode::ADD, true, false) => { - execute_e2_impl::<_, _, true, false, { FieldArithmeticOpcode::ADD as u8 }> - } - (FieldArithmeticOpcode::ADD, false, true) => { - execute_e2_impl::<_, _, false, true, { FieldArithmeticOpcode::ADD as u8 }> - } - (FieldArithmeticOpcode::ADD, false, false) => { - execute_e2_impl::<_, _, false, false, { FieldArithmeticOpcode::ADD as u8 }> - } - (FieldArithmeticOpcode::SUB, true, true) => { - execute_e2_impl::<_, _, true, true, { FieldArithmeticOpcode::SUB as u8 }> - } - (FieldArithmeticOpcode::SUB, true, false) => { - execute_e2_impl::<_, _, true, false, { FieldArithmeticOpcode::SUB as u8 }> - } - (FieldArithmeticOpcode::SUB, false, true) => { - execute_e2_impl::<_, _, false, true, { FieldArithmeticOpcode::SUB as u8 }> - } - (FieldArithmeticOpcode::SUB, false, false) => { - execute_e2_impl::<_, _, false, false, { FieldArithmeticOpcode::SUB as u8 }> - } - (FieldArithmeticOpcode::MUL, true, true) => { - execute_e2_impl::<_, _, true, true, { FieldArithmeticOpcode::MUL as u8 }> - } - (FieldArithmeticOpcode::MUL, true, false) => { - execute_e2_impl::<_, _, true, false, { FieldArithmeticOpcode::MUL as u8 }> - } - (FieldArithmeticOpcode::MUL, false, true) => { - execute_e2_impl::<_, _, false, true, { FieldArithmeticOpcode::MUL as u8 }> - } - (FieldArithmeticOpcode::MUL, false, false) => { - execute_e2_impl::<_, _, false, false, { FieldArithmeticOpcode::MUL as u8 }> - } - (FieldArithmeticOpcode::DIV, true, true) => { - execute_e2_impl::<_, _, true, true, { FieldArithmeticOpcode::DIV as u8 }> - } - (FieldArithmeticOpcode::DIV, true, false) => { - execute_e2_impl::<_, _, true, false, { FieldArithmeticOpcode::DIV as u8 }> - } - (FieldArithmeticOpcode::DIV, false, true) => { - execute_e2_impl::<_, _, false, true, { FieldArithmeticOpcode::DIV as u8 }> - } - (FieldArithmeticOpcode::DIV, false, false) => { - execute_e2_impl::<_, _, false, false, { FieldArithmeticOpcode::DIV as u8 }> - } - }; + dispatch!(execute_e2_impl, local_opcode, a_is_imm, b_is_imm) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let (a_is_imm, b_is_imm, local_opcode) = + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; - Ok(fn_ptr) + dispatch!(execute_e2_tco_handler, local_opcode, a_is_imm, b_is_imm) } } @@ -276,6 +263,7 @@ unsafe fn execute_e12_impl< vm_state.instret += 1; } +#[create_tco_handler] unsafe fn execute_e1_impl< F: PrimeField32, CTX: ExecutionCtxTrait, @@ -290,6 +278,7 @@ unsafe fn execute_e1_impl< execute_e12_impl::(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl< F: PrimeField32, CTX: MeteredExecutionCtxTrait, diff --git a/extensions/native/circuit/src/field_extension/execution.rs b/extensions/native/circuit/src/field_extension/execution.rs index 7b4802987e..2752a05c44 100644 --- a/extensions/native/circuit/src/field_extension/execution.rs +++ b/extensions/native/circuit/src/field_extension/execution.rs @@ -57,10 +57,39 @@ impl FieldExtensionCoreExecutor { } } +macro_rules! dispatch { + ($execute_impl:ident, $opcode:ident) => { + match $opcode { + 0 => Ok($execute_impl::<_, _, 0>), // FE4ADD + 1 => Ok($execute_impl::<_, _, 1>), // FE4SUB + 2 => Ok($execute_impl::<_, _, 2>), // BBE4MUL + 3 => Ok($execute_impl::<_, _, 3>), // BBE4DIV + _ => panic!("Invalid field extension opcode: {}", $opcode), + } + }; +} + impl Executor for FieldExtensionCoreExecutor where F: PrimeField32, { + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut FieldExtensionPreCompute = data.borrow_mut(); + + let opcode = self.pre_compute_impl(pc, inst, pre_compute)?; + + dispatch!(execute_e1_tco_handler, opcode) + } + #[inline(always)] fn pre_compute_size(&self) -> usize { size_of::() @@ -77,15 +106,7 @@ where let opcode = self.pre_compute_impl(pc, inst, pre_compute)?; - let fn_ptr = match opcode { - 0 => execute_e1_impl::<_, _, 0>, // FE4ADD - 1 => execute_e1_impl::<_, _, 1>, // FE4SUB - 2 => execute_e1_impl::<_, _, 2>, // BBE4MUL - 3 => execute_e1_impl::<_, _, 3>, // BBE4DIV - _ => panic!("Invalid field extension opcode: {opcode}"), - }; - - Ok(fn_ptr) + dispatch!(execute_e1_impl, opcode) } } @@ -111,15 +132,23 @@ where let opcode = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; - let fn_ptr = match opcode { - 0 => execute_e2_impl::<_, _, 0>, // FE4ADD - 1 => execute_e2_impl::<_, _, 1>, // FE4SUB - 2 => execute_e2_impl::<_, _, 2>, // BBE4MUL - 3 => execute_e2_impl::<_, _, 3>, // BBE4DIV - _ => panic!("Invalid field extension opcode: {opcode}"), - }; + dispatch!(execute_e2_impl, opcode) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let opcode = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; - Ok(fn_ptr) + dispatch!(execute_e2_tco_handler, opcode) } } @@ -145,6 +174,7 @@ unsafe fn execute_e12_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -153,6 +183,7 @@ unsafe fn execute_e1_impl(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl( pre_compute: &[u8], vm_state: &mut VmExecState, diff --git a/extensions/native/circuit/src/fri/execution.rs b/extensions/native/circuit/src/fri/execution.rs index 7af4034ed9..13297cc260 100644 --- a/extensions/native/circuit/src/fri/execution.rs +++ b/extensions/native/circuit/src/fri/execution.rs @@ -66,6 +66,24 @@ impl Executor for FriReducedOpeningExecutor where F: PrimeField32, { + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut FriReducedOpeningPreCompute = data.borrow_mut(); + + self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = execute_e1_tco_handler; + Ok(fn_ptr) + } + #[inline(always)] fn pre_compute_size(&self) -> usize { size_of::() @@ -112,8 +130,26 @@ where let fn_ptr = execute_e2_impl; Ok(fn_ptr) } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + let fn_ptr = execute_e2_tco_handler; + Ok(fn_ptr) + } } +#[create_tco_handler] unsafe fn execute_e1_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -122,6 +158,7 @@ unsafe fn execute_e1_impl( execute_e12_impl(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl( pre_compute: &[u8], vm_state: &mut VmExecState, diff --git a/extensions/native/circuit/src/jal_rangecheck/execution.rs b/extensions/native/circuit/src/jal_rangecheck/execution.rs index f9cf17d7af..296aa536af 100644 --- a/extensions/native/circuit/src/jal_rangecheck/execution.rs +++ b/extensions/native/circuit/src/jal_rangecheck/execution.rs @@ -109,6 +109,28 @@ where Ok(execute_range_check_e1_impl) } } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let &Instruction { opcode, .. } = inst; + + let is_jal = opcode == NativeJalOpcode::JAL.global_opcode(); + + if is_jal { + let jal_data: &mut JalPreCompute = data.borrow_mut(); + self.pre_compute_jal_impl(pc, inst, jal_data)?; + Ok(execute_jal_e1_tco_handler) + } else { + let range_check_data: &mut RangeCheckPreCompute = data.borrow_mut(); + self.pre_compute_range_check_impl(pc, inst, range_check_data)?; + Ok(execute_range_check_e1_tco_handler) + } + } } impl MeteredExecutor for JalRangeCheckExecutor @@ -149,6 +171,33 @@ where Ok(execute_range_check_e2_impl) } } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let &Instruction { opcode, .. } = inst; + + let is_jal = opcode == NativeJalOpcode::JAL.global_opcode(); + + if is_jal { + let pre_compute: &mut E2PreCompute> = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_jal_impl(pc, inst, &mut pre_compute.data)?; + Ok(execute_jal_e2_tco_handler) + } else { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_range_check_impl(pc, inst, &mut pre_compute.data)?; + Ok(execute_range_check_e2_tco_handler) + } + } } #[inline(always)] @@ -190,6 +239,7 @@ unsafe fn execute_range_check_e12_impl( vm_state.instret += 1; } +#[create_tco_handler] unsafe fn execute_jal_e1_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -198,6 +248,7 @@ unsafe fn execute_jal_e1_impl( execute_jal_e12_impl(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_jal_e2_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -209,6 +260,7 @@ unsafe fn execute_jal_e2_impl( execute_jal_e12_impl(&pre_compute.data, vm_state); } +#[create_tco_handler] unsafe fn execute_range_check_e1_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -217,6 +269,7 @@ unsafe fn execute_range_check_e1_impl( execute_range_check_e12_impl(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_range_check_e2_impl( pre_compute: &[u8], vm_state: &mut VmExecState, diff --git a/extensions/native/circuit/src/lib.rs b/extensions/native/circuit/src/lib.rs index 01c0d0ba5b..94f4a7d75f 100644 --- a/extensions/native/circuit/src/lib.rs +++ b/extensions/native/circuit/src/lib.rs @@ -1,3 +1,6 @@ +#![cfg_attr(feature = "tco", allow(incomplete_features))] +#![cfg_attr(feature = "tco", feature(explicit_tail_calls))] + use openvm_circuit::{ arch::{ AirInventory, ChipInventoryError, InitFileGenerator, MatrixRecordArena, MemoryConfig, diff --git a/extensions/native/circuit/src/loadstore/execution.rs b/extensions/native/circuit/src/loadstore/execution.rs index a31efb831e..b0b3e4b726 100644 --- a/extensions/native/circuit/src/loadstore/execution.rs +++ b/extensions/native/circuit/src/loadstore/execution.rs @@ -82,6 +82,28 @@ where Ok(fn_ptr) } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut NativeLoadStorePreCompute = data.borrow_mut(); + + let local_opcode = self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = match local_opcode { + NativeLoadStoreOpcode::LOADW => execute_e1_loadw_tco_handler::, + NativeLoadStoreOpcode::STOREW => execute_e1_storew_tco_handler::, + NativeLoadStoreOpcode::HINT_STOREW => { + execute_e1_hint_storew_tco_handler:: + } + }; + + Ok(fn_ptr) + } } impl MeteredExecutor for NativeLoadStoreCoreExecutor @@ -114,8 +136,33 @@ where Ok(fn_ptr) } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute> = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let local_opcode = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + let fn_ptr = match local_opcode { + NativeLoadStoreOpcode::LOADW => execute_e2_loadw_tco_handler::, + NativeLoadStoreOpcode::STOREW => execute_e2_storew_tco_handler::, + NativeLoadStoreOpcode::HINT_STOREW => { + execute_e2_hint_storew_tco_handler:: + } + }; + + Ok(fn_ptr) + } } +#[create_tco_handler] unsafe fn execute_e1_loadw( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -124,6 +171,7 @@ unsafe fn execute_e1_loadw(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e1_storew( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -132,6 +180,7 @@ unsafe fn execute_e1_storew(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e1_hint_storew< F: PrimeField32, CTX: ExecutionCtxTrait, @@ -144,6 +193,7 @@ unsafe fn execute_e1_hint_storew< execute_e12_hint_storew::<_, _, NUM_CELLS>(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_loadw< F: PrimeField32, CTX: MeteredExecutionCtxTrait, @@ -159,6 +209,7 @@ unsafe fn execute_e2_loadw< execute_e12_loadw::<_, _, NUM_CELLS>(&pre_compute.data, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_storew< F: PrimeField32, CTX: MeteredExecutionCtxTrait, @@ -174,6 +225,7 @@ unsafe fn execute_e2_storew< execute_e12_storew::<_, _, NUM_CELLS>(&pre_compute.data, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_hint_storew< F: PrimeField32, CTX: MeteredExecutionCtxTrait, diff --git a/extensions/native/circuit/src/poseidon2/execution.rs b/extensions/native/circuit/src/poseidon2/execution.rs index 661d8e10cc..5dcc356c55 100644 --- a/extensions/native/circuit/src/poseidon2/execution.rs +++ b/extensions/native/circuit/src/poseidon2/execution.rs @@ -136,6 +136,33 @@ impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Executor { + if $opcode == PERM_POS2.global_opcode() || $opcode == COMP_POS2.global_opcode() { + let pos2_data: &mut Pos2PreCompute = $data.borrow_mut(); + $executor.pre_compute_pos2_impl($pc, $inst, pos2_data)?; + if $opcode == PERM_POS2.global_opcode() { + Ok($execute_pos2_impl::<_, _, SBOX_REGISTERS, true>) + } else { + Ok($execute_pos2_impl::<_, _, SBOX_REGISTERS, false>) + } + } else { + let verify_batch_data: &mut VerifyBatchPreCompute = + $data.borrow_mut(); + $executor.pre_compute_verify_batch_impl($pc, $inst, verify_batch_data)?; + Ok($execute_verify_batch_impl::<_, _, SBOX_REGISTERS>) + } + }; +} + impl Executor for NativePoseidon2Executor { @@ -154,25 +181,67 @@ impl Executor inst: &Instruction, data: &mut [u8], ) -> Result, StaticProgramError> { - let &Instruction { opcode, .. } = inst; + dispatch1!( + execute_pos2_e1_impl, + execute_verify_batch_e1_impl, + self, + inst.opcode, + pc, + inst, + data + ) + } - let is_pos2 = opcode == PERM_POS2.global_opcode() || opcode == COMP_POS2.global_opcode(); + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + dispatch1!( + execute_pos2_e1_tco_handler, + execute_verify_batch_e1_tco_handler, + self, + inst.opcode, + pc, + inst, + data + ) + } +} + +macro_rules! dispatch2 { + ( + $execute_pos2_impl:ident, + $execute_verify_batch_impl:ident, + $executor:ident, + $opcode:expr, + $chip_idx:ident, + $pc:ident, + $inst:ident, + $data:ident + ) => { + if $opcode == PERM_POS2.global_opcode() || $opcode == COMP_POS2.global_opcode() { + let pre_compute: &mut E2PreCompute> = + $data.borrow_mut(); + pre_compute.chip_idx = $chip_idx as u32; - if is_pos2 { - let pos2_data: &mut Pos2PreCompute = data.borrow_mut(); - self.pre_compute_pos2_impl(pc, inst, pos2_data)?; - if opcode == PERM_POS2.global_opcode() { - Ok(execute_pos2_e1_impl::<_, _, SBOX_REGISTERS, true>) + $executor.pre_compute_pos2_impl($pc, $inst, &mut pre_compute.data)?; + if $opcode == PERM_POS2.global_opcode() { + Ok($execute_pos2_impl::<_, _, SBOX_REGISTERS, true>) } else { - Ok(execute_pos2_e1_impl::<_, _, SBOX_REGISTERS, false>) + Ok($execute_pos2_impl::<_, _, SBOX_REGISTERS, false>) } } else { - let verify_batch_data: &mut VerifyBatchPreCompute = - data.borrow_mut(); - self.pre_compute_verify_batch_impl(pc, inst, verify_batch_data)?; - Ok(execute_verify_batch_e1_impl::<_, _, SBOX_REGISTERS>) + let pre_compute: &mut E2PreCompute> = + $data.borrow_mut(); + pre_compute.chip_idx = $chip_idx as u32; + + $executor.pre_compute_verify_batch_impl($pc, $inst, &mut pre_compute.data)?; + Ok($execute_verify_batch_impl::<_, _, SBOX_REGISTERS>) } - } + }; } impl MeteredExecutor @@ -194,32 +263,40 @@ impl MeteredExecutor inst: &Instruction, data: &mut [u8], ) -> Result, StaticProgramError> { - let &Instruction { opcode, .. } = inst; - - let is_pos2 = opcode == PERM_POS2.global_opcode() || opcode == COMP_POS2.global_opcode(); - - if is_pos2 { - let pre_compute: &mut E2PreCompute> = - data.borrow_mut(); - pre_compute.chip_idx = chip_idx as u32; - - self.pre_compute_pos2_impl(pc, inst, &mut pre_compute.data)?; - if opcode == PERM_POS2.global_opcode() { - Ok(execute_pos2_e2_impl::<_, _, SBOX_REGISTERS, true>) - } else { - Ok(execute_pos2_e2_impl::<_, _, SBOX_REGISTERS, false>) - } - } else { - let pre_compute: &mut E2PreCompute> = - data.borrow_mut(); - pre_compute.chip_idx = chip_idx as u32; + dispatch2!( + execute_pos2_e2_impl, + execute_verify_batch_e2_impl, + self, + inst.opcode, + chip_idx, + pc, + inst, + data + ) + } - self.pre_compute_verify_batch_impl(pc, inst, &mut pre_compute.data)?; - Ok(execute_verify_batch_e2_impl::<_, _, SBOX_REGISTERS>) - } + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + dispatch2!( + execute_pos2_e2_tco_handler, + execute_verify_batch_e2_tco_handler, + self, + inst.opcode, + chip_idx, + pc, + inst, + data + ) } } +#[create_tco_handler] unsafe fn execute_pos2_e1_impl< F: PrimeField32, CTX: ExecutionCtxTrait, @@ -233,6 +310,7 @@ unsafe fn execute_pos2_e1_impl< execute_pos2_e12_impl::<_, _, SBOX_REGISTERS, IS_PERM>(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_pos2_e2_impl< F: PrimeField32, CTX: MeteredExecutionCtxTrait, @@ -250,6 +328,7 @@ unsafe fn execute_pos2_e2_impl< .on_height_change(pre_compute.chip_idx as usize, height); } +#[create_tco_handler] unsafe fn execute_verify_batch_e1_impl< F: PrimeField32, CTX: ExecutionCtxTrait, @@ -263,6 +342,7 @@ unsafe fn execute_verify_batch_e1_impl< execute_verify_batch_e12_impl::<_, _, SBOX_REGISTERS, true>(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_verify_batch_e2_impl< F: PrimeField32, CTX: MeteredExecutionCtxTrait, diff --git a/extensions/pairing/circuit/Cargo.toml b/extensions/pairing/circuit/Cargo.toml index a44afff0f8..46565dd58e 100644 --- a/extensions/pairing/circuit/Cargo.toml +++ b/extensions/pairing/circuit/Cargo.toml @@ -49,3 +49,7 @@ openvm-pairing-guest = { workspace = true, features = [ "bls12_381", "bn254", ] } + +[features] +default = [] +tco = ["openvm-rv32im-circuit/tco", "openvm-ecc-circuit/tco"] diff --git a/extensions/pairing/circuit/src/lib.rs b/extensions/pairing/circuit/src/lib.rs index 7edefa5490..58e6527345 100644 --- a/extensions/pairing/circuit/src/lib.rs +++ b/extensions/pairing/circuit/src/lib.rs @@ -1,3 +1,6 @@ +#![cfg_attr(feature = "tco", allow(incomplete_features))] +#![cfg_attr(feature = "tco", feature(explicit_tail_calls))] + pub use openvm_pairing_guest::{ bls12_381::{BLS12_381_COMPLEX_STRUCT_NAME, BLS12_381_ECC_STRUCT_NAME}, bn254::BN254_COMPLEX_STRUCT_NAME, diff --git a/extensions/rv32im/circuit/Cargo.toml b/extensions/rv32im/circuit/Cargo.toml index 9f6bbb6824..1b6d021f35 100644 --- a/extensions/rv32im/circuit/Cargo.toml +++ b/extensions/rv32im/circuit/Cargo.toml @@ -36,6 +36,7 @@ test-case.workspace = true default = ["parallel", "jemalloc"] parallel = ["openvm-circuit/parallel"] test-utils = ["openvm-circuit/test-utils", "dep:openvm-stark-sdk"] +tco = ["openvm-circuit/tco"] # performance features: mimalloc = ["openvm-circuit/mimalloc"] jemalloc = ["openvm-circuit/jemalloc"] diff --git a/extensions/rv32im/circuit/src/auipc/execution.rs b/extensions/rv32im/circuit/src/auipc/execution.rs index c9269613a1..454172a06f 100644 --- a/extensions/rv32im/circuit/src/auipc/execution.rs +++ b/extensions/rv32im/circuit/src/auipc/execution.rs @@ -3,13 +3,7 @@ use std::{ mem::size_of, }; -use openvm_circuit::{ - arch::{ - E2PreCompute, ExecuteFunc, ExecutionCtxTrait, Executor, MeteredExecutionCtxTrait, - MeteredExecutor, StaticProgramError, VmExecState, - }, - system::memory::online::GuestMemory, -}; +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; use openvm_circuit_primitives_derive::AlignedBytesBorrow; use openvm_instructions::{ instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, @@ -66,6 +60,21 @@ where self.pre_compute_impl(pc, inst, data)?; Ok(execute_e1_impl) } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut AuiPcPreCompute = data.borrow_mut(); + self.pre_compute_impl(pc, inst, data)?; + Ok(execute_e1_tco_handler) + } } impl MeteredExecutor for Rv32AuipcExecutor @@ -91,6 +100,23 @@ where self.pre_compute_impl(pc, inst, &mut data.data)?; Ok(execute_e2_impl) } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + self.pre_compute_impl(pc, inst, &mut data.data)?; + Ok(execute_e2_tco_handler) + } } #[inline(always)] @@ -105,6 +131,7 @@ unsafe fn execute_e12_impl( vm_state.instret += 1; } +#[create_tco_handler] unsafe fn execute_e1_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -113,6 +140,7 @@ unsafe fn execute_e1_impl( execute_e12_impl(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl( pre_compute: &[u8], vm_state: &mut VmExecState, diff --git a/extensions/rv32im/circuit/src/base_alu/execution.rs b/extensions/rv32im/circuit/src/base_alu/execution.rs index acbbf12844..f2447e4f4e 100644 --- a/extensions/rv32im/circuit/src/base_alu/execution.rs +++ b/extensions/rv32im/circuit/src/base_alu/execution.rs @@ -3,13 +3,7 @@ use std::{ mem::size_of, }; -use openvm_circuit::{ - arch::{ - E2PreCompute, ExecuteFunc, ExecutionCtxTrait, Executor, MeteredExecutionCtxTrait, - MeteredExecutor, StaticProgramError, VmExecState, - }, - system::memory::online::GuestMemory, -}; +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; use openvm_circuit_primitives_derive::AlignedBytesBorrow; use openvm_instructions::{ instruction::Instruction, @@ -61,6 +55,28 @@ impl BaseAluExecutor { + Ok( + match ( + $is_imm, + BaseAluOpcode::from_usize($opcode.local_opcode_idx($offset)), + ) { + (true, BaseAluOpcode::ADD) => $execute_impl::<_, _, true, AddOp>, + (false, BaseAluOpcode::ADD) => $execute_impl::<_, _, false, AddOp>, + (true, BaseAluOpcode::SUB) => $execute_impl::<_, _, true, SubOp>, + (false, BaseAluOpcode::SUB) => $execute_impl::<_, _, false, SubOp>, + (true, BaseAluOpcode::XOR) => $execute_impl::<_, _, true, XorOp>, + (false, BaseAluOpcode::XOR) => $execute_impl::<_, _, false, XorOp>, + (true, BaseAluOpcode::OR) => $execute_impl::<_, _, true, OrOp>, + (false, BaseAluOpcode::OR) => $execute_impl::<_, _, false, OrOp>, + (true, BaseAluOpcode::AND) => $execute_impl::<_, _, true, AndOp>, + (false, BaseAluOpcode::AND) => $execute_impl::<_, _, false, AndOp>, + }, + ) + }; +} + impl Executor for BaseAluExecutor where @@ -71,7 +87,6 @@ where size_of::() } - #[inline(always)] fn pre_compute( &self, pc: u32, @@ -83,24 +98,24 @@ where { let data: &mut BaseAluPreCompute = data.borrow_mut(); let is_imm = self.pre_compute_impl(pc, inst, data)?; - let opcode = inst.opcode; - - let fn_ptr = match ( - is_imm, - BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.offset)), - ) { - (true, BaseAluOpcode::ADD) => execute_e1_impl::<_, _, true, AddOp>, - (false, BaseAluOpcode::ADD) => execute_e1_impl::<_, _, false, AddOp>, - (true, BaseAluOpcode::SUB) => execute_e1_impl::<_, _, true, SubOp>, - (false, BaseAluOpcode::SUB) => execute_e1_impl::<_, _, false, SubOp>, - (true, BaseAluOpcode::XOR) => execute_e1_impl::<_, _, true, XorOp>, - (false, BaseAluOpcode::XOR) => execute_e1_impl::<_, _, false, XorOp>, - (true, BaseAluOpcode::OR) => execute_e1_impl::<_, _, true, OrOp>, - (false, BaseAluOpcode::OR) => execute_e1_impl::<_, _, false, OrOp>, - (true, BaseAluOpcode::AND) => execute_e1_impl::<_, _, true, AndOp>, - (false, BaseAluOpcode::AND) => execute_e1_impl::<_, _, false, AndOp>, - }; - Ok(fn_ptr) + + dispatch!(execute_e1_impl, is_imm, inst.opcode, self.offset) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut BaseAluPreCompute = data.borrow_mut(); + let is_imm = self.pre_compute_impl(pc, inst, data)?; + + dispatch!(execute_e1_tco_handler, is_imm, inst.opcode, self.offset) } } @@ -114,7 +129,6 @@ where size_of::>() } - #[inline(always)] fn metered_pre_compute( &self, chip_idx: usize, @@ -128,24 +142,26 @@ where let data: &mut E2PreCompute = data.borrow_mut(); data.chip_idx = chip_idx as u32; let is_imm = self.pre_compute_impl(pc, inst, &mut data.data)?; - let opcode = inst.opcode; - - let fn_ptr = match ( - is_imm, - BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.offset)), - ) { - (true, BaseAluOpcode::ADD) => execute_e2_impl::<_, _, true, AddOp>, - (false, BaseAluOpcode::ADD) => execute_e2_impl::<_, _, false, AddOp>, - (true, BaseAluOpcode::SUB) => execute_e2_impl::<_, _, true, SubOp>, - (false, BaseAluOpcode::SUB) => execute_e2_impl::<_, _, false, SubOp>, - (true, BaseAluOpcode::XOR) => execute_e2_impl::<_, _, true, XorOp>, - (false, BaseAluOpcode::XOR) => execute_e2_impl::<_, _, false, XorOp>, - (true, BaseAluOpcode::OR) => execute_e2_impl::<_, _, true, OrOp>, - (false, BaseAluOpcode::OR) => execute_e2_impl::<_, _, false, OrOp>, - (true, BaseAluOpcode::AND) => execute_e2_impl::<_, _, true, AndOp>, - (false, BaseAluOpcode::AND) => execute_e2_impl::<_, _, false, AndOp>, - }; - Ok(fn_ptr) + + dispatch!(execute_e2_impl, is_imm, inst.opcode, self.offset) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let is_imm = self.pre_compute_impl(pc, inst, &mut data.data)?; + + dispatch!(execute_e2_tco_handler, is_imm, inst.opcode, self.offset) } } @@ -174,6 +190,7 @@ unsafe fn execute_e12_impl< vm_state.instret += 1; } +#[create_tco_handler] #[inline(always)] unsafe fn execute_e1_impl< F: PrimeField32, @@ -188,6 +205,7 @@ unsafe fn execute_e1_impl< execute_e12_impl::(pre_compute, vm_state); } +#[create_tco_handler] #[inline(always)] unsafe fn execute_e2_impl< F: PrimeField32, diff --git a/extensions/rv32im/circuit/src/branch_eq/execution.rs b/extensions/rv32im/circuit/src/branch_eq/execution.rs index dba0d8cddb..70b9dc4d67 100644 --- a/extensions/rv32im/circuit/src/branch_eq/execution.rs +++ b/extensions/rv32im/circuit/src/branch_eq/execution.rs @@ -3,13 +3,7 @@ use std::{ mem::size_of, }; -use openvm_circuit::{ - arch::{ - E2PreCompute, ExecuteFunc, ExecutionCtxTrait, Executor, MeteredExecutionCtxTrait, - MeteredExecutor, StaticProgramError, VmExecState, - }, - system::memory::online::GuestMemory, -}; +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; use openvm_circuit_primitives_derive::AlignedBytesBorrow; use openvm_instructions::{ instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, LocalOpcode, @@ -59,6 +53,16 @@ impl BranchEqualExecutor { } } +macro_rules! dispatch { + ($execute_impl:ident, $is_bne:ident) => { + if $is_bne { + Ok($execute_impl::<_, _, true>) + } else { + Ok($execute_impl::<_, _, false>) + } + }; +} + impl Executor for BranchEqualExecutor where F: PrimeField32, @@ -77,12 +81,22 @@ where ) -> Result, StaticProgramError> { let data: &mut BranchEqualPreCompute = data.borrow_mut(); let is_bne = self.pre_compute_impl(pc, inst, data)?; - let fn_ptr = if is_bne { - execute_e1_impl::<_, _, true> - } else { - execute_e1_impl::<_, _, false> - }; - Ok(fn_ptr) + dispatch!(execute_e1_impl, is_bne) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut BranchEqualPreCompute = data.borrow_mut(); + let is_bne = self.pre_compute_impl(pc, inst, data)?; + dispatch!(execute_e1_tco_handler, is_bne) } } @@ -107,12 +121,24 @@ where let data: &mut E2PreCompute = data.borrow_mut(); data.chip_idx = chip_idx as u32; let is_bne = self.pre_compute_impl(pc, inst, &mut data.data)?; - let fn_ptr = if is_bne { - execute_e2_impl::<_, _, true> - } else { - execute_e2_impl::<_, _, false> - }; - Ok(fn_ptr) + dispatch!(execute_e2_impl, is_bne) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let is_bne = self.pre_compute_impl(pc, inst, &mut data.data)?; + dispatch!(execute_e2_tco_handler, is_bne) } } @@ -131,6 +157,7 @@ unsafe fn execute_e12_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -139,6 +166,7 @@ unsafe fn execute_e1_impl(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl( pre_compute: &[u8], vm_state: &mut VmExecState, diff --git a/extensions/rv32im/circuit/src/branch_lt/execution.rs b/extensions/rv32im/circuit/src/branch_lt/execution.rs index 206a49e4a1..b555973030 100644 --- a/extensions/rv32im/circuit/src/branch_lt/execution.rs +++ b/extensions/rv32im/circuit/src/branch_lt/execution.rs @@ -3,13 +3,7 @@ use std::{ mem::size_of, }; -use openvm_circuit::{ - arch::{ - E2PreCompute, ExecuteFunc, ExecutionCtxTrait, Executor, MeteredExecutionCtxTrait, - MeteredExecutor, StaticProgramError, VmExecState, - }, - system::memory::online::GuestMemory, -}; +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; use openvm_circuit_primitives_derive::AlignedBytesBorrow; use openvm_instructions::{ instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, LocalOpcode, @@ -27,6 +21,17 @@ struct BranchLePreCompute { b: u8, } +macro_rules! dispatch { + ($execute_impl:ident, $local_opcode:ident) => { + match $local_opcode { + BranchLessThanOpcode::BLT => Ok($execute_impl::<_, _, BltOp>), + BranchLessThanOpcode::BLTU => Ok($execute_impl::<_, _, BltuOp>), + BranchLessThanOpcode::BGE => Ok($execute_impl::<_, _, BgeOp>), + BranchLessThanOpcode::BGEU => Ok($execute_impl::<_, _, BgeuOp>), + } + }; +} + impl BranchLessThanExecutor { @@ -78,13 +83,22 @@ where ) -> Result, StaticProgramError> { let data: &mut BranchLePreCompute = data.borrow_mut(); let local_opcode = self.pre_compute_impl(pc, inst, data)?; - let fn_ptr = match local_opcode { - BranchLessThanOpcode::BLT => execute_e1_impl::<_, _, BltOp>, - BranchLessThanOpcode::BLTU => execute_e1_impl::<_, _, BltuOp>, - BranchLessThanOpcode::BGE => execute_e1_impl::<_, _, BgeOp>, - BranchLessThanOpcode::BGEU => execute_e1_impl::<_, _, BgeuOp>, - }; - Ok(fn_ptr) + dispatch!(execute_e1_impl, local_opcode) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut BranchLePreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, data)?; + dispatch!(execute_e1_tco_handler, local_opcode) } } @@ -110,13 +124,24 @@ where let data: &mut E2PreCompute = data.borrow_mut(); data.chip_idx = chip_idx as u32; let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; - let fn_ptr = match local_opcode { - BranchLessThanOpcode::BLT => execute_e2_impl::<_, _, BltOp>, - BranchLessThanOpcode::BLTU => execute_e2_impl::<_, _, BltuOp>, - BranchLessThanOpcode::BGE => execute_e2_impl::<_, _, BgeOp>, - BranchLessThanOpcode::BGEU => execute_e2_impl::<_, _, BgeuOp>, - }; - Ok(fn_ptr) + dispatch!(execute_e2_impl, local_opcode) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; + dispatch!(execute_e2_tco_handler, local_opcode) } } @@ -136,6 +161,7 @@ unsafe fn execute_e12_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -144,6 +170,7 @@ unsafe fn execute_e1_impl(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl( pre_compute: &[u8], vm_state: &mut VmExecState, diff --git a/extensions/rv32im/circuit/src/divrem/execution.rs b/extensions/rv32im/circuit/src/divrem/execution.rs index dd87de540b..68280f6c2b 100644 --- a/extensions/rv32im/circuit/src/divrem/execution.rs +++ b/extensions/rv32im/circuit/src/divrem/execution.rs @@ -3,13 +3,7 @@ use std::{ mem::size_of, }; -use openvm_circuit::{ - arch::{ - E2PreCompute, ExecuteFunc, ExecutionCtxTrait, Executor, MeteredExecutionCtxTrait, - MeteredExecutor, StaticProgramError, VmExecState, - }, - system::memory::online::GuestMemory, -}; +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; use openvm_circuit_primitives_derive::AlignedBytesBorrow; use openvm_instructions::{ instruction::Instruction, @@ -55,6 +49,17 @@ impl DivRemExecutor { + match $local_opcode { + DivRemOpcode::DIV => Ok($execute_impl::<_, _, DivOp>), + DivRemOpcode::DIVU => Ok($execute_impl::<_, _, DivuOp>), + DivRemOpcode::REM => Ok($execute_impl::<_, _, RemOp>), + DivRemOpcode::REMU => Ok($execute_impl::<_, _, RemuOp>), + } + }; +} + impl Executor for DivRemExecutor where @@ -74,13 +79,22 @@ where ) -> Result, StaticProgramError> { let data: &mut DivRemPreCompute = data.borrow_mut(); let local_opcode = self.pre_compute_impl(pc, inst, data)?; - let fn_ptr = match local_opcode { - DivRemOpcode::DIV => execute_e1_impl::<_, _, DivOp>, - DivRemOpcode::DIVU => execute_e1_impl::<_, _, DivuOp>, - DivRemOpcode::REM => execute_e1_impl::<_, _, RemOp>, - DivRemOpcode::REMU => execute_e1_impl::<_, _, RemuOp>, - }; - Ok(fn_ptr) + dispatch!(execute_e1_impl, local_opcode) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut DivRemPreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, data)?; + dispatch!(execute_e1_tco_handler, local_opcode) } } @@ -106,13 +120,24 @@ where let data: &mut E2PreCompute = data.borrow_mut(); data.chip_idx = chip_idx as u32; let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; - let fn_ptr = match local_opcode { - DivRemOpcode::DIV => execute_e2_impl::<_, _, DivOp>, - DivRemOpcode::DIVU => execute_e2_impl::<_, _, DivuOp>, - DivRemOpcode::REM => execute_e2_impl::<_, _, RemOp>, - DivRemOpcode::REMU => execute_e2_impl::<_, _, RemuOp>, - }; - Ok(fn_ptr) + dispatch!(execute_e2_impl, local_opcode) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; + dispatch!(execute_e2_tco_handler, local_opcode) } } @@ -128,6 +153,7 @@ unsafe fn execute_e12_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -136,6 +162,7 @@ unsafe fn execute_e1_impl execute_e12_impl::(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl( pre_compute: &[u8], vm_state: &mut VmExecState, diff --git a/extensions/rv32im/circuit/src/hintstore/execution.rs b/extensions/rv32im/circuit/src/hintstore/execution.rs index 2e87cc9cd9..41ab992243 100644 --- a/extensions/rv32im/circuit/src/hintstore/execution.rs +++ b/extensions/rv32im/circuit/src/hintstore/execution.rs @@ -1,4 +1,7 @@ -use std::borrow::{Borrow, BorrowMut}; +use std::{ + borrow::{Borrow, BorrowMut}, + mem::size_of, +}; use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; use openvm_circuit_primitives_derive::AlignedBytesBorrow; @@ -57,6 +60,15 @@ impl Rv32HintStoreExecutor { } } +macro_rules! dispatch { + ($execute_impl:ident, $local_opcode:ident) => { + match $local_opcode { + HINT_STOREW => Ok($execute_impl::<_, _, true>), + HINT_BUFFER => Ok($execute_impl::<_, _, false>), + } + }; +} + impl Executor for Rv32HintStoreExecutor where F: PrimeField32, @@ -74,11 +86,22 @@ where ) -> Result, StaticProgramError> { let pre_compute: &mut HintStorePreCompute = data.borrow_mut(); let local_opcode = self.pre_compute_impl(pc, inst, pre_compute)?; - let fn_ptr = match local_opcode { - HINT_STOREW => execute_e1_impl::<_, _, true>, - HINT_BUFFER => execute_e1_impl::<_, _, false>, - }; - Ok(fn_ptr) + dispatch!(execute_e1_impl, local_opcode) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut HintStorePreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, pre_compute)?; + dispatch!(execute_e1_tco_handler, local_opcode) } } @@ -103,11 +126,24 @@ where let pre_compute: &mut E2PreCompute = data.borrow_mut(); pre_compute.chip_idx = chip_idx as u32; let local_opcode = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; - let fn_ptr = match local_opcode { - HINT_STOREW => execute_e2_impl::<_, _, true>, - HINT_BUFFER => execute_e2_impl::<_, _, false>, - }; - Ok(fn_ptr) + dispatch!(execute_e2_impl, local_opcode) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + dispatch!(execute_e2_tco_handler, local_opcode) } } @@ -154,6 +190,7 @@ unsafe fn execute_e12_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -162,6 +199,7 @@ unsafe fn execute_e1_impl(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl< F: PrimeField32, CTX: MeteredExecutionCtxTrait, diff --git a/extensions/rv32im/circuit/src/jal_lui/execution.rs b/extensions/rv32im/circuit/src/jal_lui/execution.rs index 129fe32202..6f61b5d0e4 100644 --- a/extensions/rv32im/circuit/src/jal_lui/execution.rs +++ b/extensions/rv32im/circuit/src/jal_lui/execution.rs @@ -3,13 +3,7 @@ use std::{ mem::size_of, }; -use openvm_circuit::{ - arch::{ - E2PreCompute, ExecuteFunc, ExecutionCtxTrait, Executor, MeteredExecutionCtxTrait, - MeteredExecutor, StaticProgramError, VmExecState, - }, - system::memory::online::GuestMemory, -}; +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; use openvm_circuit_primitives_derive::AlignedBytesBorrow; use openvm_instructions::{ instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, LocalOpcode, @@ -49,6 +43,17 @@ impl Rv32JalLuiExecutor { } } +macro_rules! dispatch { + ($execute_impl:ident, $is_jal:ident, $enabled:ident) => { + match ($is_jal, $enabled) { + (true, true) => Ok($execute_impl::<_, _, true, true>), + (true, false) => Ok($execute_impl::<_, _, true, false>), + (false, true) => Ok($execute_impl::<_, _, false, true>), + (false, false) => Ok($execute_impl::<_, _, false, false>), + } + }; +} + impl Executor for Rv32JalLuiExecutor where F: PrimeField32, @@ -66,13 +71,22 @@ where ) -> Result, StaticProgramError> { let data: &mut JalLuiPreCompute = data.borrow_mut(); let (is_jal, enabled) = self.pre_compute_impl(inst, data)?; - let fn_ptr = match (is_jal, enabled) { - (true, true) => execute_e1_impl::<_, _, true, true>, - (true, false) => execute_e1_impl::<_, _, true, false>, - (false, true) => execute_e1_impl::<_, _, false, true>, - (false, false) => execute_e1_impl::<_, _, false, false>, - }; - Ok(fn_ptr) + dispatch!(execute_e1_impl, is_jal, enabled) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut JalLuiPreCompute = data.borrow_mut(); + let (is_jal, enabled) = self.pre_compute_impl(inst, data)?; + dispatch!(execute_e1_tco_handler, is_jal, enabled) } } @@ -97,13 +111,24 @@ where let data: &mut E2PreCompute = data.borrow_mut(); data.chip_idx = chip_idx as u32; let (is_jal, enabled) = self.pre_compute_impl(inst, &mut data.data)?; - let fn_ptr = match (is_jal, enabled) { - (true, true) => execute_e2_impl::<_, _, true, true>, - (true, false) => execute_e2_impl::<_, _, true, false>, - (false, true) => execute_e2_impl::<_, _, false, true>, - (false, false) => execute_e2_impl::<_, _, false, false>, - }; - Ok(fn_ptr) + dispatch!(execute_e2_impl, is_jal, enabled) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let (is_jal, enabled) = self.pre_compute_impl(inst, &mut data.data)?; + dispatch!(execute_e2_tco_handler, is_jal, enabled) } } @@ -138,6 +163,7 @@ unsafe fn execute_e12_impl< vm_state.instret += 1; } +#[create_tco_handler] unsafe fn execute_e1_impl< F: PrimeField32, CTX: ExecutionCtxTrait, @@ -151,6 +177,7 @@ unsafe fn execute_e1_impl< execute_e12_impl::(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl< F: PrimeField32, CTX: MeteredExecutionCtxTrait, diff --git a/extensions/rv32im/circuit/src/jalr/execution.rs b/extensions/rv32im/circuit/src/jalr/execution.rs index 8eb09de03c..e84e200eec 100644 --- a/extensions/rv32im/circuit/src/jalr/execution.rs +++ b/extensions/rv32im/circuit/src/jalr/execution.rs @@ -3,13 +3,7 @@ use std::{ mem::size_of, }; -use openvm_circuit::{ - arch::{ - E2PreCompute, ExecuteFunc, ExecutionCtxTrait, Executor, MeteredExecutionCtxTrait, - MeteredExecutor, StaticProgramError, VmExecState, - }, - system::memory::online::GuestMemory, -}; +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; use openvm_circuit_primitives_derive::AlignedBytesBorrow; use openvm_instructions::{ instruction::Instruction, @@ -50,6 +44,16 @@ impl Rv32JalrExecutor { } } +macro_rules! dispatch { + ($execute_impl:ident, $enabled:ident) => { + if $enabled { + Ok($execute_impl::<_, _, true>) + } else { + Ok($execute_impl::<_, _, false>) + } + }; +} + impl Executor for Rv32JalrExecutor where F: PrimeField32, @@ -67,12 +71,22 @@ where ) -> Result, StaticProgramError> { let data: &mut JalrPreCompute = data.borrow_mut(); let enabled = self.pre_compute_impl(pc, inst, data)?; - let fn_ptr = if enabled { - execute_e1_impl::<_, _, true> - } else { - execute_e1_impl::<_, _, false> - }; - Ok(fn_ptr) + dispatch!(execute_e1_impl, enabled) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut JalrPreCompute = data.borrow_mut(); + let enabled = self.pre_compute_impl(pc, inst, data)?; + dispatch!(execute_e1_tco_handler, enabled) } } @@ -97,12 +111,24 @@ where let data: &mut E2PreCompute = data.borrow_mut(); data.chip_idx = chip_idx as u32; let enabled = self.pre_compute_impl(pc, inst, &mut data.data)?; - let fn_ptr = if enabled { - execute_e2_impl::<_, _, true> - } else { - execute_e2_impl::<_, _, false> - }; - Ok(fn_ptr) + dispatch!(execute_e2_impl, enabled) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let enabled = self.pre_compute_impl(pc, inst, &mut data.data)?; + dispatch!(execute_e2_tco_handler, enabled) } } @@ -126,6 +152,7 @@ unsafe fn execute_e12_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -134,6 +161,7 @@ unsafe fn execute_e1_impl(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl( pre_compute: &[u8], vm_state: &mut VmExecState, diff --git a/extensions/rv32im/circuit/src/less_than/execution.rs b/extensions/rv32im/circuit/src/less_than/execution.rs index 16c11377e5..7f3560f1e1 100644 --- a/extensions/rv32im/circuit/src/less_than/execution.rs +++ b/extensions/rv32im/circuit/src/less_than/execution.rs @@ -3,13 +3,7 @@ use std::{ mem::size_of, }; -use openvm_circuit::{ - arch::{ - E2PreCompute, ExecuteFunc, ExecutionCtxTrait, Executor, MeteredExecutionCtxTrait, - MeteredExecutor, StaticProgramError, VmExecState, - }, - system::memory::online::GuestMemory, -}; +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; use openvm_circuit_primitives_derive::AlignedBytesBorrow; use openvm_instructions::{ instruction::Instruction, @@ -71,6 +65,17 @@ impl LessThanExecutor { + match ($is_imm, $is_sltu) { + (true, true) => Ok($execute_impl::<_, _, true, true>), + (true, false) => Ok($execute_impl::<_, _, true, false>), + (false, true) => Ok($execute_impl::<_, _, false, true>), + (false, false) => Ok($execute_impl::<_, _, false, false>), + } + }; +} + impl Executor for LessThanExecutor where @@ -90,13 +95,22 @@ where ) -> Result, StaticProgramError> { let pre_compute: &mut LessThanPreCompute = data.borrow_mut(); let (is_imm, is_sltu) = self.pre_compute_impl(pc, inst, pre_compute)?; - let fn_ptr = match (is_imm, is_sltu) { - (true, true) => execute_e1_impl::<_, _, true, true>, - (true, false) => execute_e1_impl::<_, _, true, false>, - (false, true) => execute_e1_impl::<_, _, false, true>, - (false, false) => execute_e1_impl::<_, _, false, false>, - }; - Ok(fn_ptr) + dispatch!(execute_e1_impl, is_imm, is_sltu) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut LessThanPreCompute = data.borrow_mut(); + let (is_imm, is_sltu) = self.pre_compute_impl(pc, inst, pre_compute)?; + dispatch!(execute_e1_tco_handler, is_imm, is_sltu) } } @@ -122,13 +136,24 @@ where let pre_compute: &mut E2PreCompute = data.borrow_mut(); pre_compute.chip_idx = chip_idx as u32; let (is_imm, is_sltu) = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; - let fn_ptr = match (is_imm, is_sltu) { - (true, true) => execute_e2_impl::<_, _, true, true>, - (true, false) => execute_e2_impl::<_, _, true, false>, - (false, true) => execute_e2_impl::<_, _, false, true>, - (false, false) => execute_e2_impl::<_, _, false, false>, - }; - Ok(fn_ptr) + dispatch!(execute_e2_impl, is_imm, is_sltu) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + let (is_imm, is_sltu) = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + dispatch!(execute_e2_tco_handler, is_imm, is_sltu) } } @@ -160,6 +185,7 @@ unsafe fn execute_e12_impl< vm_state.instret += 1; } +#[create_tco_handler] unsafe fn execute_e1_impl< F: PrimeField32, CTX: ExecutionCtxTrait, @@ -173,6 +199,7 @@ unsafe fn execute_e1_impl< execute_e12_impl::(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl< F: PrimeField32, CTX: MeteredExecutionCtxTrait, diff --git a/extensions/rv32im/circuit/src/lib.rs b/extensions/rv32im/circuit/src/lib.rs index 6224c0450a..38a6c55747 100644 --- a/extensions/rv32im/circuit/src/lib.rs +++ b/extensions/rv32im/circuit/src/lib.rs @@ -1,3 +1,5 @@ +#![cfg_attr(feature = "tco", allow(incomplete_features))] +#![cfg_attr(feature = "tco", feature(explicit_tail_calls))] use openvm_circuit::{ arch::{ AirInventory, ChipInventoryError, InitFileGenerator, MatrixRecordArena, SystemConfig, diff --git a/extensions/rv32im/circuit/src/load_sign_extend/execution.rs b/extensions/rv32im/circuit/src/load_sign_extend/execution.rs index 43f11a33a7..f8d5686c48 100644 --- a/extensions/rv32im/circuit/src/load_sign_extend/execution.rs +++ b/extensions/rv32im/circuit/src/load_sign_extend/execution.rs @@ -5,10 +5,7 @@ use std::{ }; use openvm_circuit::{ - arch::{ - E2PreCompute, ExecuteFunc, ExecutionCtxTrait, ExecutionError, Executor, - MeteredExecutionCtxTrait, MeteredExecutor, StaticProgramError, VmExecState, - }, + arch::*, system::memory::{online::GuestMemory, POINTER_MAX_BITS}, }; use openvm_circuit_primitives_derive::AlignedBytesBorrow; @@ -80,6 +77,17 @@ impl LoadSignExtendExecutor { + match ($is_loadb, $enabled) { + (true, true) => Ok($execute_impl::<_, _, true, true>), + (true, false) => Ok($execute_impl::<_, _, true, false>), + (false, true) => Ok($execute_impl::<_, _, false, true>), + (false, false) => Ok($execute_impl::<_, _, false, false>), + } + }; +} + impl Executor for LoadSignExtendExecutor where @@ -98,13 +106,22 @@ where ) -> Result, StaticProgramError> { let pre_compute: &mut LoadSignExtendPreCompute = data.borrow_mut(); let (is_loadb, enabled) = self.pre_compute_impl(pc, inst, pre_compute)?; - let fn_ptr = match (is_loadb, enabled) { - (true, true) => execute_e1_impl::<_, _, true, true>, - (true, false) => execute_e1_impl::<_, _, true, false>, - (false, true) => execute_e1_impl::<_, _, false, true>, - (false, false) => execute_e1_impl::<_, _, false, false>, - }; - Ok(fn_ptr) + dispatch!(execute_e1_impl, is_loadb, enabled) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut LoadSignExtendPreCompute = data.borrow_mut(); + let (is_loadb, enabled) = self.pre_compute_impl(pc, inst, pre_compute)?; + dispatch!(execute_e1_tco_handler, is_loadb, enabled) } } @@ -130,13 +147,24 @@ where let pre_compute: &mut E2PreCompute = data.borrow_mut(); pre_compute.chip_idx = chip_idx as u32; let (is_loadb, enabled) = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; - let fn_ptr = match (is_loadb, enabled) { - (true, true) => execute_e2_impl::<_, _, true, true>, - (true, false) => execute_e2_impl::<_, _, true, false>, - (false, true) => execute_e2_impl::<_, _, false, true>, - (false, false) => execute_e2_impl::<_, _, false, false>, - }; - Ok(fn_ptr) + dispatch!(execute_e2_impl, is_loadb, enabled) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + let (is_loadb, enabled) = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + dispatch!(execute_e2_tco_handler, is_loadb, enabled) } } @@ -185,6 +213,7 @@ unsafe fn execute_e12_impl< vm_state.instret += 1; } +#[create_tco_handler] unsafe fn execute_e1_impl< F: PrimeField32, CTX: ExecutionCtxTrait, @@ -198,6 +227,7 @@ unsafe fn execute_e1_impl< execute_e12_impl::(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl< F: PrimeField32, CTX: MeteredExecutionCtxTrait, diff --git a/extensions/rv32im/circuit/src/loadstore/execution.rs b/extensions/rv32im/circuit/src/loadstore/execution.rs index 4d718c579e..c79c9beb32 100644 --- a/extensions/rv32im/circuit/src/loadstore/execution.rs +++ b/extensions/rv32im/circuit/src/loadstore/execution.rs @@ -1,6 +1,7 @@ use std::{ borrow::{Borrow, BorrowMut}, fmt::Debug, + mem::size_of, }; use openvm_circuit::{ @@ -82,6 +83,32 @@ impl LoadStoreExecutor { } } +macro_rules! dispatch { + ($execute_impl:ident, $local_opcode:ident, $enabled:ident, $is_native_store:ident) => { + match ($local_opcode, $enabled, $is_native_store) { + (LOADW, true, _) => Ok($execute_impl::<_, _, U8, LoadWOp, true>), + (LOADW, false, _) => Ok($execute_impl::<_, _, U8, LoadWOp, false>), + (LOADHU, true, _) => Ok($execute_impl::<_, _, U8, LoadHUOp, true>), + (LOADHU, false, _) => Ok($execute_impl::<_, _, U8, LoadHUOp, false>), + (LOADBU, true, _) => Ok($execute_impl::<_, _, U8, LoadBUOp, true>), + (LOADBU, false, _) => Ok($execute_impl::<_, _, U8, LoadBUOp, false>), + (STOREW, true, false) => Ok($execute_impl::<_, _, U8, StoreWOp, true>), + (STOREW, false, false) => Ok($execute_impl::<_, _, U8, StoreWOp, false>), + (STOREW, true, true) => Ok($execute_impl::<_, _, F, StoreWOp, true>), + (STOREW, false, true) => Ok($execute_impl::<_, _, F, StoreWOp, false>), + (STOREH, true, false) => Ok($execute_impl::<_, _, U8, StoreHOp, true>), + (STOREH, false, false) => Ok($execute_impl::<_, _, U8, StoreHOp, false>), + (STOREH, true, true) => Ok($execute_impl::<_, _, F, StoreHOp, true>), + (STOREH, false, true) => Ok($execute_impl::<_, _, F, StoreHOp, false>), + (STOREB, true, false) => Ok($execute_impl::<_, _, U8, StoreBOp, true>), + (STOREB, false, false) => Ok($execute_impl::<_, _, U8, StoreBOp, false>), + (STOREB, true, true) => Ok($execute_impl::<_, _, F, StoreBOp, true>), + (STOREB, false, true) => Ok($execute_impl::<_, _, F, StoreBOp, false>), + (_, _, _) => unreachable!(), + } + }; +} + impl Executor for LoadStoreExecutor where F: PrimeField32, @@ -101,28 +128,28 @@ where let pre_compute: &mut LoadStorePreCompute = data.borrow_mut(); let (local_opcode, enabled, is_native_store) = self.pre_compute_impl(pc, inst, pre_compute)?; - let fn_ptr = match (local_opcode, enabled, is_native_store) { - (LOADW, true, _) => execute_e1_impl::<_, _, U8, LoadWOp, true>, - (LOADW, false, _) => execute_e1_impl::<_, _, U8, LoadWOp, false>, - (LOADHU, true, _) => execute_e1_impl::<_, _, U8, LoadHUOp, true>, - (LOADHU, false, _) => execute_e1_impl::<_, _, U8, LoadHUOp, false>, - (LOADBU, true, _) => execute_e1_impl::<_, _, U8, LoadBUOp, true>, - (LOADBU, false, _) => execute_e1_impl::<_, _, U8, LoadBUOp, false>, - (STOREW, true, false) => execute_e1_impl::<_, _, U8, StoreWOp, true>, - (STOREW, false, false) => execute_e1_impl::<_, _, U8, StoreWOp, false>, - (STOREW, true, true) => execute_e1_impl::<_, _, F, StoreWOp, true>, - (STOREW, false, true) => execute_e1_impl::<_, _, F, StoreWOp, false>, - (STOREH, true, false) => execute_e1_impl::<_, _, U8, StoreHOp, true>, - (STOREH, false, false) => execute_e1_impl::<_, _, U8, StoreHOp, false>, - (STOREH, true, true) => execute_e1_impl::<_, _, F, StoreHOp, true>, - (STOREH, false, true) => execute_e1_impl::<_, _, F, StoreHOp, false>, - (STOREB, true, false) => execute_e1_impl::<_, _, U8, StoreBOp, true>, - (STOREB, false, false) => execute_e1_impl::<_, _, U8, StoreBOp, false>, - (STOREB, true, true) => execute_e1_impl::<_, _, F, StoreBOp, true>, - (STOREB, false, true) => execute_e1_impl::<_, _, F, StoreBOp, false>, - (_, _, _) => unreachable!(), - }; - Ok(fn_ptr) + dispatch!(execute_e1_impl, local_opcode, enabled, is_native_store) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut LoadStorePreCompute = data.borrow_mut(); + let (local_opcode, enabled, is_native_store) = + self.pre_compute_impl(pc, inst, pre_compute)?; + dispatch!( + execute_e1_tco_handler, + local_opcode, + enabled, + is_native_store + ) } } @@ -148,28 +175,30 @@ where pre_compute.chip_idx = chip_idx as u32; let (local_opcode, enabled, is_native_store) = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; - let fn_ptr = match (local_opcode, enabled, is_native_store) { - (LOADW, true, _) => execute_e2_impl::<_, _, U8, LoadWOp, true>, - (LOADW, false, _) => execute_e2_impl::<_, _, U8, LoadWOp, false>, - (LOADHU, true, _) => execute_e2_impl::<_, _, U8, LoadHUOp, true>, - (LOADHU, false, _) => execute_e2_impl::<_, _, U8, LoadHUOp, false>, - (LOADBU, true, _) => execute_e2_impl::<_, _, U8, LoadBUOp, true>, - (LOADBU, false, _) => execute_e2_impl::<_, _, U8, LoadBUOp, false>, - (STOREW, true, false) => execute_e2_impl::<_, _, U8, StoreWOp, true>, - (STOREW, false, false) => execute_e2_impl::<_, _, U8, StoreWOp, false>, - (STOREW, true, true) => execute_e2_impl::<_, _, F, StoreWOp, true>, - (STOREW, false, true) => execute_e2_impl::<_, _, F, StoreWOp, false>, - (STOREH, true, false) => execute_e2_impl::<_, _, U8, StoreHOp, true>, - (STOREH, false, false) => execute_e2_impl::<_, _, U8, StoreHOp, false>, - (STOREH, true, true) => execute_e2_impl::<_, _, F, StoreHOp, true>, - (STOREH, false, true) => execute_e2_impl::<_, _, F, StoreHOp, false>, - (STOREB, true, false) => execute_e2_impl::<_, _, U8, StoreBOp, true>, - (STOREB, false, false) => execute_e2_impl::<_, _, U8, StoreBOp, false>, - (STOREB, true, true) => execute_e2_impl::<_, _, F, StoreBOp, true>, - (STOREB, false, true) => execute_e2_impl::<_, _, F, StoreBOp, false>, - (_, _, _) => unreachable!(), - }; - Ok(fn_ptr) + dispatch!(execute_e2_impl, local_opcode, enabled, is_native_store) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + let (local_opcode, enabled, is_native_store) = + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + dispatch!( + execute_e2_tco_handler, + local_opcode, + enabled, + is_native_store + ) } } @@ -226,6 +255,7 @@ unsafe fn execute_e12_impl< vm_state.instret += 1; } +#[create_tco_handler] unsafe fn execute_e1_impl< F: PrimeField32, CTX: ExecutionCtxTrait, @@ -240,6 +270,7 @@ unsafe fn execute_e1_impl< execute_e12_impl::(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl< F: PrimeField32, CTX: MeteredExecutionCtxTrait, diff --git a/extensions/rv32im/circuit/src/mul/execution.rs b/extensions/rv32im/circuit/src/mul/execution.rs index 73376d8f98..e254434780 100644 --- a/extensions/rv32im/circuit/src/mul/execution.rs +++ b/extensions/rv32im/circuit/src/mul/execution.rs @@ -3,13 +3,7 @@ use std::{ mem::size_of, }; -use openvm_circuit::{ - arch::{ - E2PreCompute, ExecuteFunc, ExecutionCtxTrait, Executor, MeteredExecutionCtxTrait, - MeteredExecutor, StaticProgramError, VmExecState, - }, - system::memory::online::GuestMemory, -}; +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; use openvm_circuit_primitives_derive::AlignedBytesBorrow; use openvm_instructions::{ instruction::Instruction, @@ -75,6 +69,21 @@ where self.pre_compute_impl(pc, inst, pre_compute)?; Ok(execute_e1_impl) } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut MultiPreCompute = data.borrow_mut(); + self.pre_compute_impl(pc, inst, pre_compute)?; + Ok(execute_e1_tco_handler) + } } impl MeteredExecutor @@ -101,6 +110,23 @@ where self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; Ok(execute_e2_impl) } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + Ok(execute_e2_tco_handler) + } } #[inline(always)] @@ -121,6 +147,7 @@ unsafe fn execute_e12_impl( vm_state.instret += 1; } +#[create_tco_handler] unsafe fn execute_e1_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -129,6 +156,7 @@ unsafe fn execute_e1_impl( execute_e12_impl(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl( pre_compute: &[u8], vm_state: &mut VmExecState, diff --git a/extensions/rv32im/circuit/src/mulh/execution.rs b/extensions/rv32im/circuit/src/mulh/execution.rs index 1818a63080..03f127e93c 100644 --- a/extensions/rv32im/circuit/src/mulh/execution.rs +++ b/extensions/rv32im/circuit/src/mulh/execution.rs @@ -3,13 +3,7 @@ use std::{ mem::size_of, }; -use openvm_circuit::{ - arch::{ - E2PreCompute, ExecuteFunc, ExecutionCtxTrait, Executor, MeteredExecutionCtxTrait, - MeteredExecutor, StaticProgramError, VmExecState, - }, - system::memory::online::GuestMemory, -}; +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; use openvm_circuit_primitives_derive::AlignedBytesBorrow; use openvm_instructions::{ instruction::Instruction, @@ -48,6 +42,16 @@ impl MulHExecutor { + match $local_opcode { + MulHOpcode::MULH => Ok($execute_impl::<_, _, MulHOp>), + MulHOpcode::MULHSU => Ok($execute_impl::<_, _, MulHSuOp>), + MulHOpcode::MULHU => Ok($execute_impl::<_, _, MulHUOp>), + } + }; +} + impl Executor for MulHExecutor where @@ -67,12 +71,22 @@ where ) -> Result, StaticProgramError> { let pre_compute: &mut MulHPreCompute = data.borrow_mut(); let local_opcode = self.pre_compute_impl(inst, pre_compute)?; - let fn_ptr = match local_opcode { - MulHOpcode::MULH => execute_e1_impl::<_, _, MulHOp>, - MulHOpcode::MULHSU => execute_e1_impl::<_, _, MulHSuOp>, - MulHOpcode::MULHU => execute_e1_impl::<_, _, MulHUOp>, - }; - Ok(fn_ptr) + dispatch!(execute_e1_impl, local_opcode) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut MulHPreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(inst, pre_compute)?; + dispatch!(execute_e1_tco_handler, local_opcode) } } @@ -98,12 +112,24 @@ where let pre_compute: &mut E2PreCompute = data.borrow_mut(); pre_compute.chip_idx = chip_idx as u32; let local_opcode = self.pre_compute_impl(inst, &mut pre_compute.data)?; - let fn_ptr = match local_opcode { - MulHOpcode::MULH => execute_e2_impl::<_, _, MulHOp>, - MulHOpcode::MULHSU => execute_e2_impl::<_, _, MulHSuOp>, - MulHOpcode::MULHU => execute_e2_impl::<_, _, MulHUOp>, - }; - Ok(fn_ptr) + dispatch!(execute_e2_impl, local_opcode) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(inst, &mut pre_compute.data)?; + dispatch!(execute_e2_tco_handler, local_opcode) } } @@ -123,6 +149,7 @@ unsafe fn execute_e12_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -131,6 +158,7 @@ unsafe fn execute_e1_impl(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl( pre_compute: &[u8], vm_state: &mut VmExecState, diff --git a/extensions/rv32im/circuit/src/shift/execution.rs b/extensions/rv32im/circuit/src/shift/execution.rs index b756f8b768..cacf3b9f2e 100644 --- a/extensions/rv32im/circuit/src/shift/execution.rs +++ b/extensions/rv32im/circuit/src/shift/execution.rs @@ -3,13 +3,7 @@ use std::{ mem::size_of, }; -use openvm_circuit::{ - arch::{ - E2PreCompute, ExecuteFunc, ExecutionCtxTrait, Executor, MeteredExecutionCtxTrait, - MeteredExecutor, StaticProgramError, VmExecState, - }, - system::memory::online::GuestMemory, -}; +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; use openvm_circuit_primitives_derive::AlignedBytesBorrow; use openvm_instructions::{ instruction::Instruction, @@ -65,6 +59,19 @@ impl ShiftExecutor { + match ($is_imm, $shift_opcode) { + (true, ShiftOpcode::SLL) => Ok($execute_impl::<_, _, true, SllOp>), + (false, ShiftOpcode::SLL) => Ok($execute_impl::<_, _, false, SllOp>), + (true, ShiftOpcode::SRL) => Ok($execute_impl::<_, _, true, SrlOp>), + (false, ShiftOpcode::SRL) => Ok($execute_impl::<_, _, false, SrlOp>), + (true, ShiftOpcode::SRA) => Ok($execute_impl::<_, _, true, SraOp>), + (false, ShiftOpcode::SRA) => Ok($execute_impl::<_, _, false, SraOp>), + } + }; +} + impl Executor for ShiftExecutor where @@ -74,7 +81,6 @@ where size_of::() } - #[inline(always)] fn pre_compute( &self, pc: u32, @@ -84,15 +90,23 @@ where let data: &mut ShiftPreCompute = data.borrow_mut(); let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, data)?; // `d` is always expected to be RV32_REGISTER_AS. - let fn_ptr = match (is_imm, shift_opcode) { - (true, ShiftOpcode::SLL) => execute_e1_impl::<_, _, true, SllOp>, - (false, ShiftOpcode::SLL) => execute_e1_impl::<_, _, false, SllOp>, - (true, ShiftOpcode::SRL) => execute_e1_impl::<_, _, true, SrlOp>, - (false, ShiftOpcode::SRL) => execute_e1_impl::<_, _, false, SrlOp>, - (true, ShiftOpcode::SRA) => execute_e1_impl::<_, _, true, SraOp>, - (false, ShiftOpcode::SRA) => execute_e1_impl::<_, _, false, SraOp>, - }; - Ok(fn_ptr) + dispatch!(execute_e1_impl, is_imm, shift_opcode) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut ShiftPreCompute = data.borrow_mut(); + let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, data)?; + // `d` is always expected to be RV32_REGISTER_AS. + dispatch!(execute_e1_tco_handler, is_imm, shift_opcode) } } @@ -105,7 +119,6 @@ where size_of::>() } - #[inline(always)] fn metered_pre_compute( &self, chip_idx: usize, @@ -117,15 +130,22 @@ where data.chip_idx = chip_idx as u32; let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, &mut data.data)?; // `d` is always expected to be RV32_REGISTER_AS. - let fn_ptr = match (is_imm, shift_opcode) { - (true, ShiftOpcode::SLL) => execute_e2_impl::<_, _, true, SllOp>, - (false, ShiftOpcode::SLL) => execute_e2_impl::<_, _, false, SllOp>, - (true, ShiftOpcode::SRL) => execute_e2_impl::<_, _, true, SrlOp>, - (false, ShiftOpcode::SRL) => execute_e2_impl::<_, _, false, SrlOp>, - (true, ShiftOpcode::SRA) => execute_e2_impl::<_, _, true, SraOp>, - (false, ShiftOpcode::SRA) => execute_e2_impl::<_, _, false, SraOp>, - }; - Ok(fn_ptr) + dispatch!(execute_e2_impl, is_imm, shift_opcode) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, &mut data.data)?; + // `d` is always expected to be RV32_REGISTER_AS. + dispatch!(execute_e2_tco_handler, is_imm, shift_opcode) } } @@ -155,6 +175,7 @@ unsafe fn execute_e12_impl< state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); } +#[create_tco_handler] unsafe fn execute_e1_impl< F: PrimeField32, CTX: ExecutionCtxTrait, @@ -168,6 +189,7 @@ unsafe fn execute_e1_impl< execute_e12_impl::(pre_compute, state); } +#[create_tco_handler] unsafe fn execute_e2_impl< F: PrimeField32, CTX: MeteredExecutionCtxTrait, diff --git a/extensions/rv32im/tests/Cargo.toml b/extensions/rv32im/tests/Cargo.toml index 45eb4c1654..412c8cffb0 100644 --- a/extensions/rv32im/tests/Cargo.toml +++ b/extensions/rv32im/tests/Cargo.toml @@ -25,3 +25,4 @@ strum.workspace = true [features] default = ["parallel"] parallel = ["openvm-circuit/parallel"] +tco = ["openvm-rv32im-circuit/tco"] diff --git a/extensions/sha256/circuit/Cargo.toml b/extensions/sha256/circuit/Cargo.toml index 413265b622..5cdfb143c0 100644 --- a/extensions/sha256/circuit/Cargo.toml +++ b/extensions/sha256/circuit/Cargo.toml @@ -32,6 +32,7 @@ openvm-circuit = { workspace = true, features = ["test-utils"] } default = ["parallel", "jemalloc"] parallel = ["openvm-circuit/parallel"] test-utils = ["openvm-circuit/test-utils"] +tco = ["openvm-rv32im-circuit/tco"] # performance features: mimalloc = ["openvm-circuit/mimalloc"] jemalloc = ["openvm-circuit/jemalloc"] diff --git a/extensions/sha256/circuit/src/lib.rs b/extensions/sha256/circuit/src/lib.rs index 741cf3ec9d..7e1676702d 100644 --- a/extensions/sha256/circuit/src/lib.rs +++ b/extensions/sha256/circuit/src/lib.rs @@ -1,3 +1,6 @@ +#![cfg_attr(feature = "tco", allow(incomplete_features))] +#![cfg_attr(feature = "tco", feature(explicit_tail_calls))] + use std::result::Result; use openvm_circuit::{ diff --git a/extensions/sha256/circuit/src/sha256_chip/execution.rs b/extensions/sha256/circuit/src/sha256_chip/execution.rs index befbb25f41..33b40a59c3 100644 --- a/extensions/sha256/circuit/src/sha256_chip/execution.rs +++ b/extensions/sha256/circuit/src/sha256_chip/execution.rs @@ -23,6 +23,21 @@ struct ShaPreCompute { } impl Executor for Sha256VmExecutor { + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut ShaPreCompute = data.borrow_mut(); + self.pre_compute_impl(pc, inst, data)?; + Ok(execute_e1_tco_handler::<_, _>) + } + fn pre_compute_size(&self) -> usize { size_of::() } @@ -61,6 +76,23 @@ impl MeteredExecutor for Sha256VmExecutor { self.pre_compute_impl(pc, inst, &mut data.data)?; Ok(execute_e2_impl::<_, _>) } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + self.pre_compute_impl(pc, inst, &mut data.data)?; + Ok(execute_e2_tco_handler::<_, _>) + } } unsafe fn execute_e12_impl( @@ -105,6 +137,7 @@ unsafe fn execute_e12_impl( pre_compute: &[u8], vm_state: &mut VmExecState, @@ -112,6 +145,7 @@ unsafe fn execute_e1_impl( let pre_compute: &ShaPreCompute = pre_compute.borrow(); execute_e12_impl::(pre_compute, vm_state); } +#[create_tco_handler] unsafe fn execute_e2_impl( pre_compute: &[u8], vm_state: &mut VmExecState, diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 8825102061..651e7fa7e6 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,5 @@ [toolchain] channel = "1.86.0" +# To use the "tco" feature, switch to Rust nightly: +# channel = "nightly-2025-08-19" components = ["clippy", "rustfmt"]