Skip to content

Commit 2813251

Browse files
Merge pull request #2992 from AayushSabharwal/as/bugs
fix: fix substitute duplicating equations
2 parents 677cf30 + d7c673a commit 2813251

File tree

3 files changed

+53
-9
lines changed

3 files changed

+53
-9
lines changed

src/systems/abstractsystem.jl

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2934,12 +2934,12 @@ function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair},
29342934
elseif sys isa ODESystem
29352935
rules = todict(map(r -> Symbolics.unwrap(r[1]) => Symbolics.unwrap(r[2]),
29362936
collect(rules)))
2937-
eqs = fast_substitute(equations(sys), rules)
2938-
pdeps = fast_substitute(parameter_dependencies(sys), rules)
2937+
eqs = fast_substitute(get_eqs(sys), rules)
2938+
pdeps = fast_substitute(get_parameter_dependencies(sys), rules)
29392939
defs = Dict(fast_substitute(k, rules) => fast_substitute(v, rules)
2940-
for (k, v) in defaults(sys))
2940+
for (k, v) in get_defaults(sys))
29412941
guess = Dict(fast_substitute(k, rules) => fast_substitute(v, rules)
2942-
for (k, v) in guesses(sys))
2942+
for (k, v) in get_guesses(sys))
29432943
subsys = map(s -> substitute(s, rules), get_systems(sys))
29442944
ODESystem(eqs, get_iv(sys); name = nameof(sys), defaults = defs,
29452945
guesses = guess, parameter_dependencies = pdeps, systems = subsys)
@@ -2948,14 +2948,34 @@ function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair},
29482948
end
29492949
end
29502950

2951+
struct InvalidParameterDependenciesType
2952+
got::Any
2953+
end
2954+
2955+
function Base.showerror(io::IO, err::InvalidParameterDependenciesType)
2956+
print(
2957+
io, "Parameter dependencies must be a `Dict`, or an array of `Pair` or `Equation`.")
2958+
if err.got !== nothing
2959+
print(io, " Got ", err.got)
2960+
end
2961+
end
2962+
29512963
function process_parameter_dependencies(pdeps, ps)
29522964
if pdeps === nothing || isempty(pdeps)
29532965
return Equation[], ps
2954-
elseif eltype(pdeps) <: Pair
2955-
pdeps = [lhs ~ rhs for (lhs, rhs) in pdeps]
29562966
end
2957-
if !(eltype(pdeps) <: Equation)
2958-
error("Parameter dependencies must be a `Dict`, `Vector{Pair}` or `Vector{Equation}`")
2967+
if pdeps isa Dict
2968+
pdeps = [k ~ v for (k, v) in pdeps]
2969+
else
2970+
pdeps isa AbstractArray || throw(InvalidParameterDependenciesType(pdeps))
2971+
pdeps = [if p isa Pair
2972+
p[1] ~ p[2]
2973+
elseif p isa Equation
2974+
p
2975+
else
2976+
error("Parameter dependencies must be a `Dict`, `Vector{Pair}` or `Vector{Equation}`")
2977+
end
2978+
for p in pdeps]
29592979
end
29602980
lhss = BasicSymbolic[]
29612981
for p in pdeps

test/dq_units.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,4 +223,4 @@ end
223223
@variables X(tt) [unit = u"L"]
224224
DD = Differential(tt)
225225
eqs = [DD(X) ~ p - d * X + d * X]
226-
@test ModelingToolkit.validate(eqs)
226+
@test ModelingToolkit.validate(eqs)

test/odesystem.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,3 +1331,27 @@ end
13311331
@test length(ModelingToolkit.guesses(sys2)) == 2
13321332
@test ModelingToolkit.guesses(sys2)[p3] == 2.0
13331333
end
1334+
1335+
@testset "Substituting with nested systems" begin
1336+
@parameters p1 p2
1337+
@variables x(t) y(t)
1338+
@named innersys = ODESystem([D(x) ~ y + p2], t; parameter_dependencies = [p2 ~ 2p1],
1339+
defaults = [p1 => 1.0, p2 => 2.0], guesses = [p1 => 2.0, p2 => 3.0])
1340+
@parameters p3 p4
1341+
@named outersys = ODESystem(
1342+
[D(innersys.y) ~ innersys.y + p4], t; parameter_dependencies = [p4 ~ 3p3],
1343+
defaults = [p3 => 3.0, p4 => 9.0], guesses = [p4 => 10.0], systems = [innersys])
1344+
@test_nowarn structural_simplify(outersys)
1345+
@parameters p5
1346+
sys2 = substitute(outersys, [p4 => p5])
1347+
@test_nowarn structural_simplify(sys2)
1348+
@test length(equations(sys2)) == 2
1349+
@test length(parameters(sys2)) == 2
1350+
@test length(full_parameters(sys2)) == 4
1351+
@test all(!isequal(p4), full_parameters(sys2))
1352+
@test any(isequal(p5), full_parameters(sys2))
1353+
@test length(ModelingToolkit.defaults(sys2)) == 4
1354+
@test ModelingToolkit.defaults(sys2)[p5] == 9.0
1355+
@test length(ModelingToolkit.guesses(sys2)) == 3
1356+
@test ModelingToolkit.guesses(sys2)[p5] == 10.0
1357+
end

0 commit comments

Comments
 (0)