Skip to content

Commit 391f683

Browse files
Merge pull request #709 from AayushSabharwal/as/cache-fix
fix: handle explicit `::Any` annotation in `@cache` macro
2 parents 30b0487 + d243b9f commit 391f683

File tree

3 files changed

+136
-20
lines changed

3 files changed

+136
-20
lines changed

src/cache.jl

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,14 @@ macro cache(args...)
229229
keytypes = []
230230
# The arguments of the workhorse function
231231
argexprs = []
232+
# The name of the variable storing the result of looking up the cache
233+
cache_value_name = :val
234+
# The condition for a cache hit
235+
cache_hit_condition = :(!($cache_value_name isa $CacheSentinel))
236+
# Type of additional data stored with cached result. Used to compare
237+
# equality of `BasicSymbolic` arguments, since `objectid` is a hash.
238+
cache_additional_types = []
239+
cache_additional_values = []
232240

233241
for arg in fn.args
234242
# handle arguments with defaults
@@ -240,18 +248,39 @@ macro cache(args...)
240248
push!(keyexprs, :($arg isa BasicSymbolic ? $SymbolicKey(objectid($arg)) : $arg))
241249
push!(argexprs, arg)
242250
push!(keytypes, Any)
251+
push!(cache_additional_types, Any)
252+
push!(cache_additional_values, arg)
253+
cache_hit_condition = :($cache_hit_condition && (!($arg isa BasicSymbolic) || $arg === $cache_value_name[$(length(cache_additional_values))]))
243254
continue
244255
end
245256
argname, Texpr = arg.args
246257
push!(argexprs, argname)
247258

259+
if Texpr == :Any
260+
# if the type is `Any`, branch on it being a `BasicSymbolic`
261+
push!(keyexprs, :($argname isa BasicSymbolic ? $SymbolicKey(objectid($argname)) : $argname))
262+
push!(keytypes, Any)
263+
push!(cache_additional_types, Any)
264+
push!(cache_additional_values, argname)
265+
cache_hit_condition = :($cache_hit_condition && (!($argname isa BasicSymbolic) || $argname === $cache_value_name[$(length(cache_additional_values))]))
266+
continue
267+
end
268+
248269
# handle Union types that may contain a `BasicSymbolic`
249270
if Meta.isexpr(Texpr, :curly) && Texpr.args[1] == :Union
250271
Texprs = Texpr.args[2:end]
251272
Ts = map(Base.Fix1(Base.eval, __module__), Texprs)
252273
keyTs = map(x -> x <: BasicSymbolic ? SymbolicKey : x, Ts)
274+
maybe_basicsymbolic = any(x -> x <: BasicSymbolic, Ts)
253275
push!(keytypes, Union{keyTs...})
254-
push!(keyexprs, :($argname isa BasicSymbolic ? $SymbolicKey(objectid($argname)) : $argname))
276+
if maybe_basicsymbolic
277+
push!(keyexprs, :($argname isa BasicSymbolic ? $SymbolicKey(objectid($argname)) : $argname))
278+
push!(cache_additional_types, Texpr)
279+
push!(cache_additional_values, argname)
280+
cache_hit_condition = :($cache_hit_condition && (!($argname isa BasicSymbolic) || $argname === $cache_value_name[$(length(cache_additional_values))]))
281+
else
282+
push!(keyexprs, argname)
283+
end
255284
continue
256285
end
257286

@@ -260,6 +289,9 @@ macro cache(args...)
260289
if T <: BasicSymbolic
261290
push!(keytypes, SymbolicKey)
262291
push!(keyexprs, :($SymbolicKey(objectid($argname))))
292+
push!(cache_additional_types, T)
293+
push!(cache_additional_values, argname)
294+
cache_hit_condition = :($cache_hit_condition && $argname === $cache_value_name[$(length(cache_additional_values))])
263295
else
264296
push!(keytypes, T)
265297
push!(keyexprs, argname)
@@ -280,8 +312,11 @@ macro cache(args...)
280312
# construct an expression for the type of the cache keys
281313
keyT = Expr(:curly, Tuple)
282314
append!(keyT.args, keytypes)
315+
valT = Expr(:curly, Tuple)
316+
append!(valT.args, cache_additional_types)
317+
push!(valT.args, rettype)
283318
# the type of the cache
284-
cacheT = :(Dict{$keyT, $rettype})
319+
cacheT = :(Dict{$keyT, $valT})
285320
# type of the `TaskLocalValue`
286321
tlvT = :($(TaskLocalValue){Tuple{$cacheT, $CacheStats}})
287322
# the name of the cache struct
@@ -324,11 +359,11 @@ macro cache(args...)
324359
# look it up
325360
# we use a custom sentinel value since `nothing` is a valid return value
326361
# which we might want to cache
327-
val = $(get)(cachedict, key, $(CacheSentinel)())
328-
if !(val isa $CacheSentinel)
362+
$cache_value_name = $(get)(cachedict, key, $(CacheSentinel)())
363+
if $cache_hit_condition
329364
# cache hit
330365
cachestats.hits += 1
331-
return val
366+
return $cache_value_name[end]
332367
end
333368
# cache miss
334369
cachestats.misses += 1
@@ -339,7 +374,7 @@ macro cache(args...)
339374
$(filter!)($cachename, cachedict)
340375
end
341376
# add to cache
342-
cachedict[key] = val
377+
cachedict[key] = ($(cache_additional_values...), val)
343378
return val
344379
end
345380

src/types.jl

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic})
6969
ScalarSymbolic()
7070
end
7171

72+
"""
73+
$(TYPEDSIGNATURES)
74+
75+
Return the inner `Symbolic` wrapped in a non-symbolic subtype. Defaults to
76+
returning the input as-is.
77+
"""
78+
unwrap(x) = x
79+
7280
function exprtype(x::BasicSymbolic)
7381
@compactified x::BasicSymbolic begin
7482
Term => TERM
@@ -538,10 +546,17 @@ function Sym{T}(name::Symbol; kw...) where {T}
538546
BasicSymbolic(s)
539547
end
540548

549+
function unwrap_arr!(arr)
550+
for i in eachindex(arr)
551+
arr[i] = unwrap(arr[i])
552+
end
553+
end
554+
541555
function Term{T}(f, args; kw...) where T
542556
if eltype(args) !== Any
543557
args = convert(Vector{Any}, args)
544558
end
559+
unwrap_arr!(args)
545560

546561
s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), kw...)
547562
BasicSymbolic(s)
@@ -551,7 +566,16 @@ function Term(f, args; metadata=NO_METADATA)
551566
Term{_promote_symtype(f, args)}(f, args, metadata=metadata)
552567
end
553568

569+
function unwrap_dict(dict)
570+
if any(k -> unwrap(k) !== k, keys(dict))
571+
return typeof(dict)(unwrap(k) => v for (k, v) in dict)
572+
end
573+
return dict
574+
end
575+
554576
function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T
577+
coeff = unwrap(coeff)
578+
dict = unwrap_dict(dict)
555579
if isempty(dict)
556580
return coeff
557581
elseif _iszero(coeff) && length(dict) == 1
@@ -569,6 +593,8 @@ function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T
569593
end
570594

571595
function Mul(T, a, b; metadata=NO_METADATA, kw...)
596+
a = unwrap(a)
597+
b = unwrap_dict(b)
572598
isempty(b) && return a
573599
if _isone(a) && length(b) == 1
574600
pair = first(b)
@@ -613,6 +639,8 @@ function maybe_intcoeff(x)
613639
end
614640

615641
function Div{T}(n, d, simplified=false; metadata=nothing, kwargs...) where {T}
642+
n = unwrap(n)
643+
d = unwrap(d)
616644
if T<:Number && !(T<:SafeReal)
617645
n, d = quick_cancel(n, d)
618646
end
@@ -662,6 +690,8 @@ end
662690
@inline denominators(x) = isdiv(x) ? numerators(x.den) : Any[1]
663691

664692
function Pow{T}(a, b; metadata=NO_METADATA, kwargs...) where {T}
693+
a = unwrap(a)
694+
b = unwrap(b)
665695
_iszero(b) && return 1
666696
_isone(b) && return a
667697
s = Pow{T}(; base=a, exp=b, arguments=[], metadata)
@@ -760,6 +790,7 @@ function makepow(a, b)
760790
end
761791

762792
function term(f, args...; type = nothing)
793+
args = map(unwrap, args)
763794
if type === nothing
764795
T = _promote_symtype(f, args)
765796
else
@@ -894,6 +925,7 @@ Base.ImmutableDict(d::ImmutableDict{K,V}, x, y) where {K, V} = ImmutableDict{K,
894925

895926
assocmeta(d::Dict, ctx, val) = (d=copy(d); d[ctx] = val; d)
896927
function assocmeta(d::Base.ImmutableDict, ctx, val)::ImmutableDict{DataType,Any}
928+
val = unwrap(val)
897929
# optimizations
898930
# If using upto 3 contexts, things stay compact
899931
if isdefined(d, :parent)
@@ -915,7 +947,7 @@ function setmetadata(s::Symbolic, ctx::DataType, val)
915947
@set s.metadata = assocmeta(s.metadata, ctx, val)
916948
else
917949
# fresh Dict
918-
@set s.metadata = Base.ImmutableDict{DataType, Any}(ctx, val)
950+
@set s.metadata = Base.ImmutableDict{DataType, Any}(ctx, unwrap(val))
919951
end
920952
end
921953

@@ -1333,6 +1365,10 @@ function +(a::SN, b::SN)
13331365
end
13341366

13351367
function +(a::Number, b::SN)
1368+
tmp = unwrap(a)
1369+
if tmp !== a
1370+
return tmp + b
1371+
end
13361372
!issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata
13371373
iszero(a) && return b
13381374
if isadd(b)
@@ -1399,6 +1435,10 @@ function *(a::SN, b::SN)
13991435
end
14001436

14011437
function *(a::Number, b::SN)
1438+
tmp = unwrap(a)
1439+
if tmp !== a
1440+
return tmp * b
1441+
end
14021442
!issafecanon(*, b) && return term(*, a, b)
14031443
if iszero(a)
14041444
a
@@ -1439,6 +1479,7 @@ end
14391479
###
14401480

14411481
function ^(a::SN, b)
1482+
b = unwrap(b)
14421483
!issafecanon(^, a,b) && return Pow(a, b)
14431484
if b isa Number && iszero(b)
14441485
# fast path

test/cache_macro.jl

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ end
1313
@test isequal(val, 2x + 1)
1414
cachestruct = associated_cache(f1)
1515
cache, stats = cachestruct.tlv[]
16-
@test cache isa Dict{Tuple{SymbolicKey}, BasicSymbolic}
16+
@test cache isa Dict{Tuple{SymbolicKey}, Tuple{BasicSymbolic, BasicSymbolic}}
1717
@test length(cache) == 1
18-
@test cache[(SymbolicKey(objectid(x)),)] === val
18+
@test cache[(SymbolicKey(objectid(x)),)][end] === val
1919
@test stats.hits == 0
2020
@test stats.misses == 1
2121
f1(x)
@@ -75,9 +75,9 @@ end
7575
@test isequal(val, 2x + 1)
7676
cachestruct = associated_cache(f2)
7777
cache, stats = cachestruct.tlv[]
78-
@test cache isa Dict{Tuple{Union{SymbolicKey, UInt}}, Union{BasicSymbolic, UInt}}
78+
@test cache isa Dict{Tuple{Union{SymbolicKey, UInt}}, NTuple{2, Union{BasicSymbolic, UInt}}}
7979
@test length(cache) == 1
80-
@test cache[(SymbolicKey(objectid(x)),)] === val
80+
@test cache[(SymbolicKey(objectid(x)),)][end] === val
8181
@test stats.hits == 0
8282
@test stats.misses == 1
8383
f2(x)
@@ -88,7 +88,7 @@ end
8888
val = f2(y)
8989
@test val == 2y + 1
9090
@test length(cache) == 2
91-
@test cache[(y,)] == val
91+
@test cache[(y,)][end] == val
9292
@test stats.misses == 2
9393

9494
clear_cache!(f2)
@@ -100,27 +100,31 @@ end
100100
return 2x + 1
101101
end
102102

103-
@testset "::Any" begin
103+
@cache function f3_2(x::Any)::Union{BasicSymbolic, Int}
104+
return 2x + 1
105+
end
106+
107+
@testset "$name" for (name, fn) in [("implicit ::Any", f3), ("explicit ::Any", f3_2)]
104108
@syms x
105-
val = f3(x)
109+
val = fn(x)
106110
@test isequal(val, 2x + 1)
107-
cachestruct = associated_cache(f3)
111+
cachestruct = associated_cache(fn)
108112
cache, stats = cachestruct.tlv[]
109-
@test cache isa Dict{Tuple{Any}, Union{BasicSymbolic, Int}}
113+
@test cache isa Dict{Tuple{Any}, Tuple{Any, Union{BasicSymbolic, Int}}}
110114
@test length(cache) == 1
111-
@test cache[(SymbolicKey(objectid(x)),)] === val
115+
@test cache[(SymbolicKey(objectid(x)),)][end] === val
112116
@test stats.hits == 0
113117
@test stats.misses == 1
114-
f3(x)
118+
fn(x)
115119
@test stats.hits == 1
116120
@test stats.misses == 1
117121

118-
val = f3(3)
122+
val = fn(3)
119123
@test val == 7
120124
@test length(cache) == 2
121125
@test stats.misses == 2
122126

123-
clear_cache!(f3)
127+
clear_cache!(fn)
124128
@test length(cache) == 0
125129
@test stats.hits == stats.misses == stats.clears == 0
126130
end
@@ -155,3 +159,39 @@ end
155159
truevals = map(f4, exprs)
156160
@test isequal(result, truevals)
157161
end
162+
163+
@cache function f5(x::BasicSymbolic, y::Union{BasicSymbolic, Int}, z)::BasicSymbolic
164+
return x + y + z
165+
end
166+
167+
# temporary defintion to induce objectid collisions
168+
Base.objectid(x::BasicSymbolic) = 0x42
169+
170+
@testset "`objectid` collision handling" begin
171+
@syms x y z
172+
@test objectid(x) == objectid(y) == objectid(z) == 0x42
173+
cachestruct = associated_cache(f5)
174+
cache, stats = cachestruct.tlv[]
175+
val = f5(x, 1, 2)
176+
@test isequal(val, x + 3)
177+
@test length(cache) == 1
178+
@test stats.misses == 1
179+
val2 = f5(y, 1, 2)
180+
@test isequal(val2, y + 3)
181+
@test length(cache) == 1
182+
@test stats.misses == 2
183+
184+
clear_cache!(f5)
185+
val = f5(x, y, z)
186+
@test isequal(val, x + y + z)
187+
@test length(cache) == 1
188+
@test stats.misses == 1
189+
val2 = f5(y, 2z, x)
190+
@test isequal(val2, x + y + 2z)
191+
@test length(cache) == 1
192+
@test stats.misses == 2
193+
end
194+
195+
Base.delete_method(only(methods(objectid, @__MODULE__)))
196+
@syms x
197+
@test objectid(x) != 0x42

0 commit comments

Comments
 (0)