Skip to content

Commit cf71d23

Browse files
Merge pull request #727 from AayushSabharwal/as/fix-cache
fix: fix keyword syntax for `@cache` options
2 parents bb0774d + 7ea4fad commit cf71d23

File tree

5 files changed

+103
-40
lines changed

5 files changed

+103
-40
lines changed

.github/workflows/Downstream.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ jobs:
1818
julia-version: [1]
1919
os: [ubuntu-latest]
2020
package:
21-
- {user: SciML, repo: ModelingToolkit.jl, group: All}
21+
- {user: SciML, repo: ModelingToolkit.jl, group: InterfaceI}
22+
- {user: SciML, repo: ModelingToolkit.jl, group: InterfaceII}
23+
- {user: SciML, repo: ModelingToolkit.jl, group: Initialization}
24+
- {user: SciML, repo: ModelingToolkit.jl, group: SymbolicIndexingInterface}
25+
- {user: SciML, repo: ModelingToolkit.jl, group: FMI}
2226
- {user: SciML, repo: Catalyst.jl, group: All}
2327
- {user: SciML, repo: NeuralPDE.jl, group: NNPDE}
2428
- {user: SciML, repo: DataDrivenDiffEq.jl, group: Core}

src/cache.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ resets the stats.
102102
function clear_cache!(fn)
103103
dict, stats = associated_cache(fn).tlv[]
104104
empty!(dict)
105+
sizehint!(dict, get_limit(fn))
105106
reset_stats!(stats)
106107
end
107108

@@ -193,7 +194,7 @@ macro cache(args...)
193194
# parse configuration options
194195
config = Dict(:limit => 100_000, :retain_fraction => 0.5, :allow_any_return => false, :enabled => true)
195196
for carg in configargs
196-
if !Meta.isexpr(carg, :())
197+
if !Meta.isexpr(carg, :(=))
197198
throw(ArgumentError("Expected `key = value` syntax, got $carg"))
198199
end
199200
k, v = carg.args
@@ -339,7 +340,7 @@ macro cache(args...)
339340
end)
340341

341342
# instantiation of the TaskLocalValue
342-
tlvctor = :($tlvT(() -> ($cacheT(), $CacheStats())))
343+
tlvctor = :($tlvT(() -> ((dict = $cacheT(); sizehint!(dict, $get_limit($name)); dict), $CacheStats())))
343344
# instantiation expression for the constant value
344345
cachector = Expr(:call, structT, tlvctor, config[:limit], config[:retain_fraction], config[:enabled])
345346

src/small_array.jl

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ GC'ed when removed.
3232
defaultval(::Type{T}) where {T <: Number} = zero(T)
3333
defaultval(::Type{Any}) = nothing
3434

35-
function Base.getindex(x::Backing, i::Int)
35+
Base.@propagate_inbounds function Base.getindex(x::Backing, i::Int)
3636
@boundscheck 1 <= i <= x.len
3737
if i == 1
3838
x.x1
@@ -43,7 +43,7 @@ function Base.getindex(x::Backing, i::Int)
4343
end
4444
end
4545

46-
function Base.setindex!(x::Backing, v, i::Int)
46+
Base.@propagate_inbounds function Base.setindex!(x::Backing, v, i::Int)
4747
@boundscheck 1 <= i <= x.len
4848
if i == 1
4949
setfield!(x, :x1, v)
@@ -54,20 +54,65 @@ function Base.setindex!(x::Backing, v, i::Int)
5454
end
5555
end
5656

57-
function Base.push!(x::Backing, v)
58-
x.len < 3 || throw(ArgumentError("`Backing` is full"))
57+
Base.@propagate_inbounds function Base.push!(x::Backing, v)
58+
@boundscheck x.len < 3
5959
x.len += 1
6060
x[x.len] = v
6161
end
6262

63-
function Base.pop!(x::Backing{T}) where {T}
64-
x.len > 0 || throw(ArgumentError("Array is empty"))
63+
Base.@propagate_inbounds function Base.pop!(x::Backing{T}) where {T}
64+
@boundscheck x.len > 0
6565
v = x[x.len]
6666
x[x.len] = defaultval(T)
6767
x.len -= 1
6868
v
6969
end
7070

71+
function Base.any(f::Function, x::Backing)
72+
if x.len == 0
73+
false
74+
elseif x.len == 1
75+
f(x.x1)
76+
elseif x.len == 2
77+
f(x.x1) || f(x.x2)
78+
elseif x.len == 3
79+
f(x.x1) || f(x.x2) || f(x.x3)
80+
end
81+
end
82+
83+
function Base.all(f::Function, x::Backing)
84+
if x.len == 0
85+
true
86+
elseif x.len == 1
87+
f(x.x1)
88+
elseif x.len == 2
89+
f(x.x1) && f(x.x2)
90+
elseif x.len == 3
91+
f(x.x1) && f(x.x2) && f(x.x3)
92+
end
93+
end
94+
95+
function Base.map(f, x::Backing{T}) where {T}
96+
if x.len == 0
97+
# StaticArrays does this, so we are only as bad as they are
98+
Backing{Core.Compiler.return_type(f, Tuple{T})}()
99+
elseif x.len == 1
100+
x1 = f(x.x1)
101+
Backing{typeof(x1)}(x1)
102+
elseif x.len == 2
103+
x1 = f(x.x1)
104+
x2 = f(x.x2)
105+
Backing{Base.promote_typejoin(typeof(x1), typeof(x2))}(x1, x2)
106+
elseif x.len == 3
107+
x1 = f(x.x1)
108+
x2 = f(x.x2)
109+
x3 = f(x.x3)
110+
_T = Base.promote_typejoin(typeof(x1), typeof(x2))
111+
_T = Base.promote_typejoin(_T, typeof(x3))
112+
Backing{_T}(x1, x2, x3)
113+
end
114+
end
115+
71116
"""
72117
$(TYPEDSIGNATURES)
73118
@@ -94,6 +139,10 @@ mutable struct SmallVec{T, V <: AbstractVector{T}} <: AbstractVector{T}
94139
end
95140
end
96141

142+
function SmallVec{T, V}(x::Backing{T}) where {T, V}
143+
new{T, V}(x)
144+
end
145+
97146
function SmallVec{T, V}() where {T, V}
98147
new{T, V}(Backing{T}())
99148
end
@@ -113,21 +162,24 @@ Base.convert(::Type{SmallVec{T, V}}, x::SmallVec{T, V}) where {T, V} = x
113162

114163
Base.size(x::SmallVec) = size(x.data)
115164
Base.isempty(x::SmallVec) = isempty(x.data)
116-
Base.getindex(x::SmallVec, i::Int) = x.data[i]
117-
Base.setindex!(x::SmallVec, v, i::Int) = setindex!(x.data, v, i)
165+
Base.@propagate_inbounds Base.getindex(x::SmallVec, i::Int) = x.data[i]
166+
Base.@propagate_inbounds Base.setindex!(x::SmallVec, v, i::Int) = setindex!(x.data, v, i)
118167

119-
function Base.push!(x::SmallVec{T, V}, v) where {T, V}
168+
Base.@propagate_inbounds function Base.push!(x::SmallVec{T, V}, v) where {T, V}
120169
buf = x.data
121170
buf isa Backing{T} || return push!(buf::V, v)
122171
isfull(buf) || return push!(buf::Backing{T}, v)
123172
x.data = V(buf)
124173
return push!(x.data::V, v)
125174
end
126175

127-
Base.pop!(x::SmallVec) = pop!(x.data)
176+
Base.@propagate_inbounds Base.pop!(x::SmallVec) = pop!(x.data)
128177

129178
function Base.sizehint!(x::SmallVec{T, V}, n; kwargs...) where {T, V}
130179
x.data isa Backing && return x
131180
sizehint!(x.data, n; kwargs...)
132181
x
133182
end
183+
184+
Base.any(f::Function, x::SmallVec) = any(f, x.data)
185+
Base.all(f::Function, x::SmallVec) = all(f, x.data)

src/types.jl

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -332,14 +332,10 @@ function isequal_with_metadata(a::NamedTuple, b::NamedTuple)
332332
a === b && return true
333333
typeof(a) == typeof(b) || return false
334334

335-
for (k, v) in pairs(a)
336-
haskey(b, k) || return false
337-
isequal_with_metadata(v, b[k]) || return false
338-
end
339-
340-
for (k, v) in pairs(b)
341-
haskey(a, k) || return false
342-
isequal_with_metadata(v, a[k]) || return false
335+
# same type, so same keys and value types
336+
# either everything works or it fails and early exits
337+
for (av, bv) in zip(values(a), values(b))
338+
isequal_with_metadata(av, bv) || return false
343339
end
344340

345341
return true
@@ -350,27 +346,28 @@ function isequal_with_metadata(a::AbstractDict, b::AbstractDict)
350346
typeof(a) == typeof(b) || return false
351347
length(a) == length(b) || return false
352348

353-
akeys = collect(keys(a))
354-
avisited = falses(length(akeys))
355-
bkeys = collect(keys(b))
356-
bvisited = falses(length(bkeys))
349+
# they have same length, so either `b` has all the same keys
350+
# or this will fail. Can't use `get(b, k, nothing)` because if
351+
# `a[k] === nothing` it will result in a false positive.
352+
for (k, v) in a
353+
k2 = getkey(b, k, nothing)
354+
isequal_with_metadata(k, k2) || return false
355+
isequal_with_metadata(v, b[k2]) || return false
356+
end
357+
return true
358+
end
359+
360+
function isequal_with_metadata(a::Base.ImmutableDict, b::Base.ImmutableDict)
361+
a === b && return true
362+
typeof(a) == typeof(b) || return false
363+
length(a) == length(b) || return false
357364

358-
for k in akeys
359-
idx = findfirst(eachindex(bkeys)) do i
360-
!bvisited[i] && isequal_with_metadata(k, bkeys[i])
361-
end
362-
idx === nothing && return false
363-
bvisited[idx] = true
364-
isequal_with_metadata(a[k], b[bkeys[idx]]) || return false
365-
end
366-
for (j, k) in enumerate(bkeys)
367-
bvisited[j] && continue
368-
idx = findfirst(eachindex(akeys)) do i
369-
!avisited[i] && isequal_with_metadata(k, akeys[i])
365+
for (k, v) in a
366+
match = false
367+
for (k2, v2) in b
368+
match |= isequal_with_metadata(k, k2) && isequal_with_metadata(v, v2)
370369
end
371-
idx === nothing && return false
372-
avisited[idx] = true
373-
isequal_with_metadata(b[k], a[akeys[idx]]) || return false
370+
match || return false
374371
end
375372
return true
376373
end

test/cache_macro.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,12 @@ end
196196
Base.delete_method(only(methods(objectid, @__MODULE__)))
197197
@syms x
198198
@test objectid(x) != 0x42
199+
200+
@cache limit = 10 retain_fraction = 0.1 function f6(x::BasicSymbolic, y::Union{BasicSymbolic, Int}, z)::BasicSymbolic
201+
return x + y + z
202+
end
203+
204+
@testset "Keyword argument syntax works" begin
205+
@test SymbolicUtils.get_limit(f6) == 10
206+
@test SymbolicUtils.get_retain_fraction(f6) 0.1
207+
end

0 commit comments

Comments
 (0)