Skip to content

Commit fca9af7

Browse files
committed
Enhances the memoization mechanism by introducing a traceable operator stack to manage dependencies and cache invalidation.
- Replaces simple integer keys with function pointers for better traceability. - Adds `MemoOperator` enums to handle stack operations. - Adds functionality to remove dependent entries from the cache when necessary.
1 parent b248693 commit fca9af7

File tree

3 files changed

+69
-17
lines changed

3 files changed

+69
-17
lines changed

cache/src/lib.rs

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,52 @@
11
use lru::LruCache;
22
use std::{any::Any, num::NonZeroUsize, rc::Rc};
33

4+
pub enum Trace {
5+
Push,
6+
Pop,
7+
}
8+
49
pub enum MemoOperator {
5-
Memo,
10+
Memo(Trace),
11+
Pop,
612
}
713

14+
pub type OperatorFunc = fn(MemoOperator);
15+
816
const CACHE_CAP: usize = 128;
917

10-
static mut CACHE: Option<LruCache<usize, Rc<dyn Any>>> = None;
18+
static mut CACHE: Option<LruCache<OperatorFunc, Rc<dyn Any>>> = None;
1119

12-
fn cache() -> &'static mut LruCache<usize, Rc<dyn Any>> {
20+
static mut CALL_STACK: Option<Vec<OperatorFunc>> = None;
21+
22+
fn cache() -> &'static mut LruCache<OperatorFunc, Rc<dyn Any>> {
1323
#[allow(static_mut_refs)]
1424
unsafe {
1525
CACHE.get_or_insert_with(|| LruCache::new(NonZeroUsize::new(CACHE_CAP).unwrap()))
1626
}
1727
}
1828

19-
pub fn touch<T: 'static>(key: usize) -> Option<Rc<T>> {
29+
pub fn call_stack() -> &'static mut Vec<OperatorFunc> {
30+
#[allow(static_mut_refs)]
31+
unsafe {
32+
CALL_STACK.get_or_insert_with(|| Vec::new())
33+
}
34+
}
35+
36+
pub fn touch<T: 'static>(key: OperatorFunc) -> Option<Rc<T>> {
2037
cache()
2138
.get(&key)
2239
.map(Rc::clone)
2340
.filter(|rc| rc.is::<T>())
2441
.map(|rc| unsafe { Rc::from_raw(Rc::into_raw(rc) as *const T) })
2542
}
2643

27-
pub fn store_in_cache<T: 'static>(key: usize, val: T) -> Rc<T> {
44+
pub fn store_in_cache<T: 'static>(key: OperatorFunc, val: T) -> Rc<T> {
2845
let rc = Rc::new(val);
2946
cache().put(key, Rc::clone(&rc) as _);
3047
rc
3148
}
49+
50+
pub fn remove_from_cache(key: OperatorFunc) {
51+
cache().pop(&key);
52+
}

macros/src/lib.rs

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ pub fn memo(_attr: TokenStream, item: TokenStream) -> TokenStream {
99
let vis = &func.vis;
1010
let sig = &func.sig;
1111
let block = &func.block;
12+
let ident = &func.sig.ident;
1213

1314
let output_ty = match &sig.output {
1415
ReturnType::Type(_, ty) => ty.clone(),
@@ -28,31 +29,53 @@ pub fn memo(_attr: TokenStream, item: TokenStream) -> TokenStream {
2829
.into();
2930
}
3031

31-
let _ident = format_ident!("_{}", sig.ident);
32-
let mut _sig = sig.clone();
33-
_sig.ident = _ident.clone();
34-
_sig.inputs
32+
let op_ident = format_ident!("{}_op", ident);
33+
let mut op_sig = sig.clone();
34+
op_sig.ident = op_ident.clone();
35+
op_sig
36+
.inputs
3537
.insert(0, parse_quote! { op: cache::MemoOperator });
38+
op_sig.output = parse_quote! { -> () };
3639

3740
let expanded = quote! {
3841
#vis #sig
3942
where #output_ty: Clone + 'static
4043
{
41-
#_ident(cache::MemoOperator::Memo)
42-
}
44+
#op_ident(cache::MemoOperator::Memo(cache::Trace::Push));
4345

44-
#vis #_sig
45-
where #output_ty: Clone + 'static
46-
{
47-
let key = #_ident as usize;
46+
let key: cache::OperatorFunc = #op_ident;
4847
let rc = if let Some(rc) = cache::touch(key) {
4948
rc
5049
} else {
5150
let result: #output_ty = (|| #block)();
5251
cache::store_in_cache(key, result)
5352
};
53+
54+
#op_ident(cache::MemoOperator::Memo(cache::Trace::Pop));
55+
5456
(*rc).clone()
5557
}
58+
59+
#vis #op_sig
60+
{
61+
static mut dependents: Vec<cache::OperatorFunc> = Vec::new();
62+
match op {
63+
cache::MemoOperator::Memo(cache::Trace::Push) => {
64+
if let Some(last) = cache::call_stack().peek_mut() {
65+
unsafe { dependents.push(*last) };
66+
}
67+
cache::call_stack().push(#op_ident);
68+
},
69+
cache::MemoOperator::Memo(cache::Trace::Pop) => {
70+
cache::call_stack().pop();
71+
},
72+
cache::MemoOperator::Pop => {
73+
for dependent in unsafe { dependents.iter() } {
74+
cache::remove_from_cache(*dependent);
75+
}
76+
},
77+
}
78+
}
5679
};
5780

5881
expanded.into()

macros/tests/basic.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
1+
#![feature(vec_peek_mut)]
2+
13
use macros::memo;
24

35
#[memo]
46
pub fn get_number() -> i32 {
5-
println!("Calculate i32");
7+
static mut INVOKED: bool = false;
8+
assert!(!unsafe { INVOKED });
9+
unsafe { INVOKED = true };
10+
611
42
712
}
813

914
#[memo]
1015
pub fn get_text() -> String {
11-
println!("Calculate String");
16+
static mut INVOKED: bool = false;
17+
assert!(!unsafe { INVOKED });
18+
unsafe { INVOKED = true };
19+
1220
"hello".to_string()
1321
}
1422

0 commit comments

Comments
 (0)