Skip to content

Commit 38a1987

Browse files
committed
Fixes dependent cache handling in memoization
add complex_dependency_memo_test
1 parent e097bdf commit 38a1987

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

macros/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ pub fn memo(_attr: TokenStream, item: TokenStream) -> TokenStream {
7272
cache::MemoOperator::Pop => {
7373
for dependent in unsafe { dependents.iter() } {
7474
cache::remove_from_cache(*dependent);
75+
dependent(cache::MemoOperator::Pop);
7576
}
7677
},
7778
}

macros/tests/dep.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#![feature(vec_peek_mut)]
2+
3+
use macros::memo;
4+
5+
static mut SOURCE_A_CALLED: bool = false;
6+
static mut SOURCE_B_CALLED: bool = false;
7+
static mut SOURCE_C_CALLED: bool = false;
8+
static mut SOURCE_D_CALLED: bool = false;
9+
static mut SOURCE_E_CALLED: bool = false;
10+
11+
#[memo]
12+
pub fn source_a() -> i32 {
13+
assert!(!unsafe { SOURCE_A_CALLED });
14+
unsafe { SOURCE_A_CALLED = true };
15+
16+
10
17+
}
18+
19+
#[memo]
20+
pub fn source_b() -> i32 {
21+
assert!(!unsafe { SOURCE_B_CALLED });
22+
unsafe { SOURCE_B_CALLED = true };
23+
24+
5
25+
}
26+
27+
#[memo]
28+
pub fn derived_c() -> i32 {
29+
assert!(!unsafe { SOURCE_C_CALLED });
30+
unsafe { SOURCE_C_CALLED = true };
31+
32+
source_a() + source_b()
33+
}
34+
35+
#[memo]
36+
pub fn derived_d() -> i32 {
37+
assert!(!unsafe { SOURCE_D_CALLED });
38+
unsafe { SOURCE_D_CALLED = true };
39+
40+
derived_c() * 2
41+
}
42+
43+
#[memo]
44+
pub fn derived_e() -> i32 {
45+
assert!(!unsafe { SOURCE_E_CALLED });
46+
unsafe { SOURCE_E_CALLED = true };
47+
48+
source_b() - 3
49+
}
50+
51+
// source_a source_b
52+
// \ / \
53+
// derived_c derived_e
54+
// |
55+
// derived_d
56+
57+
#[test]
58+
fn complex_dependency_memo_test() {
59+
let e1 = derived_e();
60+
let d1 = derived_d();
61+
let c1 = derived_c();
62+
63+
assert_eq!(c1, 15); // 10 + 5
64+
assert_eq!(d1, 30); // 15 * 2
65+
assert_eq!(e1, 2); // 5 - 3
66+
67+
let e2 = derived_e();
68+
let d2 = derived_d();
69+
let c2 = derived_c();
70+
71+
assert_eq!(c2, c1);
72+
assert_eq!(d2, d1);
73+
assert_eq!(e2, e1);
74+
75+
cache::remove_from_cache(source_a_op);
76+
source_a_op(cache::MemoOperator::Pop);
77+
78+
unsafe { SOURCE_A_CALLED = false };
79+
unsafe { SOURCE_C_CALLED = false };
80+
unsafe { SOURCE_D_CALLED = false };
81+
82+
let e3 = derived_e();
83+
let d3 = derived_d();
84+
let c3 = derived_c();
85+
86+
assert_eq!(c3, 15);
87+
assert_eq!(d3, 30);
88+
assert_eq!(e3, e2);
89+
}

0 commit comments

Comments
 (0)