Skip to content

Commit 3fe0d7e

Browse files
Merge pull request #2984 from AayushSabharwal/as/discover-pdeps
fix: discover parameters from parameter dependencies, fix substitute
2 parents 268667f + 9b53b9f commit 3fe0d7e

File tree

7 files changed

+65
-2
lines changed

7 files changed

+65
-2
lines changed

src/systems/abstractsystem.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2940,7 +2940,14 @@ function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair},
29402940
rules = todict(map(r -> Symbolics.unwrap(r[1]) => Symbolics.unwrap(r[2]),
29412941
collect(rules)))
29422942
eqs = fast_substitute(equations(sys), rules)
2943-
ODESystem(eqs, get_iv(sys); name = nameof(sys))
2943+
pdeps = fast_substitute(parameter_dependencies(sys), rules)
2944+
defs = Dict(fast_substitute(k, rules) => fast_substitute(v, rules)
2945+
for (k, v) in defaults(sys))
2946+
guess = Dict(fast_substitute(k, rules) => fast_substitute(v, rules)
2947+
for (k, v) in guesses(sys))
2948+
subsys = map(s -> substitute(s, rules), get_systems(sys))
2949+
ODESystem(eqs, get_iv(sys); name = nameof(sys), defaults = defs,
2950+
guesses = guess, parameter_dependencies = pdeps, systems = subsys)
29442951
else
29452952
error("substituting symbols is not supported for $(typeof(sys))")
29462953
end

src/systems/diffeqs/odesystem.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,15 @@ function ODESystem(eqs, iv; kwargs...)
311311
push!(algeeq, eq)
312312
end
313313
end
314+
for eq in get(kwargs, :parameter_dependencies, Equation[])
315+
if eq isa Pair
316+
collect_vars!(allunknowns, ps, eq[1], iv)
317+
collect_vars!(allunknowns, ps, eq[2], iv)
318+
else
319+
collect_vars!(allunknowns, ps, eq.lhs, iv)
320+
collect_vars!(allunknowns, ps, eq.rhs, iv)
321+
end
322+
end
314323
for v in allunknowns
315324
isdelay(v, iv) || continue
316325
collect_vars!(allunknowns, ps, arguments(v)[1], iv)

src/systems/discrete_system/discrete_system.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,15 @@ function DiscreteSystem(eqs, iv; kwargs...)
183183
push!(diffvars, eq.lhs)
184184
end
185185
end
186+
for eq in get(kwargs, :parameter_dependencies, Equation[])
187+
if eq isa Pair
188+
collect_vars!(allunknowns, ps, eq[1], iv)
189+
collect_vars!(allunknowns, ps, eq[2], iv)
190+
else
191+
collect_vars!(allunknowns, ps, eq.lhs, iv)
192+
collect_vars!(allunknowns, ps, eq.rhs, iv)
193+
end
194+
end
186195
new_ps = OrderedSet()
187196
for p in ps
188197
if iscall(p) && operation(p) === getindex

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,15 @@ function NonlinearSystem(eqs; kwargs...)
170170
collect_vars!(allunknowns, ps, eq.lhs, nothing)
171171
collect_vars!(allunknowns, ps, eq.rhs, nothing)
172172
end
173+
for eq in get(kwargs, :parameter_dependencies, Equation[])
174+
if eq isa Pair
175+
collect_vars!(allunknowns, ps, eq[1], nothing)
176+
collect_vars!(allunknowns, ps, eq[2], nothing)
177+
else
178+
collect_vars!(allunknowns, ps, eq.lhs, nothing)
179+
collect_vars!(allunknowns, ps, eq.rhs, nothing)
180+
end
181+
end
173182
new_ps = OrderedSet()
174183
for p in ps
175184
if iscall(p) && operation(p) === getindex

test/odesystem.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,3 +1315,19 @@ end
13151315
@named sys = compose(sys, sys) # nest into a hierarchical system
13161316
@test t === sys.t === sys.sys.t
13171317
end
1318+
1319+
@testset "Substituting preserves parameter dependencies, defaults, guesses" begin
1320+
@parameters p1 p2
1321+
@variables x(t) y(t)
1322+
@named sys = ODESystem([D(x) ~ y + p2], t; parameter_dependencies = [p2 ~ 2p1],
1323+
defaults = [p1 => 1.0, p2 => 2.0], guesses = [p1 => 2.0, p2 => 3.0])
1324+
@parameters p3
1325+
sys2 = substitute(sys, [p1 => p3])
1326+
@test length(parameters(sys2)) == 1
1327+
@test is_parameter(sys2, p3)
1328+
@test !is_parameter(sys2, p1)
1329+
@test length(ModelingToolkit.defaults(sys2)) == 2
1330+
@test ModelingToolkit.defaults(sys2)[p3] == 1.0
1331+
@test length(ModelingToolkit.guesses(sys2)) == 2
1332+
@test ModelingToolkit.guesses(sys2)[p3] == 2.0
1333+
end

test/parameter_dependencies.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,3 +325,16 @@ end
325325
@test getp(sys, p1)(ps2) == 2.0
326326
@test getp(sys, p2)(ps2) == 4.0
327327
end
328+
329+
@testset "Discovery of parameters from dependencies" begin
330+
@parameters p1 p2
331+
@variables x(t) y(t)
332+
@named sys = ODESystem([D(x) ~ y + p2], t; parameter_dependencies = [p2 ~ 2p1])
333+
@test is_parameter(sys, p1)
334+
@named sys = NonlinearSystem([x * y^2 ~ y + p2]; parameter_dependencies = [p2 ~ 2p1])
335+
@test is_parameter(sys, p1)
336+
k = ShiftIndex(t)
337+
@named sys = DiscreteSystem(
338+
[x(k - 1) ~ x(k) + y(k) + p2], t; parameter_dependencies = [p2 ~ 2p1])
339+
@test is_parameter(sys, p1)
340+
end

test/split_parameters.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ sol = solve(prob, ImplicitEuler());
117117

118118
# ------------------------ Mixed Type Conserved
119119

120-
prob = ODEProblem(sys, [], tspan, []; tofloat = false)
120+
prob = ODEProblem(sys, [], tspan, []; tofloat = false, use_union = true)
121121

122122
@test prob.p isa Tuple{Vector{Float64}, Vector{Int64}}
123123
sol = solve(prob, ImplicitEuler());

0 commit comments

Comments
 (0)