Skip to content

Commit ea46142

Browse files
committed
Collect default values of variables into defaults
1 parent 0fdba85 commit ea46142

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ function ODESystem(
101101
defaults=_merge(Dict(default_u0), Dict(default_p)),
102102
connection_type=nothing,
103103
)
104-
105104
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."
106105

107106
iv′ = value(scalarize(iv))
@@ -115,6 +114,13 @@ function ODESystem(
115114
defaults = todict(defaults)
116115
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))
117116

117+
iv′ = value(scalarize(iv))
118+
dvs′ = value.(scalarize(dvs))
119+
ps′ = value.(scalarize(ps))
120+
121+
collect_defaults!(defaults, dvs′)
122+
collect_defaults!(defaults, ps′)
123+
118124
tgrad = RefValue(Vector{Num}(undef, 0))
119125
jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
120126
ctrl_jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
@@ -363,4 +369,4 @@ function convert_system(::Type{<:ODESystem}, sys, t; name=nameof(sys))
363369
neweqs = map(sub, equations(sys))
364370
defs = Dict(sub(k) => sub(v) for (k, v) in defaults(sys))
365371
return ODESystem(neweqs, t, newsts, parameters(sys); defaults=defs, name=name)
366-
end
372+
end

src/utils.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ function check_variables(dvs, iv)
116116
isequal(iv, dv) && throw(ArgumentError("Independent variable $iv not allowed in dependent variables."))
117117
isequal(iv, iv_from_nested_derivative(dv)) || throw(ArgumentError("Variable $dv is not a function of independent variable $iv."))
118118
end
119-
end
119+
end
120120

121121
"Get all the independent variables with respect to which differentials are taken."
122122
function collect_differentials(eqs)
@@ -155,3 +155,14 @@ end
155155
iv_from_nested_derivative(x::Term) = operation(x) isa Differential ? iv_from_nested_derivative(arguments(x)[1]) : arguments(x)[1]
156156
iv_from_nested_derivative(x::Sym) = x
157157
iv_from_nested_derivative(x) = missing
158+
159+
hasdefault(v) = hasmetadata(v, Symbolics.VariableDefaultValue)
160+
getdefault(v) = value(getmetadata(v, Symbolics.VariableDefaultValue))
161+
setdefault(v, val) = val === nothing ? v : setmetadata(v, Symbolics.VariableDefaultValue, value(val))
162+
163+
function collect_defaults!(defs, vars)
164+
for v in vars; (haskey(defs, v) || !hasdefault(v)) && continue
165+
defs[v] = getdefault(v)
166+
end
167+
return defs
168+
end

0 commit comments

Comments
 (0)