Skip to content

Commit f52d533

Browse files
Merge pull request #840 from pepijndevos/cross_defaults
Allow u0 and p defaults to refer to each other
2 parents 52a031b + 569ad7a commit f52d533

File tree

6 files changed

+27
-23
lines changed

6 files changed

+27
-23
lines changed

src/structural_transformation/codegen.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,12 +338,12 @@ function ODAEProblem{iip}(
338338
s = structure(sys)
339339
@unpack fullvars = s
340340
dvs = fullvars[diffvars_range(s)]
341+
defaults = merge(default_p(sys), default_u0(sys))
341342
u0map′ = ModelingToolkit.lower_mapnames(u0map, independent_variable(sys))
342-
u0 = ModelingToolkit.varmap_to_vars(u0map′, dvs; defaults=default_u0(sys))
343+
u0 = ModelingToolkit.varmap_to_vars(u0map′, dvs; defaults=defaults)
343344

344345
ps = parameters(sys)
345-
d_p = default_p(sys)
346-
if parammap isa DiffEqBase.NullParameters && isempty(d_p)
346+
if parammap isa DiffEqBase.NullParameters && isempty(default_p(sys))
347347
isempty(ps) || throw(ArgumentError("The model has non-empty parameters but no parameters are specified in the problem."))
348348
p = parammap
349349
else
@@ -352,7 +352,7 @@ function ODAEProblem{iip}(
352352
else
353353
pp = ModelingToolkit.lower_mapnames(parammap)
354354
end
355-
p = ModelingToolkit.varmap_to_vars(pp, ps; defaults=d_p)
355+
p = ModelingToolkit.varmap_to_vars(pp, ps; defaults=defaults)
356356
end
357357

358358
ODEProblem{iip}(build_torn_function(sys; kw...), u0, tspan, p; kw...)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,20 +273,20 @@ function process_DEProblem(constructor, sys::AbstractODESystem,u0map,parammap;
273273
kwargs...)
274274
dvs = states(sys)
275275
ps = parameters(sys)
276+
defaults = merge(default_p(sys), default_u0(sys))
276277

277278
if u0map !== nothing
278279
u0map′ = lower_mapnames(u0map,get_iv(sys))
279-
u0 = varmap_to_vars(u0map′,dvs; defaults=default_u0(sys))
280+
u0 = varmap_to_vars(u0map′,dvs; defaults=defaults)
280281
else
281282
u0 = nothing
282283
end
283284

284-
defp = default_p(sys)
285285
if !(parammap isa DiffEqBase.NullParameters)
286286
parammap′ = lower_mapnames(parammap)
287-
p = varmap_to_vars(parammap′,ps; defaults=defp)
288-
elseif !isempty(defp)
289-
p = varmap_to_vars(Dict(),ps; defaults=defp)
287+
p = varmap_to_vars(parammap′,ps; defaults=defaults)
288+
elseif !isempty(defaults)
289+
p = varmap_to_vars(Dict(),ps; defaults=defaults)
290290
else
291291
p = ps
292292
end

src/systems/jumps/jumpsystem.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,9 @@ dprob = DiscreteProblem(js, u₀map, tspan, parammap)
204204
"""
205205
function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,Nothing},
206206
parammap=DiffEqBase.NullParameters(); kwargs...)
207-
u0 = varmap_to_vars(u0map, states(sys); defaults=default_u0(sys))
208-
p = varmap_to_vars(parammap, parameters(sys); defaults=default_p(sys))
207+
defaults = merge(default_p(sys), default_u0(sys))
208+
u0 = varmap_to_vars(u0map, states(sys); defaults=defaults)
209+
p = varmap_to_vars(parammap, parameters(sys); defaults=defaults)
209210
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
210211
df = DiscreteFunction{true,true}(f, syms=Symbol.(states(sys)))
211212
DiscreteProblem(df, u0, tspan, p; kwargs...)
@@ -232,8 +233,9 @@ dprob = DiscreteProblem(js, u₀map, tspan, parammap)
232233
"""
233234
function DiscreteProblemExpr(sys::JumpSystem, u0map, tspan::Union{Tuple,Nothing},
234235
parammap=DiffEqBase.NullParameters(); kwargs...)
235-
u0 = varmap_to_vars(u0map, states(sys); defaults=default_u0(sys))
236-
p = varmap_to_vars(parammap, parameters(sys); defaults=default_p(sys))
236+
defaults = merge(default_p(sys), default_u0(sys))
237+
u0 = varmap_to_vars(u0map, states(sys); defaults=defaults)
238+
p = varmap_to_vars(parammap, parameters(sys); defaults=defaults)
237239
# identity function to make syms works
238240
quote
239241
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,14 +211,14 @@ function process_NonlinearProblem(constructor, sys::NonlinearSystem,u0map,paramm
211211
dvs = states(sys)
212212
ps = parameters(sys)
213213
u0map′ = lower_mapnames(u0map)
214-
u0 = varmap_to_vars(u0map′,dvs; defaults=default_u0(sys))
215-
defp = default_p(sys)
214+
defaults = merge(default_p(sys), default_u0(sys))
215+
u0 = varmap_to_vars(u0map′,dvs; defaults=defaults)
216216

217217
if !(parammap isa DiffEqBase.NullParameters)
218218
parammap′ = lower_mapnames(parammap)
219-
p = varmap_to_vars(parammap′,ps; defaults=defp)
220-
elseif !isempty(defp)
221-
p = varmap_to_vars(Dict(),ps; defaults=defp)
219+
p = varmap_to_vars(parammap′,ps; defaults=defaults)
220+
elseif !isempty(default_p(sys))
221+
p = varmap_to_vars(Dict(),ps; defaults=defaults)
222222
else
223223
p = ps
224224
end

src/systems/optimization/optimizationsystem.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,9 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0,
157157

158158
_f = DiffEqBase.OptimizationFunction{iip,AutoModelingToolkit,typeof(f),typeof(_grad),typeof(_hess),Nothing,Nothing,Nothing,Nothing}(f,AutoModelingToolkit(),_grad,_hess,nothing,nothing,nothing,nothing)
159159

160-
u0 = varmap_to_vars(u0,dvs; defaults=default_u0(sys))
161-
p = varmap_to_vars(parammap,ps; defaults=default_p(sys))
160+
defaults = merge(default_p(sys), default_u0(sys))
161+
u0 = varmap_to_vars(u0,dvs; defaults=defaults)
162+
p = varmap_to_vars(parammap,ps; defaults=defaults)
162163
lb = varmap_to_vars(lb,dvs)
163164
ub = varmap_to_vars(ub,dvs)
164165
OptimizationProblem{iip}(_f,u0,p;lb=lb,ub=ub,kwargs...)
@@ -212,8 +213,9 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0,
212213
_hess = :nothing
213214
end
214215

215-
u0 = varmap_to_vars(u0,dvs; defaults=default_u0(sys))
216-
p = varmap_to_vars(parammap,ps; defaults=default_p(sys))
216+
defaults = merge(default_p(sys), default_u0(sys))
217+
u0 = varmap_to_vars(u0,dvs; defaults=defaults)
218+
p = varmap_to_vars(parammap,ps; defaults=defaults)
217219
lb = varmap_to_vars(lb,dvs)
218220
ub = varmap_to_vars(ub,dvs)
219221
quote

test/symbolic_parameters.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ par = [
1616
]
1717
u0 = Pair{Num, Any}[
1818
x => u,
19-
y => u,
19+
y => σ, # default u0 from default p
2020
z => u-0.1,
2121
]
2222
ns = NonlinearSystem(eqs, [x,y,z],[σ,ρ,β], name=:ns, default_p=par, default_u0=u0)

0 commit comments

Comments
 (0)