Skip to content

Commit 2856c67

Browse files
fix: handle collision of objectid in @cache macro
1 parent 0d6338e commit 2856c67

File tree

3 files changed

+112
-7
lines changed

3 files changed

+112
-7
lines changed

src/cache.jl

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,14 @@ macro cache(args...)
229229
keytypes = []
230230
# The arguments of the workhorse function
231231
argexprs = []
232+
# The name of the variable storing the result of looking up the cache
233+
cache_value_name = :val
234+
# The condition for a cache hit
235+
cache_hit_condition = :(!($cache_value_name isa $CacheSentinel))
236+
# Type of additional data stored with cached result. Used to compare
237+
# equality of `BasicSymbolic` arguments, since `objectid` is a hash.
238+
cache_additional_types = []
239+
cache_additional_values = []
232240

233241
for arg in fn.args
234242
# handle arguments with defaults
@@ -240,6 +248,9 @@ macro cache(args...)
240248
push!(keyexprs, :($arg isa BasicSymbolic ? $SymbolicKey(objectid($arg)) : $arg))
241249
push!(argexprs, arg)
242250
push!(keytypes, Any)
251+
push!(cache_additional_types, Any)
252+
push!(cache_additional_values, arg)
253+
cache_hit_condition = :($cache_hit_condition && (!($arg isa BasicSymbolic) || $arg === $cache_value_name[$(length(cache_additional_values))]))
243254
continue
244255
end
245256
argname, Texpr = arg.args
@@ -249,6 +260,9 @@ macro cache(args...)
249260
# if the type is `Any`, branch on it being a `BasicSymbolic`
250261
push!(keyexprs, :($argname isa BasicSymbolic ? $SymbolicKey(objectid($argname)) : $argname))
251262
push!(keytypes, Any)
263+
push!(cache_additional_types, Any)
264+
push!(cache_additional_values, argname)
265+
cache_hit_condition = :($cache_hit_condition && (!($argname isa BasicSymbolic) || $argname === $cache_value_name[$(length(cache_additional_values))]))
252266
continue
253267
end
254268

@@ -257,8 +271,16 @@ macro cache(args...)
257271
Texprs = Texpr.args[2:end]
258272
Ts = map(Base.Fix1(Base.eval, __module__), Texprs)
259273
keyTs = map(x -> x <: BasicSymbolic ? SymbolicKey : x, Ts)
274+
maybe_basicsymbolic = any(x -> x <: BasicSymbolic, Ts)
260275
push!(keytypes, Union{keyTs...})
261-
push!(keyexprs, :($argname isa BasicSymbolic ? $SymbolicKey(objectid($argname)) : $argname))
276+
if maybe_basicsymbolic
277+
push!(keyexprs, :($argname isa BasicSymbolic ? $SymbolicKey(objectid($argname)) : $argname))
278+
push!(cache_additional_types, Texpr)
279+
push!(cache_additional_values, argname)
280+
cache_hit_condition = :($cache_hit_condition && (!($argname isa BasicSymbolic) || $argname === $cache_value_name[$(length(cache_additional_values))]))
281+
else
282+
push!(keyexprs, argname)
283+
end
262284
continue
263285
end
264286

@@ -267,6 +289,9 @@ macro cache(args...)
267289
if T <: BasicSymbolic
268290
push!(keytypes, SymbolicKey)
269291
push!(keyexprs, :($SymbolicKey(objectid($argname))))
292+
push!(cache_additional_types, T)
293+
push!(cache_additional_values, argname)
294+
cache_hit_condition = :($cache_hit_condition && $argname === $cache_value_name[$(length(cache_additional_values))])
270295
else
271296
push!(keytypes, T)
272297
push!(keyexprs, argname)
@@ -287,8 +312,11 @@ macro cache(args...)
287312
# construct an expression for the type of the cache keys
288313
keyT = Expr(:curly, Tuple)
289314
append!(keyT.args, keytypes)
315+
valT = Expr(:curly, Tuple)
316+
append!(valT.args, cache_additional_types)
317+
push!(valT.args, rettype)
290318
# the type of the cache
291-
cacheT = :(Dict{$keyT, $rettype})
319+
cacheT = :(Dict{$keyT, $valT})
292320
# type of the `TaskLocalValue`
293321
tlvT = :($(TaskLocalValue){Tuple{$cacheT, $CacheStats}})
294322
# the name of the cache struct
@@ -331,11 +359,11 @@ macro cache(args...)
331359
# look it up
332360
# we use a custom sentinel value since `nothing` is a valid return value
333361
# which we might want to cache
334-
val = $(get)(cachedict, key, $(CacheSentinel)())
335-
if !(val isa $CacheSentinel)
362+
$cache_value_name = $(get)(cachedict, key, $(CacheSentinel)())
363+
if $cache_hit_condition
336364
# cache hit
337365
cachestats.hits += 1
338-
return val
366+
return $cache_value_name[end]
339367
end
340368
# cache miss
341369
cachestats.misses += 1
@@ -346,7 +374,7 @@ macro cache(args...)
346374
$(filter!)($cachename, cachedict)
347375
end
348376
# add to cache
349-
cachedict[key] = val
377+
cachedict[key] = ($(cache_additional_values...), val)
350378
return val
351379
end
352380

src/types.jl

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic})
6969
ScalarSymbolic()
7070
end
7171

72+
"""
73+
$(TYPEDSIGNATURES)
74+
75+
Return the inner `Symbolic` wrapped in a non-symbolic subtype. Defaults to
76+
returning the input as-is.
77+
"""
78+
unwrap(x) = x
79+
7280
function exprtype(x::BasicSymbolic)
7381
@compactified x::BasicSymbolic begin
7482
Term => TERM
@@ -538,10 +546,17 @@ function Sym{T}(name::Symbol; kw...) where {T}
538546
BasicSymbolic(s)
539547
end
540548

549+
function unwrap_arr!(arr)
550+
for i in eachindex(arr)
551+
arr[i] = unwrap(arr[i])
552+
end
553+
end
554+
541555
function Term{T}(f, args; kw...) where T
542556
if eltype(args) !== Any
543557
args = convert(Vector{Any}, args)
544558
end
559+
unwrap_arr!(args)
545560

546561
s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), kw...)
547562
BasicSymbolic(s)
@@ -551,7 +566,16 @@ function Term(f, args; metadata=NO_METADATA)
551566
Term{_promote_symtype(f, args)}(f, args, metadata=metadata)
552567
end
553568

569+
function unwrap_dict(dict)
570+
if any(k -> unwrap(k) !== k, keys(dict))
571+
return typeof(dict)(unwrap(k) => v for (k, v) in dict)
572+
end
573+
return dict
574+
end
575+
554576
function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T
577+
coeff = unwrap(coeff)
578+
dict = unwrap_dict(dict)
555579
if isempty(dict)
556580
return coeff
557581
elseif _iszero(coeff) && length(dict) == 1
@@ -569,6 +593,8 @@ function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T
569593
end
570594

571595
function Mul(T, a, b; metadata=NO_METADATA, kw...)
596+
a = unwrap(a)
597+
b = unwrap_dict(b)
572598
isempty(b) && return a
573599
if _isone(a) && length(b) == 1
574600
pair = first(b)
@@ -613,6 +639,8 @@ function maybe_intcoeff(x)
613639
end
614640

615641
function Div{T}(n, d, simplified=false; metadata=nothing, kwargs...) where {T}
642+
n = unwrap(n)
643+
d = unwrap(d)
616644
if T<:Number && !(T<:SafeReal)
617645
n, d = quick_cancel(n, d)
618646
end
@@ -662,6 +690,8 @@ end
662690
@inline denominators(x) = isdiv(x) ? numerators(x.den) : Any[1]
663691

664692
function Pow{T}(a, b; metadata=NO_METADATA, kwargs...) where {T}
693+
a = unwrap(a)
694+
b = unwrap(b)
665695
_iszero(b) && return 1
666696
_isone(b) && return a
667697
s = Pow{T}(; base=a, exp=b, arguments=[], metadata)
@@ -760,6 +790,7 @@ function makepow(a, b)
760790
end
761791

762792
function term(f, args...; type = nothing)
793+
args = map(unwrap, args)
763794
if type === nothing
764795
T = _promote_symtype(f, args)
765796
else
@@ -894,6 +925,7 @@ Base.ImmutableDict(d::ImmutableDict{K,V}, x, y) where {K, V} = ImmutableDict{K,
894925

895926
assocmeta(d::Dict, ctx, val) = (d=copy(d); d[ctx] = val; d)
896927
function assocmeta(d::Base.ImmutableDict, ctx, val)::ImmutableDict{DataType,Any}
928+
val = unwrap(val)
897929
# optimizations
898930
# If using upto 3 contexts, things stay compact
899931
if isdefined(d, :parent)
@@ -915,7 +947,7 @@ function setmetadata(s::Symbolic, ctx::DataType, val)
915947
@set s.metadata = assocmeta(s.metadata, ctx, val)
916948
else
917949
# fresh Dict
918-
@set s.metadata = Base.ImmutableDict{DataType, Any}(ctx, val)
950+
@set s.metadata = Base.ImmutableDict{DataType, Any}(ctx, unwrap(val))
919951
end
920952
end
921953

@@ -1333,6 +1365,10 @@ function +(a::SN, b::SN)
13331365
end
13341366

13351367
function +(a::Number, b::SN)
1368+
tmp = unwrap(a)
1369+
if tmp !== a
1370+
return tmp + b
1371+
end
13361372
!issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata
13371373
iszero(a) && return b
13381374
if isadd(b)
@@ -1399,6 +1435,10 @@ function *(a::SN, b::SN)
13991435
end
14001436

14011437
function *(a::Number, b::SN)
1438+
tmp = unwrap(a)
1439+
if tmp !== a
1440+
return tmp * b
1441+
end
14021442
!issafecanon(*, b) && return term(*, a, b)
14031443
if iszero(a)
14041444
a
@@ -1439,6 +1479,7 @@ end
14391479
###
14401480

14411481
function ^(a::SN, b)
1482+
b = unwrap(b)
14421483
!issafecanon(^, a,b) && return Pow(a, b)
14431484
if b isa Number && iszero(b)
14441485
# fast path

test/cache_macro.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,39 @@ end
159159
truevals = map(f4, exprs)
160160
@test isequal(result, truevals)
161161
end
162+
163+
@cache function f5(x::BasicSymbolic, y::Union{BasicSymbolic, Int}, z)::BasicSymbolic
164+
return x + y + z
165+
end
166+
167+
# temporary defintion to induce objectid collisions
168+
Base.objectid(x::BasicSymbolic) = 0x42
169+
170+
@testset "`objectid` collision handling" begin
171+
@syms x y z
172+
@test objectid(x) == objectid(y) == objectid(z) == 0x42
173+
cachestruct = associated_cache(f5)
174+
cache, stats = cachestruct.tlv[]
175+
val = f5(x, 1, 2)
176+
@test isequal(val, x + 3)
177+
@test length(cache) == 1
178+
@test stats.misses == 1
179+
val2 = f5(y, 1, 2)
180+
@test isequal(val2, y + 3)
181+
@test length(cache) == 1
182+
@test stats.misses == 2
183+
184+
clear_cache!(f5)
185+
val = f5(x, y, z)
186+
@test isequal(val, x + y + z)
187+
@test length(cache) == 1
188+
@test stats.misses == 1
189+
val2 = f5(y, 2z, x)
190+
@test isequal(val2, x + y + 2z)
191+
@test length(cache) == 1
192+
@test stats.misses == 2
193+
end
194+
195+
Base.delete_method(only(methods(objectid, @__MODULE__)))
196+
@syms x
197+
@test objectid(x) != 0x42

0 commit comments

Comments
 (0)