|
1 | 1 | use proc_macro::TokenStream; |
2 | 2 | use quote::{format_ident, quote}; |
3 | | -use syn::{ItemFn, ReturnType, parse_macro_input, parse_quote}; |
| 3 | +use syn::{ItemFn, ReturnType, parse_macro_input}; |
4 | 4 |
|
5 | 5 | #[proc_macro_attribute] |
6 | 6 | pub fn memo(_attr: TokenStream, item: TokenStream) -> TokenStream { |
@@ -29,53 +29,16 @@ pub fn memo(_attr: TokenStream, item: TokenStream) -> TokenStream { |
29 | 29 | .into(); |
30 | 30 | } |
31 | 31 |
|
32 | | - let op_ident = format_ident!("{}_op", ident); |
33 | | - let mut op_sig = sig.clone(); |
34 | | - op_sig.ident = op_ident.clone(); |
35 | | - op_sig |
36 | | - .inputs |
37 | | - .insert(0, parse_quote! { op: cache::MemoOperator }); |
38 | | - op_sig.output = parse_quote! { -> () }; |
| 32 | + let memo_ident = format_ident!("{}", ident.to_string().to_uppercase()); |
39 | 33 |
|
40 | 34 | let expanded = quote! { |
| 35 | + static mut #memo_ident: Option<cache::Memo<#output_ty, fn() -> #output_ty>> = None; |
| 36 | + |
41 | 37 | #vis #sig |
42 | 38 | where #output_ty: Clone + 'static |
43 | 39 | { |
44 | | - #op_ident(cache::MemoOperator::Memo(cache::Trace::Push)); |
45 | | - |
46 | | - let key: cache::OperatorFunc = #op_ident; |
47 | | - let rc = if let Some(rc) = cache::touch(key) { |
48 | | - rc |
49 | | - } else { |
50 | | - let result: #output_ty = (|| #block)(); |
51 | | - cache::store_in_cache(key, result) |
52 | | - }; |
53 | | - |
54 | | - #op_ident(cache::MemoOperator::Memo(cache::Trace::Pop)); |
55 | | - |
56 | | - (*rc).clone() |
57 | | - } |
58 | | - |
59 | | - #vis #op_sig |
60 | | - { |
61 | | - static mut dependents: Vec<cache::OperatorFunc> = Vec::new(); |
62 | | - match op { |
63 | | - cache::MemoOperator::Memo(cache::Trace::Push) => { |
64 | | - if let Some(last) = cache::call_stack::last() { |
65 | | - unsafe { dependents.push(last.clone()) }; |
66 | | - } |
67 | | - cache::call_stack::push(#op_ident); |
68 | | - }, |
69 | | - cache::MemoOperator::Memo(cache::Trace::Pop) => { |
70 | | - cache::call_stack::pop(); |
71 | | - }, |
72 | | - cache::MemoOperator::Pop => { |
73 | | - for dependent in unsafe { dependents.iter() } { |
74 | | - cache::remove_from_cache(*dependent); |
75 | | - dependent(cache::MemoOperator::Pop); |
76 | | - } |
77 | | - }, |
78 | | - } |
| 40 | + #[allow(static_mut_refs)] |
| 41 | + unsafe { &mut #memo_ident }.get_or_insert_with(|| cache::Memo::new(|| #block)).get() |
79 | 42 | } |
80 | 43 | }; |
81 | 44 |
|
|
0 commit comments