Skip to content

Commit 13cc6f2

Browse files
committed
feat: add macro to generate tco handler and update interpreter for tco
1 parent 4d9e261 commit 13cc6f2

File tree

14 files changed

+364
-13
lines changed

14 files changed

+364
-13
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ dashmap = "6.1.0"
229229
memmap2 = "0.9.5"
230230
libc = "0.2.175"
231231
tracing-subscriber = { version = "0.3.17", features = ["std", "env-filter"] }
232+
paste = "1.0.15"
232233

233234
# default-features = false for no_std for use in guest programs
234235
itertools = { version = "0.14.0", default-features = false }

crates/vm/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ openvm-native-compiler.workspace = true
5050
openvm-rv32im-transpiler.workspace = true
5151

5252
[features]
53-
default = ["parallel", "jemalloc"]
53+
default = ["parallel", "jemalloc", "tco"]
5454
parallel = [
5555
"openvm-stark-backend/parallel",
5656
"dashmap/rayon",

crates/vm/derive/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ license.workspace = true
1010
proc-macro = true
1111

1212
[dependencies]
13-
syn = { version = "2.0", features = ["parsing"] }
13+
syn = { version = "2.0", features = ["parsing", "full"] }
1414
quote = "1.0"
1515
proc-macro2 = "1.0"
1616
itertools = { workspace = true }

crates/vm/derive/src/lib.rs

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ use syn::{
99
GenericParam, Ident, Meta, Token,
1010
};
1111

12+
mod tco;
13+
1214
#[proc_macro_derive(PreflightExecutor)]
1315
pub fn preflight_executor_derive(input: TokenStream) -> TokenStream {
1416
let ast: syn::DeriveInput = syn::parse(input).unwrap();
@@ -172,6 +174,18 @@ pub fn executor_derive(input: TokenStream) -> TokenStream {
172174
Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, {
173175
self.0.pre_compute(pc, inst, data)
174176
}
177+
178+
#[cfg(feature = "tco")]
179+
fn handler<Ctx>(
180+
&self,
181+
pc: u32,
182+
inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
183+
data: &mut [u8],
184+
) -> Result<::openvm_circuit::arch::Handler<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
185+
where
186+
Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, {
187+
self.0.handler(pc, inst, data)
188+
}
175189
}
176190
}
177191
.into()
@@ -205,18 +219,21 @@ pub fn executor_derive(input: TokenStream) -> TokenStream {
205219
});
206220
// Use full path ::openvm_circuit... so it can be used either within or outside the vm
207221
// crate. Assume F is already generic of the field.
208-
let (pre_compute_size_arms, pre_compute_arms, where_predicates): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(|(variant_name, field)| {
222+
let (pre_compute_size_arms, pre_compute_arms, handler_arms, where_predicates): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(|(variant_name, field)| {
209223
let field_ty = &field.ty;
210224
let pre_compute_size_arm = quote! {
211225
#name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::Executor<#first_ty_generic>>::pre_compute_size(x)
212226
};
213227
let pre_compute_arm = quote! {
214228
#name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::Executor<#first_ty_generic>>::pre_compute(x, pc, instruction, data)
215229
};
230+
let handler_arm = quote! {
231+
#name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::Executor<#first_ty_generic>>::handler(x, pc, instruction, data)
232+
};
216233
let where_predicate = syn::parse_quote! {
217234
#field_ty: ::openvm_circuit::arch::Executor<#first_ty_generic>
218235
};
219-
(pre_compute_size_arm, pre_compute_arm, where_predicate)
236+
(pre_compute_size_arm, pre_compute_arm, handler_arm, where_predicate)
220237
}));
221238
let where_clause = new_generics.make_where_clause();
222239
for predicate in where_predicates {
@@ -247,6 +264,20 @@ pub fn executor_derive(input: TokenStream) -> TokenStream {
247264
#(#pre_compute_arms,)*
248265
}
249266
}
267+
268+
#[cfg(feature = "tco")]
269+
fn handler<Ctx>(
270+
&self,
271+
pc: u32,
272+
instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
273+
data: &mut [u8],
274+
) -> Result<::openvm_circuit::arch::Handler<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
275+
where
276+
Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, {
277+
match self {
278+
#(#handler_arms,)*
279+
}
280+
}
250281
}
251282
}
252283
.into()
@@ -501,7 +532,7 @@ fn generate_config_traits_impl(name: &Ident, inner: &DataStruct) -> syn::Result<
501532
.iter()
502533
.filter(|f| f.attrs.iter().any(|attr| attr.path().is_ident("config")))
503534
.exactly_one()
504-
.clone()
535+
.ok()
505536
.expect("Exactly one field must have the #[config] attribute");
506537
let (source_name, source_name_upper) =
507538
gen_name_with_uppercase_idents(source_field.ident.as_ref().unwrap());
@@ -700,3 +731,30 @@ fn parse_executor_type(
700731
})
701732
}
702733
}
734+
735+
/// An attribute procedural macro for creating TCO (Tail Call Optimization) handlers.
736+
///
737+
/// This macro generates a handler function that wraps an execute implementation
738+
/// with tail call optimization using the `become` keyword. It extracts the generics
739+
/// and where clauses from the original function.
740+
///
741+
/// # Usage
742+
///
743+
/// Place this attribute above a function definition:
744+
/// ```
745+
/// #[create_tco_handler = "handler_name"]
746+
/// unsafe fn execute_e1_impl<F: PrimeField32, CTX, const B_IS_IMM: bool>(
747+
/// pre_compute: &[u8],
748+
/// state: &mut VmExecState<F, GuestMemory, CTX>,
749+
/// ) where
750+
/// CTX: ExecutionCtxTrait,
751+
/// {
752+
/// // function body
753+
/// }
754+
/// ```
755+
///
756+
/// This will generate a TCO handler function with the same generics and where clauses.
757+
#[proc_macro_attribute]
758+
pub fn create_tco_handler(_attr: TokenStream, item: TokenStream) -> TokenStream {
759+
tco::tco_impl(item)
760+
}

crates/vm/derive/src/tco.rs

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
use proc_macro::TokenStream;
2+
use quote::{format_ident, quote};
3+
use syn::{parse_macro_input, ItemFn};
4+
5+
/// Implementation of the TCO handler generation logic.
6+
/// This is called from the proc macro attribute in lib.rs.
7+
pub fn tco_impl(item: TokenStream) -> TokenStream {
8+
// Parse the input function
9+
let input_fn = parse_macro_input!(item as ItemFn);
10+
11+
// Extract information from the function
12+
let fn_name = &input_fn.sig.ident;
13+
let generics = &input_fn.sig.generics;
14+
let where_clause = &generics.where_clause;
15+
16+
// Extract the first two generic type parameters (F and CTX)
17+
let (f_type, ctx_type) = extract_f_and_ctx_types(generics);
18+
// Derive new function name:
19+
// If original ends with `_impl`, replace with `_tco_handler`, else append suffix.
20+
let new_name_str = fn_name
21+
.to_string()
22+
.strip_suffix("_impl")
23+
.map(|base| format!("{base}_tco_handler"))
24+
.unwrap_or_else(|| format!("{fn_name}_tco_handler"));
25+
let handler_name = format_ident!("{}", new_name_str);
26+
27+
// Build the generic parameters for the handler, preserving all original generics
28+
let handler_generics = generics.clone();
29+
30+
// Build the function call with all the generics
31+
let generic_args = build_generic_args(generics);
32+
let execute_call = if generic_args.is_empty() {
33+
quote! { #fn_name(pre_compute, exec_state) }
34+
} else {
35+
quote! { #fn_name::<#(#generic_args),*>(pre_compute, exec_state) }
36+
};
37+
38+
// Generate the TCO handler function
39+
let handler_fn = quote! {
40+
#[cfg(feature = "tco")]
41+
#[inline(never)]
42+
unsafe fn #handler_name #handler_generics (
43+
interpreter: &::openvm_circuit::arch::interpreter::InterpretedInstance<#f_type, #ctx_type>,
44+
exec_state: &mut ::openvm_circuit::arch::VmExecState<
45+
#f_type,
46+
::openvm_circuit::system::memory::online::GuestMemory,
47+
#ctx_type,
48+
>,
49+
) -> Result<(), ::openvm_circuit::arch::ExecutionError>
50+
#where_clause
51+
{
52+
let pre_compute = interpreter.get_pre_compute(exec_state.pc);
53+
#execute_call;
54+
55+
if std::hint::unlikely(exec_state.exit_code.is_err()) {
56+
return Err(::openvm_circuit::arch::ExecutionError::ExecStateError);
57+
}
58+
if std::hint::unlikely(exec_state.exit_code.as_ref().unwrap().is_some()) {
59+
// terminate
60+
return Ok(());
61+
}
62+
// exec_state.pc should have been updated by execute_impl at this point
63+
let next_handler = interpreter.get_handler(exec_state.pc)?;
64+
become next_handler(interpreter, exec_state)
65+
}
66+
};
67+
68+
// Return both the original function and the new handler
69+
let output = quote! {
70+
#input_fn
71+
72+
#handler_fn
73+
};
74+
75+
TokenStream::from(output)
76+
}
77+
78+
fn extract_f_and_ctx_types(generics: &syn::Generics) -> (syn::Ident, syn::Ident) {
79+
let mut type_params = generics.params.iter().filter_map(|param| {
80+
if let syn::GenericParam::Type(type_param) = param {
81+
Some(&type_param.ident)
82+
} else {
83+
None
84+
}
85+
});
86+
87+
let f_type = type_params
88+
.next()
89+
.expect("Function must have at least one type parameter (F)")
90+
.clone();
91+
let ctx_type = type_params
92+
.next()
93+
.expect("Function must have at least two type parameters (F and CTX)")
94+
.clone();
95+
96+
(f_type, ctx_type)
97+
}
98+
99+
fn build_generic_args(generics: &syn::Generics) -> Vec<proc_macro2::TokenStream> {
100+
generics
101+
.params
102+
.iter()
103+
.map(|param| match param {
104+
syn::GenericParam::Type(type_param) => {
105+
let ident = &type_param.ident;
106+
quote! { #ident }
107+
}
108+
syn::GenericParam::Lifetime(lifetime) => {
109+
let lifetime = &lifetime.lifetime;
110+
quote! { #lifetime }
111+
}
112+
syn::GenericParam::Const(const_param) => {
113+
let ident = &const_param.ident;
114+
quote! { #ident }
115+
}
116+
})
117+
.collect()
118+
}

crates/vm/src/arch/execution.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ use serde::{Deserialize, Serialize};
1212
use thiserror::Error;
1313

1414
use super::{execution_mode::ExecutionCtxTrait, Streams, VmExecState};
15+
#[cfg(feature = "tco")]
16+
use crate::arch::interpreter::InterpretedInstance;
1517
#[cfg(feature = "metrics")]
1618
use crate::metrics::VmMetrics;
1719
use crate::{
@@ -72,6 +74,9 @@ pub enum ExecutionError {
7274
Inventory(#[from] ExecutorInventoryError),
7375
#[error("static program error: {0}")]
7476
Static(#[from] StaticProgramError),
77+
// Placeholder error type for tco
78+
#[error("error in VmExecState")]
79+
ExecStateError,
7580
}
7681

7782
/// Errors in the program that can be statically analyzed before runtime.
@@ -91,7 +96,20 @@ pub enum StaticProgramError {
9196
/// The `pre_compute: &[u8]` is a pre-computed buffer of data corresponding to a single instruction.
9297
/// The contents of `pre_compute` are determined from the program code as specified by the
9398
/// [Executor] and [MeteredExecutor] traits.
94-
pub type ExecuteFunc<F, CTX> = unsafe fn(&[u8], &mut VmExecState<F, GuestMemory, CTX>);
99+
pub type ExecuteFunc<F, CTX> =
100+
unsafe fn(pre_compute: &[u8], exec_state: &mut VmExecState<F, GuestMemory, CTX>);
101+
102+
/// Handler for tail call elimination. The `CTX` is assumed to contain pointers to the pre-computed
103+
/// buffer and the function handler table.
104+
///
105+
/// - `pre_compute_buf` is the starting pointer of the pre-computed buffer.
106+
/// - `handlers` is the starting pointer of the table of function pointers of `Handler` type. The
107+
/// pointer is typeless to avoid self-referential types.
108+
#[cfg(feature = "tco")]
109+
pub type Handler<F, CTX> = unsafe fn(
110+
interpreter: &InterpretedInstance<F, CTX>,
111+
exec_state: &mut VmExecState<F, GuestMemory, CTX>,
112+
) -> Result<(), ExecutionError>;
95113

96114
/// Trait for pure execution via a host interpreter. The trait methods provide the methods to
97115
/// pre-process the program code into function pointers which operate on `pre_compute` instruction
@@ -108,6 +126,20 @@ pub trait Executor<F> {
108126
) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
109127
where
110128
Ctx: ExecutionCtxTrait;
129+
130+
/// Returns a function pointer with tail call optimization. The handler function assumes that
131+
/// the pre-compute buffer it receives is the populated `data`.
132+
// NOTE: we could have used `pre_compute` above to populate `data`, but the implementations were
133+
// simpler to keep `handler` entirely separate from `pre_compute`.
134+
#[cfg(feature = "tco")]
135+
fn handler<Ctx>(
136+
&self,
137+
pc: u32,
138+
inst: &Instruction<F>,
139+
data: &mut [u8],
140+
) -> Result<Handler<F, Ctx>, StaticProgramError>
141+
where
142+
Ctx: ExecutionCtxTrait;
111143
}
112144

113145
/// Trait for metered execution via a host interpreter. The trait methods provide the methods to

crates/vm/src/arch/execution_mode/pure.rs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,6 @@ impl ExecutionCtx {
1919
}
2020
}
2121

22-
impl Default for ExecutionCtx {
23-
fn default() -> Self {
24-
Self::new(None)
25-
}
26-
}
27-
2822
impl ExecutionCtxTrait for ExecutionCtx {
2923
#[inline(always)]
3024
fn on_memory_operation(&mut self, _address_space: u32, _ptr: u32, _size: u32) {}

0 commit comments

Comments
 (0)