1
1
use proc_macro:: TokenStream ;
2
- use quote:: { format_ident , quote} ;
2
+ use quote:: quote;
3
3
use syn:: { parse_macro_input, ItemFn } ;
4
4
5
+ use crate :: common:: {
6
+ build_generic_args, extract_f_and_ctx_types, handler_name_from_fn, returns_result_type,
7
+ } ;
8
+
5
9
/// Implementation of the TCO handler generation logic.
6
10
/// This is called from the proc macro attribute in lib.rs.
7
11
pub fn tco_impl ( item : TokenStream ) -> TokenStream {
@@ -13,16 +17,15 @@ pub fn tco_impl(item: TokenStream) -> TokenStream {
13
17
let generics = & input_fn. sig . generics ;
14
18
let where_clause = & generics. where_clause ;
15
19
20
+ // Check if function returns Result
21
+ let returns_result = returns_result_type ( & input_fn) ;
22
+
16
23
// Extract the first two generic type parameters (F and CTX)
17
24
let ( f_type, ctx_type) = extract_f_and_ctx_types ( generics) ;
25
+
18
26
// Derive new function name:
19
- // If original ends with `_impl`, replace with `_tco_handler`, else append suffix.
20
- let new_name_str = fn_name
21
- . to_string ( )
22
- . strip_suffix ( "_impl" )
23
- . map ( |base| format ! ( "{base}_tco_handler" ) )
24
- . unwrap_or_else ( || format ! ( "{fn_name}_tco_handler" ) ) ;
25
- let handler_name = format_ident ! ( "{}" , new_name_str) ;
27
+ // If original ends with `_impl`, replace with `_handler`, else append suffix.
28
+ let handler_name = handler_name_from_fn ( fn_name) ;
26
29
27
30
// Build the generic parameters for the handler, preserving all original generics
28
31
let handler_generics = generics. clone ( ) ;
@@ -35,6 +38,21 @@ pub fn tco_impl(item: TokenStream) -> TokenStream {
35
38
quote ! { #fn_name:: <#( #generic_args) , * >( pre_compute, & mut instret, & mut pc, arg, exec_state) }
36
39
} ;
37
40
41
+ // Generate the execute and exit check code based on return type
42
+ let execute_stmt = if returns_result {
43
+ quote ! {
44
+ // Call original impl and wire errors into exit_code.
45
+ let __ret = { #execute_call } ;
46
+ if let :: core:: result:: Result :: Err ( e) = __ret {
47
+ exec_state. set_instret_and_pc( instret, pc) ;
48
+ exec_state. exit_code = :: core:: result:: Result :: Err ( e) ;
49
+ return ;
50
+ }
51
+ }
52
+ } else {
53
+ quote ! { #execute_call; }
54
+ } ;
55
+
38
56
// Generate the TCO handler function
39
57
let handler_fn = quote ! {
40
58
#[ inline( never) ]
@@ -54,12 +72,8 @@ pub fn tco_impl(item: TokenStream) -> TokenStream {
54
72
use :: openvm_circuit:: arch:: ExecutionError ;
55
73
56
74
let pre_compute = interpreter. get_pre_compute( pc) ;
57
- #execute_call ;
75
+ #execute_stmt
58
76
59
- if :: core:: intrinsics:: unlikely( exec_state. exit_code. is_err( ) ) {
60
- exec_state. set_instret_and_pc( instret, pc) ;
61
- return ;
62
- }
63
77
if :: core:: intrinsics:: unlikely( #ctx_type:: should_suspend( instret, pc, arg, exec_state) ) {
64
78
exec_state. set_instret_and_pc( instret, pc) ;
65
79
return ;
@@ -89,45 +103,3 @@ pub fn tco_impl(item: TokenStream) -> TokenStream {
89
103
90
104
TokenStream :: from ( output)
91
105
}
92
-
93
- fn extract_f_and_ctx_types ( generics : & syn:: Generics ) -> ( syn:: Ident , syn:: Ident ) {
94
- let mut type_params = generics. params . iter ( ) . filter_map ( |param| {
95
- if let syn:: GenericParam :: Type ( type_param) = param {
96
- Some ( & type_param. ident )
97
- } else {
98
- None
99
- }
100
- } ) ;
101
-
102
- let f_type = type_params
103
- . next ( )
104
- . expect ( "Function must have at least one type parameter (F)" )
105
- . clone ( ) ;
106
- let ctx_type = type_params
107
- . next ( )
108
- . expect ( "Function must have at least two type parameters (F and CTX)" )
109
- . clone ( ) ;
110
-
111
- ( f_type, ctx_type)
112
- }
113
-
114
- fn build_generic_args ( generics : & syn:: Generics ) -> Vec < proc_macro2:: TokenStream > {
115
- generics
116
- . params
117
- . iter ( )
118
- . map ( |param| match param {
119
- syn:: GenericParam :: Type ( type_param) => {
120
- let ident = & type_param. ident ;
121
- quote ! { #ident }
122
- }
123
- syn:: GenericParam :: Lifetime ( lifetime) => {
124
- let lifetime = & lifetime. lifetime ;
125
- quote ! { #lifetime }
126
- }
127
- syn:: GenericParam :: Const ( const_param) => {
128
- let ident = & const_param. ident ;
129
- quote ! { #ident }
130
- }
131
- } )
132
- . collect ( )
133
- }
0 commit comments