Skip to content

Commit aef01dc

Browse files
feat: track parameter dependency defaults and guesses
1 parent 67c1c5f commit aef01dc

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
232232
ctrl′ = value.(controls)
233233
dvs′ = value.(dvs)
234234
dvs′ = filter(x -> !isdelay(x, iv), dvs′)
235+
parameter_dependencies, ps′ = process_parameter_dependencies(
236+
parameter_dependencies, ps′)
235237
if !(isempty(default_u0) && isempty(default_p))
236238
Base.depwarn(
237239
"`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
@@ -241,6 +243,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
241243
var_to_name = Dict()
242244
process_variables!(var_to_name, defaults, dvs′)
243245
process_variables!(var_to_name, defaults, ps′)
246+
process_variables!(var_to_name, defaults, [eq.lhs for eq in parameter_dependencies])
247+
process_variables!(var_to_name, defaults, [eq.rhs for eq in parameter_dependencies])
244248
defaults = Dict{Any, Any}(value(k) => value(v)
245249
for (k, v) in pairs(defaults) if v !== nothing)
246250

@@ -252,9 +256,13 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
252256
hasaguess = findall(!isnothing, syspsguesses)
253257
ps_guesses = ps′[hasaguess] .=> syspsguesses[hasaguess]
254258
syspsguesses = isempty(ps_guesses) ? Dict() : todict(ps_guesses)
255-
256-
guesses = merge(sysdvsguesses, syspsguesses, todict(guesses))
257-
guesses = Dict{Any, Any}(value(k) => value(v) for (k, v) in pairs(guesses))
259+
syspdepguesses = [ModelingToolkit.getguess(eq.lhs) for eq in parameter_dependencies]
260+
hasaguess = findall(!isnothing, syspdepguesses)
261+
pdep_guesses = [eq.lhs for eq in parameter_dependencies][hasaguess] .=> syspdepguesses[hasaguess]
262+
syspdepguesses = isempty(pdep_guesses) ? Dict() : todict(pdep_guesses)
263+
264+
guesses = merge(sysdvsguesses, syspsguesses, syspdepguesses, todict(guesses))
265+
guesses = Dict{Any, Any}(value(k) => value(v) for (k, v) in pairs(guesses) if v !== nothing)
258266

259267
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
260268

@@ -269,8 +277,6 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
269277
end
270278
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
271279
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
272-
parameter_dependencies, ps′ = process_parameter_dependencies(
273-
parameter_dependencies, ps′)
274280
ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
275281
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
276282
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, guesses, nothing, initializesystem,

0 commit comments

Comments
 (0)