Skip to content

Commit 75ee54e

Browse files
committed
Introduces Memo
and `Signal` constructs to handle reactive patterns and dependency tracking. Simplified macro implementations to utilize the new caching mechanism.
1 parent 40fcf58 commit 75ee54e

File tree

7 files changed

+152
-63
lines changed

7 files changed

+152
-63
lines changed

cache/src/cache.rs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,35 @@
11
use lru::LruCache;
22
use std::{any::Any, num::NonZeroUsize, rc::Rc};
33

4-
use crate::OperatorFunc;
4+
use crate::Observable;
5+
6+
type CacheKey = *const dyn Observable;
57

68
const CACHE_CAP: usize = 128;
79

8-
static mut CACHE: Option<LruCache<OperatorFunc, Rc<dyn Any>>> = None;
10+
static mut CACHE: Option<LruCache<CacheKey, Rc<dyn Any>>> = None;
911

10-
fn cache() -> &'static mut LruCache<OperatorFunc, Rc<dyn Any>> {
12+
fn cache() -> &'static mut LruCache<CacheKey, Rc<dyn Any>> {
1113
#[allow(static_mut_refs)]
1214
unsafe {
1315
CACHE.get_or_insert_with(|| LruCache::new(NonZeroUsize::new(CACHE_CAP).unwrap()))
1416
}
1517
}
1618

17-
pub fn touch<T: 'static>(key: OperatorFunc) -> Option<Rc<T>> {
19+
pub fn touch<T: 'static>(key: &'static dyn Observable) -> Option<Rc<T>> {
1820
cache()
19-
.get(&key)
21+
.get(&(key as _))
2022
.map(Rc::clone)
2123
.filter(|rc| rc.is::<T>())
2224
.map(|rc| unsafe { Rc::from_raw(Rc::into_raw(rc) as *const T) })
2325
}
2426

25-
pub fn store_in_cache<T: 'static>(key: OperatorFunc, val: T) -> Rc<T> {
27+
pub fn store_in_cache<T: 'static>(key: &'static dyn Observable, val: T) -> Rc<T> {
2628
let rc = Rc::new(val);
2729
cache().put(key, Rc::clone(&rc) as _);
2830
rc
2931
}
3032

31-
pub fn remove_from_cache(key: OperatorFunc) {
32-
cache().pop(&key);
33-
}
33+
pub fn remove_from_cache(key: &'static dyn Observable) {
34+
cache().pop(&(key as _));
35+
}

cache/src/call_stack.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
1-
use crate::OperatorFunc;
1+
use crate::Observable;
22

3-
static mut CALL_STACK: Option<Vec<OperatorFunc>> = None;
3+
static mut CALL_STACK: Option<Vec<&'static dyn Observable>> = None;
44

5-
fn call_stack() -> &'static mut Vec<OperatorFunc> {
5+
fn call_stack() -> &'static mut Vec<&'static dyn Observable> {
66
#[allow(static_mut_refs)]
77
unsafe {
88
CALL_STACK.get_or_insert_with(Vec::new)
99
}
1010
}
1111

12-
pub fn push(op: OperatorFunc) {
12+
pub fn push(op: &'static dyn Observable) {
1313
call_stack().push(op)
1414
}
1515

16-
pub fn last() -> Option<&'static OperatorFunc> {
16+
pub fn last() -> Option<&'static &'static dyn Observable> {
1717
call_stack().last()
1818
}
1919

20-
pub fn pop() -> Option<OperatorFunc> {
20+
pub fn pop() -> Option<&'static dyn Observable> {
2121
call_stack().pop()
22-
}
22+
}

cache/src/lib.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,15 @@ pub enum MemoOperator {
88
Pop,
99
}
1010

11-
pub type OperatorFunc = fn(MemoOperator);
12-
1311
pub mod cache;
1412
pub mod call_stack;
13+
pub mod memo;
14+
pub mod signal;
1515

1616
pub use cache::{remove_from_cache, store_in_cache, touch};
17+
pub use memo::Memo;
18+
pub use signal::Signal;
19+
20+
pub trait Observable {
21+
fn invalidate(&'static self);
22+
}

cache/src/memo.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
use crate::{Observable, call_stack, remove_from_cache, store_in_cache, touch};
2+
3+
pub struct Memo<T, F>
4+
where
5+
T: Clone,
6+
F: Fn() -> T,
7+
{
8+
f: F,
9+
dependents: Vec<&'static dyn Observable>,
10+
}
11+
12+
impl<T, F> Observable for Memo<T, F>
13+
where
14+
T: Clone,
15+
F: Fn() -> T,
16+
{
17+
fn invalidate(&'static self) {
18+
remove_from_cache(self);
19+
self.dependents
20+
.iter()
21+
.for_each(|dependent| dependent.invalidate());
22+
}
23+
}
24+
25+
impl<T, F> Memo<T, F>
26+
where
27+
T: Clone,
28+
F: Fn() -> T,
29+
{
30+
pub fn new(f: F) -> Self {
31+
Memo {
32+
f,
33+
dependents: vec![],
34+
}
35+
}
36+
37+
pub fn get(&'static mut self) -> T {
38+
if let Some(last) = call_stack::last() {
39+
self.dependents.push(*last);
40+
}
41+
call_stack::push(self);
42+
43+
let rc = if let Some(rc) = touch(self) {
44+
rc
45+
} else {
46+
let result: T = (self.f)();
47+
store_in_cache(self, result)
48+
};
49+
50+
call_stack::pop();
51+
52+
(*rc).clone()
53+
}
54+
}

cache/src/signal.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
use crate::{Observable, call_stack};
2+
3+
pub struct Signal<T, F>
4+
where
5+
T: Eq + Clone,
6+
F: Fn() -> T,
7+
{
8+
value: T,
9+
f: F,
10+
dependents: Vec<&'static dyn Observable>,
11+
}
12+
13+
impl<T, F> Observable for Signal<T, F>
14+
where
15+
T: Eq + Clone,
16+
F: Fn() -> T,
17+
{
18+
fn invalidate(&'static self) {
19+
self.dependents
20+
.iter()
21+
.for_each(|dependent| dependent.invalidate());
22+
}
23+
}
24+
25+
impl<T, F> Signal<T, F>
26+
where
27+
T: Eq + Clone,
28+
F: Fn() -> T,
29+
{
30+
pub fn new(f: F) -> Self {
31+
let value = f();
32+
Signal {
33+
value,
34+
f,
35+
dependents: vec![],
36+
}
37+
}
38+
39+
pub fn get(&'static mut self) -> T {
40+
if let Some(last) = call_stack::last() {
41+
self.dependents.push(*last);
42+
}
43+
44+
let result: T = (self.f)();
45+
self.set(result.clone());
46+
47+
result
48+
}
49+
50+
pub fn set(&'static mut self, value: T) -> bool {
51+
if self.value == value {
52+
return false;
53+
}
54+
55+
self.value = value;
56+
57+
self.invalidate();
58+
59+
true
60+
}
61+
}

macros/src/lib.rs

Lines changed: 6 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use proc_macro::TokenStream;
22
use quote::{format_ident, quote};
3-
use syn::{ItemFn, ReturnType, parse_macro_input, parse_quote};
3+
use syn::{ItemFn, ReturnType, parse_macro_input};
44

55
#[proc_macro_attribute]
66
pub fn memo(_attr: TokenStream, item: TokenStream) -> TokenStream {
@@ -29,53 +29,16 @@ pub fn memo(_attr: TokenStream, item: TokenStream) -> TokenStream {
2929
.into();
3030
}
3131

32-
let op_ident = format_ident!("{}_op", ident);
33-
let mut op_sig = sig.clone();
34-
op_sig.ident = op_ident.clone();
35-
op_sig
36-
.inputs
37-
.insert(0, parse_quote! { op: cache::MemoOperator });
38-
op_sig.output = parse_quote! { -> () };
32+
let memo_ident = format_ident!("{}", ident.to_string().to_uppercase());
3933

4034
let expanded = quote! {
35+
static mut #memo_ident: Option<cache::Memo<#output_ty, fn() -> #output_ty>> = None;
36+
4137
#vis #sig
4238
where #output_ty: Clone + 'static
4339
{
44-
#op_ident(cache::MemoOperator::Memo(cache::Trace::Push));
45-
46-
let key: cache::OperatorFunc = #op_ident;
47-
let rc = if let Some(rc) = cache::touch(key) {
48-
rc
49-
} else {
50-
let result: #output_ty = (|| #block)();
51-
cache::store_in_cache(key, result)
52-
};
53-
54-
#op_ident(cache::MemoOperator::Memo(cache::Trace::Pop));
55-
56-
(*rc).clone()
57-
}
58-
59-
#vis #op_sig
60-
{
61-
static mut dependents: Vec<cache::OperatorFunc> = Vec::new();
62-
match op {
63-
cache::MemoOperator::Memo(cache::Trace::Push) => {
64-
if let Some(last) = cache::call_stack::last() {
65-
unsafe { dependents.push(last.clone()) };
66-
}
67-
cache::call_stack::push(#op_ident);
68-
},
69-
cache::MemoOperator::Memo(cache::Trace::Pop) => {
70-
cache::call_stack::pop();
71-
},
72-
cache::MemoOperator::Pop => {
73-
for dependent in unsafe { dependents.iter() } {
74-
cache::remove_from_cache(*dependent);
75-
dependent(cache::MemoOperator::Pop);
76-
}
77-
},
78-
}
40+
#[allow(static_mut_refs)]
41+
unsafe { &mut #memo_ident }.get_or_insert_with(|| cache::Memo::new(|| #block)).get()
7942
}
8043
};
8144

macros/tests/dep.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use cache::Observable as _;
12
use macros::memo;
23

34
static mut SOURCE_A_CALLED: bool = false;
@@ -70,8 +71,10 @@ fn complex_dependency_memo_test() {
7071
assert_eq!(d2, d1);
7172
assert_eq!(e2, e1);
7273

73-
cache::remove_from_cache(source_a_op);
74-
source_a_op(cache::MemoOperator::Pop);
74+
#[allow(static_mut_refs)]
75+
if let Some(memo) = unsafe { &SOURCE_A } {
76+
memo.invalidate();
77+
}
7578

7679
unsafe { SOURCE_A_CALLED = false };
7780
unsafe { SOURCE_C_CALLED = false };

0 commit comments

Comments
 (0)