Skip to content

Commit d61478b

Browse files
test: test @cache macro
1 parent 7e86beb commit d61478b

File tree

3 files changed

+161
-1
lines changed

3 files changed

+161
-1
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ ExproniconLite = "0.10.14"
5151
LabelledArrays = "1.5"
5252
MultivariatePolynomials = "0.5"
5353
NaNMath = "0.3, 1.1.2"
54+
OhMyThreads = "0.7"
5455
ReverseDiff = "1"
5556
Setfield = "0.7, 0.8, 1"
5657
SpecialFunctions = "0.10, 1.0, 2"
@@ -67,6 +68,7 @@ julia = "1.10"
6768
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
6869
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
6970
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
71+
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
7072
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
7173
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
7274
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -77,4 +79,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7779
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
7880

7981
[targets]
80-
test = ["BenchmarkTools", "Documenter", "LabelledArrays", "Pkg", "PkgBenchmark", "Random", "ReferenceTests", "ReverseDiff", "SafeTestsets", "Test", "Zygote"]
82+
test = ["BenchmarkTools", "Documenter", "LabelledArrays", "Pkg", "PkgBenchmark", "Random", "ReferenceTests", "ReverseDiff", "SafeTestsets", "Test", "Zygote", "OhMyThreads"]

test/cache_macro.jl

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
using SymbolicUtils
2+
using SymbolicUtils: BasicSymbolic, @cache, associated_cache, set_limit!, get_limit,
3+
clear_cache!, SymbolicKey, metadata, maketerm
4+
using OhMyThreads: tmap
5+
6+
@cache function f1(x::BasicSymbolic)::BasicSymbolic
7+
return 2x + 1
8+
end
9+
10+
@testset "::BasicSymbolic" begin
11+
@syms x
12+
val = f1(x)
13+
@test isequal(val, 2x + 1)
14+
cachestruct = associated_cache(f1)
15+
cache, stats = cachestruct.tlv[]
16+
@test cache isa Dict{Tuple{SymbolicKey}, BasicSymbolic}
17+
@test length(cache) == 1
18+
@test cache[(SymbolicKey(objectid(x)),)] === val
19+
@test stats.hits == 0
20+
@test stats.misses == 1
21+
f1(x)
22+
@test stats.hits == 1
23+
@test stats.misses == 1
24+
25+
xx = setmetadata(x, Int, 0)
26+
val = f1(xx)
27+
@test length(cache) == 2
28+
@test stats.misses == 2
29+
30+
set_limit!(f1, 10)
31+
@test get_limit(f1) == 10
32+
for i in 1:8
33+
xx = setmetadata(xx, Int, i)
34+
f1(xx)
35+
@test length(cache) == i + 2
36+
end
37+
xx = setmetadata(xx, Int, 9)
38+
f1(xx)
39+
@test length(cache) < 10
40+
@test stats.clears == 1
41+
42+
hits = stats.hits
43+
misses = stats.misses
44+
len = length(cache)
45+
46+
@syms x::Float64 # different symtype
47+
val = f1(x)
48+
@test length(cache) == len + 1
49+
@test stats.hits == hits
50+
@test stats.misses == misses + 1
51+
@test f1(x) === val
52+
@test stats.hits == hits + 1
53+
54+
clear_cache!(f1)
55+
@test length(cache) == 0
56+
stats = SymbolicUtils.get_stats(f1)
57+
@test stats.hits == stats.misses == stats.clears == 0
58+
SymbolicUtils.set_retain_fraction!(f1, 0.1)
59+
@test SymbolicUtils.get_retain_fraction(f1) == 0.1
60+
@test SymbolicUtils.is_caching_enabled(f1)
61+
SymbolicUtils.toggle_caching!(f1, false)
62+
@test !SymbolicUtils.is_caching_enabled(f1)
63+
f1(x)
64+
@test isempty(cache)
65+
@test stats.hits == stats.misses == stats.clears == 0
66+
end
67+
68+
@cache function f2(x::Union{BasicSymbolic, UInt})::Union{BasicSymbolic, UInt}
69+
return 2x + 1
70+
end
71+
72+
@testset "::Union (with `UInt`)" begin
73+
@syms x
74+
val = f2(x)
75+
@test isequal(val, 2x + 1)
76+
cachestruct = associated_cache(f2)
77+
cache, stats = cachestruct.tlv[]
78+
@test cache isa Dict{Tuple{Union{SymbolicKey, UInt}}, Union{BasicSymbolic, UInt}}
79+
@test length(cache) == 1
80+
@test cache[(SymbolicKey(objectid(x)),)] === val
81+
@test stats.hits == 0
82+
@test stats.misses == 1
83+
f2(x)
84+
@test stats.hits == 1
85+
@test stats.misses == 1
86+
87+
y = objectid(x)
88+
val = f2(y)
89+
@test val == 2y + 1
90+
@test length(cache) == 2
91+
@test cache[(y,)] == val
92+
@test stats.misses == 2
93+
94+
clear_cache!(f2)
95+
@test length(cache) == 0
96+
@test stats.hits == stats.misses == stats.clears == 0
97+
end
98+
99+
@cache function f3(x)::Union{BasicSymbolic, Int}
100+
return 2x + 1
101+
end
102+
103+
@testset "::Any" begin
104+
@syms x
105+
val = f3(x)
106+
@test isequal(val, 2x + 1)
107+
cachestruct = associated_cache(f3)
108+
cache, stats = cachestruct.tlv[]
109+
@test cache isa Dict{Tuple{Any}, Union{BasicSymbolic, Int}}
110+
@test length(cache) == 1
111+
@test cache[(SymbolicKey(objectid(x)),)] === val
112+
@test stats.hits == 0
113+
@test stats.misses == 1
114+
f3(x)
115+
@test stats.hits == 1
116+
@test stats.misses == 1
117+
118+
val = f3(3)
119+
@test val == 7
120+
@test length(cache) == 2
121+
@test stats.misses == 2
122+
123+
clear_cache!(f3)
124+
@test length(cache) == 0
125+
@test stats.hits == stats.misses == stats.clears == 0
126+
end
127+
128+
@cache function f4(x::Union{BasicSymbolic, Int})::Union{BasicSymbolic, Int}
129+
x isa Number && return x
130+
if iscall(x)
131+
return maketerm(typeof(x), operation(x), map(f4, arguments(x)), metadata(x))
132+
end
133+
return f3(x)
134+
end
135+
136+
@testset "Threading" begin
137+
@syms x y z
138+
@test isequal(f4(2x + 1), 2(2x + 1) + 1)
139+
140+
function build_rand_expr(vars, depth, maxdepth)
141+
if depth < maxdepth
142+
v = build_rand_expr(vars, depth + 1, maxdepth)
143+
else
144+
v = rand(vars)
145+
end
146+
if isodd(depth)
147+
return v + rand([1:3; vars])
148+
else
149+
return v * rand([1:3; vars])
150+
end
151+
end
152+
153+
exprs = [build_rand_expr([x, y, z], 0, 100) for _ in 1:1000]
154+
result = tmap(f4, exprs)
155+
truevals = map(f4, exprs)
156+
@test isequal(result, truevals)
157+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@ using Pkg, Test, SafeTestsets
1717
@safetestset "Fuzz" begin include("fuzz.jl") end
1818
@safetestset "Adjoints" begin include("adjoints.jl") end
1919
@safetestset "Hash Consing" begin include("hash_consing.jl") end
20+
@safetestset "Cache macro" begin include("cache_macro.jl") end
2021
end
2122
end

0 commit comments

Comments
 (0)