Skip to content

Commit 7e86beb

Browse files
fix: handle case when argument to cached function can be UInt
1 parent 140b617 commit 7e86beb

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

src/cache.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,17 @@ Sentinel value used for a cache miss, since cached functions may return `nothing
55
"""
66
struct CacheSentinel end
77

8+
"""
9+
$(TYPEDEF)
10+
11+
Struct wrapping the `objectid` of a `BasicSymbolic`, since arguments annotated
12+
`::Union{BasicSymbolic, UInt}` would not be able to differentiate between looking
13+
up a symbolic or a `UInt`.
14+
"""
15+
struct SymbolicKey
16+
id::UInt
17+
end
18+
819
"""
920
associated_cache(fn)
1021
@@ -226,7 +237,7 @@ macro cache(args...)
226237
end
227238
if !Meta.isexpr(arg, :(::))
228239
# if the type is `Any`, branch on it being a `BasicSymbolic`
229-
push!(keyexprs, :($arg isa BasicSymbolic ? objectid($arg) : $arg))
240+
push!(keyexprs, :($arg isa BasicSymbolic ? $SymbolicKey(objectid($arg)) : $arg))
230241
push!(argexprs, arg)
231242
push!(keytypes, Any)
232243
continue
@@ -238,19 +249,17 @@ macro cache(args...)
238249
if Meta.isexpr(Texpr, :curly) && Texpr.args[1] == :Union
239250
Texprs = Texpr.args[2:end]
240251
Ts = map(Base.Fix1(Base.eval, __module__), Texprs)
241-
keyTs = map(x -> x <: BasicSymbolic ? UInt64 : x, Ts)
242-
if any(x -> x <: BasicSymbolic, Ts)
243-
end
252+
keyTs = map(x -> x <: BasicSymbolic ? SymbolicKey : x, Ts)
244253
push!(keytypes, Union{keyTs...})
245-
push!(keyexprs, :($argname isa BasicSymbolic ? objectid($argname) : $argname))
254+
push!(keyexprs, :($argname isa BasicSymbolic ? $SymbolicKey(objectid($argname)) : $argname))
246255
continue
247256
end
248257

249258
# use `eval` to get the type because we need to know if it's a `BasicSymbolic`
250259
T = Base.eval(__module__, Texpr)
251260
if T <: BasicSymbolic
252-
push!(keytypes, UInt64)
253-
push!(keyexprs, :(objectid($argname)))
261+
push!(keytypes, SymbolicKey)
262+
push!(keyexprs, :($SymbolicKey(objectid($argname))))
254263
else
255264
push!(keytypes, T)
256265
push!(keyexprs, argname)

0 commit comments

Comments
 (0)