Skip to content

Commit 1ef69c6

Browse files
Merge pull request #711 from AayushSabharwal/as/fix-sorted-arguments
fix: fix caching breakages
2 parents 0c59cc9 + 777367b commit 1ef69c6

File tree

5 files changed

+19
-26
lines changed

5 files changed

+19
-26
lines changed

.github/workflows/Downstream.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ jobs:
2323
- {user: SciML, repo: NeuralPDE.jl, group: NNPDE}
2424
- {user: SciML, repo: DataDrivenDiffEq.jl, group: Core}
2525
- {user: JuliaSymbolics, repo: Symbolics.jl, group: All}
26+
- {user: JuliaSymbolics, repo: Symbolics.jl, group: Downstream}
2627
- {user: SciML, repo: ModelOrderReduction.jl, group: All}
2728

2829
steps:

src/SymbolicUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ using WeakValueDicts: WeakValueDict
2424
import ExproniconLite as EL
2525
import TaskLocalValues: TaskLocalValue
2626

27+
include("cache.jl")
2728
Base.@deprecate istree iscall
2829
export istree, operation, arguments, sorted_arguments, iscall
2930
# Sym, Term,
@@ -76,6 +77,5 @@ include("code.jl")
7677
# Adjoints
7778
include("adjoints.jl")
7879

79-
include("cache.jl")
8080

8181
end # module

src/cache.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ macro cache(args...)
215215
fn = EL.JLFunction(fnexpr)
216216
name = fn.name
217217
# this will now be an inner workhorse function, which the cached function will call
218-
fn.name = gensym(name)
218+
fn.name = gensym(Symbol(name))
219219
# the name of the global constant cache
220220
cachename = Symbol("cacheof($name)")
221221
# conditions for performing caching. At the very least, need hashconsing enabled and

src/types.jl

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ sdict(kv...) = Dict{Any, Any}(kv...)
1818
using Base: RefValue
1919
const EMPTY_ARGS = []
2020
const EMPTY_HASH = RefValue(UInt(0))
21-
const NOT_SORTED = RefValue(false)
2221
const EMPTY_DICT = sdict()
2322
const EMPTY_DICT_T = typeof(EMPTY_DICT)
2423
const ENABLE_HASHCONSING = Ref(true)
@@ -42,15 +41,13 @@ const ENABLE_HASHCONSING = Ref(true)
4241
hash::RefValue{UInt} = EMPTY_HASH
4342
hash2::RefValue{UInt} = EMPTY_HASH
4443
arguments::Vector{Any} = EMPTY_ARGS
45-
issorted::RefValue{Bool} = NOT_SORTED
4644
end
4745
mutable struct Add{T} <: BasicSymbolic{T}
4846
coeff::Any = 0 # exp/den if Pow
4947
dict::EMPTY_DICT_T = EMPTY_DICT
5048
hash::RefValue{UInt} = EMPTY_HASH
5149
hash2::RefValue{UInt} = EMPTY_HASH
5250
arguments::Vector{Any} = EMPTY_ARGS
53-
issorted::RefValue{Bool} = NOT_SORTED
5451
end
5552
mutable struct Div{T} <: BasicSymbolic{T}
5653
num::Any = 1
@@ -150,25 +147,19 @@ end
150147

151148
@inline head(x::BasicSymbolic) = operation(x)
152149

153-
function TermInterface.sorted_arguments(x::BasicSymbolic)
154-
args = arguments(x)
150+
@cache function TermInterface.sorted_arguments(x::BasicSymbolic)::Vector{Any}
151+
args = copy(arguments(x))
155152
@compactified x::BasicSymbolic begin
156153
Add => @goto ADD
157154
Mul => @goto MUL
158155
_ => return args
159156
end
160157
@label MUL
161-
if !x.issorted[]
162-
sort!(args, by=get_degrees)
163-
x.issorted[] = true
164-
end
158+
sort!(args, by=get_degrees)
165159
return args
166160

167161
@label ADD
168-
if !x.issorted[]
169-
sort!(args, lt = monomial_lt, by=get_degrees)
170-
x.issorted[] = true
171-
end
162+
sort!(args, lt = monomial_lt, by=get_degrees)
172163
return args
173164
end
174165

@@ -603,7 +594,7 @@ function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T
603594
end
604595
end
605596

606-
s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...)
597+
s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=[], kw...)
607598
BasicSymbolic(s)
608599
end
609600

@@ -621,7 +612,7 @@ function Mul(T, a, b; metadata=NO_METADATA, kw...)
621612
else
622613
coeff = a
623614
dict = b
624-
s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...)
615+
s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=[], kw...)
625616
BasicSymbolic(s)
626617
end
627618
end
@@ -642,7 +633,7 @@ ratio(x::Rat,y::Rat) = x//y
642633
function maybe_intcoeff(x)
643634
if ismul(x)
644635
if x.coeff isa Rational && isone(x.coeff.den)
645-
Mul{symtype(x)}(; coeff=x.coeff.num, dict=x.dict, x.metadata, arguments=[], issorted=RefValue(false))
636+
Mul{symtype(x)}(; coeff=x.coeff.num, dict=x.dict, x.metadata, arguments=[])
646637
else
647638
x
648639
end
@@ -1011,7 +1002,7 @@ end
10111002
function remove_minus(t)
10121003
!iscall(t) && return -t
10131004
@assert operation(t) == (*)
1014-
args = arguments(t)
1005+
args = sorted_arguments(t)
10151006
@assert args[1] < 0
10161007
Any[-args[1], args[2:end]...]
10171008
end

test/cache_macro.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using SymbolicUtils
22
using SymbolicUtils: BasicSymbolic, @cache, associated_cache, set_limit!, get_limit,
33
clear_cache!, SymbolicKey, metadata, maketerm
44
using OhMyThreads: tmap
5+
using Random
56

67
@cache function f1(x::BasicSymbolic)::BasicSymbolic
78
return 2x + 1
@@ -29,6 +30,8 @@ end
2930

3031
set_limit!(f1, 10)
3132
@test get_limit(f1) == 10
33+
SymbolicUtils.set_retain_fraction!(f1, 0.1)
34+
@test SymbolicUtils.get_retain_fraction(f1) == 0.1
3235
for i in 1:8
3336
xx = setmetadata(xx, Int, i)
3437
f1(xx)
@@ -55,8 +58,6 @@ end
5558
@test length(cache) == 0
5659
stats = SymbolicUtils.get_stats(f1)
5760
@test stats.hits == stats.misses == stats.clears == 0
58-
SymbolicUtils.set_retain_fraction!(f1, 0.1)
59-
@test SymbolicUtils.get_retain_fraction(f1) == 0.1
6061
@test SymbolicUtils.is_caching_enabled(f1)
6162
SymbolicUtils.toggle_caching!(f1, false)
6263
@test !SymbolicUtils.is_caching_enabled(f1)
@@ -129,7 +130,7 @@ end
129130
@test stats.hits == stats.misses == stats.clears == 0
130131
end
131132

132-
@cache function f4(x::Union{BasicSymbolic, Int})::Union{BasicSymbolic, Int}
133+
@cache function f4(x::Union{BasicSymbolic, Number})::Union{BasicSymbolic, Number}
133134
x isa Number && return x
134135
if iscall(x)
135136
return maketerm(typeof(x), operation(x), map(f4, arguments(x)), metadata(x))
@@ -140,17 +141,17 @@ end
140141
@testset "Threading" begin
141142
@syms x y z
142143
@test isequal(f4(2x + 1), 2(2x + 1) + 1)
143-
144+
rng = Xoshiro(3)
144145
function build_rand_expr(vars, depth, maxdepth)
145146
if depth < maxdepth
146147
v = build_rand_expr(vars, depth + 1, maxdepth)
147148
else
148-
v = rand(vars)
149+
v = rand(rng, vars)
149150
end
150151
if isodd(depth)
151-
return v + rand([1:3; vars])
152+
return v + rand(rng, [1:3; vars])
152153
else
153-
return v * rand([1:3; vars])
154+
return v * rand(rng, [1:3; vars])
154155
end
155156
end
156157

0 commit comments

Comments
 (0)