Skip to content

Commit 7e5ef70

Browse files
committed
Remove unnecessary specializations, simplify_constants, and add fold_constants
1 parent 17b0a95 commit 7e5ef70

File tree

3 files changed

+22
-16
lines changed

3 files changed

+22
-16
lines changed

src/structural_transformation/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no
179179
a, b, islinear = linear_expansion(term, var)
180180
a, b = unwrap(a), unwrap(b)
181181
islinear || (all_int_vars = false; continue)
182+
a = ModelingToolkit.fold_constants(a)
183+
b = ModelingToolkit.fold_constants(b)
182184
if a isa Symbolic
183185
all_int_vars = false
184186
if !allow_symbolic

src/utils.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ end
1919
function detime_dvs(op)
2020
if !istree(op)
2121
op
22-
elseif operation(op) isa Sym
22+
elseif issym(operation(op))
2323
Sym{Real}(nameof(operation(op)))
2424
else
2525
similarterm(op, operation(op), detime_dvs.(arguments(op)))
@@ -60,7 +60,7 @@ function states_to_sym(states::Set)
6060
elseif istree(O)
6161
op = operation(O)
6262
args = arguments(O)
63-
if op isa Sym
63+
if issym(op)
6464
O in states && return tosymbol(O)
6565
# dependent variables
6666
return build_expr(:call, Any[nameof(op); _states_to_sym.(args)])
@@ -511,7 +511,7 @@ function collect_constants(x)
511511
return constants
512512
end
513513

514-
function collect_constants!(constants, arr::AbstractArray{T}) where {T}
514+
function collect_constants!(constants, arr::AbstractArray)
515515
for el in arr
516516
collect_constants!(constants, el)
517517
end
@@ -526,8 +526,8 @@ collect_constants!(constants, x::Num) = collect_constants!(constants, unwrap(x))
526526
collect_constants!(constants, x::Real) = nothing
527527
collect_constants(n::Nothing) = Symbolics.Sym[]
528528

529-
function collect_constants!(constants, expr::Symbolics.Symbolic{T}) where {T}
530-
if expr isa Sym && isconstant(expr)
529+
function collect_constants!(constants, expr::Symbolics.Symbolic)
530+
if issym(expr) && isconstant(expr)
531531
push!(constants, expr)
532532
else
533533
evars = vars(expr)
@@ -542,8 +542,7 @@ function collect_constants!(constants, expr::Symbolics.Symbolic{T}) where {T}
542542
end
543543

544544
""" Replace symbolic constants with their literal values """
545-
function eliminate_constants(eqs::AbstractArray{<:Union{Equation, Symbolic}},
546-
cs::Vector{Sym})
545+
function eliminate_constants(eqs, cs)
547546
cmap = Dict(x => getdefault(x) for x in cs)
548547
return substitute(eqs, cmap)
549548
end
@@ -807,6 +806,17 @@ function jacobian_wrt_vars(pf::F, p, input_idxs, chunk::C) where {F, C}
807806
ForwardDiff.jacobian(p_closure, p_small, cfg, Val(false))
808807
end
809808

809+
function fold_constants(ex)
810+
if istree(ex)
811+
similarterm(ex, operations(ex), map(fold_constants, arguments(ex)),
812+
symtype(expr); metadata = metadata(expr))
813+
elseif issym(ex) && isconstant(ex)
814+
getdefault(ex)
815+
else
816+
ex
817+
end
818+
end
819+
810820
# Symbolics needs to call unwrap on the substitution rules, but most of the time
811821
# we don't want to do that in MTK.
812822
function fast_substitute(eq::Equation, subs)

test/constants.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,9 @@ newsys = MT.eliminate_constants(sys)
1919
eqs = [D(x) ~ 1,
2020
w ~ a]
2121
@named sys = ODESystem(eqs)
22-
simp = structural_simplify(sys, simplify_constants = false);
23-
@test isequal(simp.substitutions.subs[1], eqs[2])
24-
@test isequal(equations(simp)[1], eqs[1])
25-
prob = ODEProblem(simp, [0], [0.0, 1.0], [])
26-
sol = solve(prob, Tsit5())
27-
@test sol[w][1] == 1
2822
# Now eliminate the constants first
29-
simp = structural_simplify(sys, simplify_constants = true);
30-
@test isequal(simp.substitutions.subs[1], w ~ 1)
23+
simp = structural_simplify(sys)
24+
@test equations(simp) == [D(x) ~ 1.0]
3125

3226
#Constant with units
3327
@constants β=1 [unit = u"m/s"]
@@ -43,4 +37,4 @@ sys = ODESystem(eqs, name = :sys)
4337
simp = structural_simplify(sys)
4438
@test_throws MT.ValidationError MT.check_units(simp.eqs...)
4539

46-
@test MT.collect_constants(nothing) == Symbolics.Sym[]
40+
@test isempty(MT.collect_constants(nothing))

0 commit comments

Comments
 (0)