Skip to content

Commit dc60fe8

Browse files
Merge pull request #1487 from ValentinKaisermayer/patch-process_DEProblem
Patch process_DEproblem
2 parents 029adda + c474956 commit dc60fe8

File tree

9 files changed

+114
-67
lines changed

9 files changed

+114
-67
lines changed

src/structural_transformation/codegen.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,8 @@ function ODAEProblem{iip}(
495495
ps = parameters(sys)
496496
defs = defaults(sys)
497497

498+
defs = ModelingToolkit.mergedefaults(defs,parammap,ps)
499+
defs = ModelingToolkit.mergedefaults(defs,u0map,dvs)
498500
u0 = ModelingToolkit.varmap_to_vars(u0map, dvs; defaults=defs)
499501
p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults=defs)
500502

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -557,36 +557,22 @@ function process_DEProblem(constructor, sys::AbstractODESystem,u0map,parammap;
557557
eqs = equations(sys)
558558
dvs = states(sys)
559559
ps = parameters(sys)
560-
defs = defaults(sys)
561560
iv = get_iv(sys)
562-
if parammap isa Dict
563-
u0defs = merge(parammap, defs)
564-
elseif eltype(parammap) <: Pair
565-
u0defs = merge(Dict(parammap), defs)
566-
elseif eltype(parammap) <: Number
567-
u0defs = merge(Dict(zip(ps, parammap)), defs)
568-
else
569-
u0defs = defs
570-
end
571-
if u0map isa Dict
572-
pdefs = merge(u0map, defs)
573-
elseif eltype(u0map) <: Pair
574-
pdefs = merge(Dict(u0map), defs)
575-
elseif eltype(u0map) <: Number
576-
pdefs = merge(Dict(zip(dvs, u0map)), defs)
577-
else
578-
pdefs = defs
579-
end
580-
581-
u0 = varmap_to_vars(u0map,dvs; defaults=u0defs)
561+
562+
defs = defaults(sys)
563+
defs = mergedefaults(defs,parammap,ps)
564+
defs = mergedefaults(defs,u0map,dvs)
565+
566+
u0 = varmap_to_vars(u0map,dvs; defaults=defs)
567+
p = varmap_to_vars(parammap,ps; defaults=defs)
582568
if implicit_dae && du0map !== nothing
583569
ddvs = map(Differential(iv), dvs)
584-
du0 = varmap_to_vars(du0map, ddvs; defaults=defaults, toterm=identity)
570+
defs = mergedefaults(defs,du0map, ddvs)
571+
du0 = varmap_to_vars(du0map,ddvs; defaults=defs, toterm=identity)
585572
else
586573
du0 = nothing
587574
ddvs = nothing
588575
end
589-
p = varmap_to_vars(parammap,ps; defaults=pdefs)
590576

591577
check_eqs_u0(eqs, dvs, u0; kwargs...)
592578

@@ -691,7 +677,7 @@ merge_cb(x, y) = CallbackSet(x, y)
691677

692678
"""
693679
```julia
694-
function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
680+
function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem,du0map,u0map,tspan,
695681
parammap=DiffEqBase.NullParameters();
696682
version = nothing, tgrad=false,
697683
jac = false,

src/systems/discrete_system/discrete_system.jl

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -174,33 +174,17 @@ function DiffEqBase.DiscreteProblem(sys::DiscreteSystem,u0map,tspan,
174174
ps = parameters(sys)
175175
eqs = equations(sys)
176176
eqs = linearize_eqs(sys, eqs)
177-
defs = defaults(sys)
178177
iv = get_iv(sys)
179-
180-
if parammap isa Dict
181-
u0defs = merge(parammap, defs)
182-
elseif eltype(parammap) <: Pair
183-
u0defs = merge(Dict(parammap), defs)
184-
elseif eltype(parammap) <: Number
185-
u0defs = merge(Dict(zip(ps, parammap)), defs)
186-
else
187-
u0defs = defs
188-
end
189-
if u0map isa Dict
190-
pdefs = merge(u0map, defs)
191-
elseif eltype(u0map) <: Pair
192-
pdefs = merge(Dict(u0map), defs)
193-
elseif eltype(u0map) <: Number
194-
pdefs = merge(Dict(zip(dvs, u0map)), defs)
195-
else
196-
pdefs = defs
197-
end
198-
199-
u0 = varmap_to_vars(u0map,dvs; defaults=u0defs)
178+
179+
defs = defaults(sys)
180+
defs = mergedefaults(defs,parammap,ps)
181+
defs = mergedefaults(defs,u0map,dvs)
182+
183+
u0 = varmap_to_vars(u0map,dvs; defaults=defs)
184+
p = varmap_to_vars(parammap,ps; defaults=defs)
200185

201186
rhss = [eq.rhs for eq in eqs]
202187
u = dvs
203-
p = varmap_to_vars(parammap,ps; defaults=pdefs)
204188

205189
f_gen = generate_function(sys; expression=Val{eval_expression}, expression_module=eval_module)
206190
f_oop, _ = (@RuntimeGeneratedFunction(eval_module, ex) for ex in f_gen)

src/systems/jumps/jumpsystem.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,17 @@ dprob = DiscreteProblem(js, u₀map, tspan, parammap)
220220
"""
221221
function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,Nothing},
222222
parammap=DiffEqBase.NullParameters(); checkbounds=false, kwargs...)
223+
224+
dvs = states(sys)
225+
ps = parameters(sys)
226+
223227
defs = defaults(sys)
224-
u0 = varmap_to_vars(u0map, states(sys); defaults=defs)
225-
p = varmap_to_vars(parammap, parameters(sys); defaults=defs)
228+
defs = mergedefaults(defs,parammap,ps)
229+
defs = mergedefaults(defs,u0map,dvs)
230+
231+
u0 = varmap_to_vars(u0map,dvs; defaults=defs)
232+
p = varmap_to_vars(parammap,ps; defaults=defs)
233+
226234
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
227235

228236
# just taken from abstractodesystem.jl for ODEFunction def

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,11 @@ function process_NonlinearProblem(constructor, sys::NonlinearSystem,u0map,paramm
264264
eqs = equations(sys)
265265
dvs = states(sys)
266266
ps = parameters(sys)
267+
267268
defs = defaults(sys)
269+
defs = mergedefaults(defs,parammap,ps)
270+
defs = mergedefaults(defs,u0map,dvs)
271+
268272
u0 = varmap_to_vars(u0map,dvs; defaults=defs)
269273
p = varmap_to_vars(parammap,ps; defaults=defs)
270274

src/systems/optimization/optimizationsystem.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,9 @@ DiffEqBase.OptimizationProblem(sys::OptimizationSystem,args...;kwargs...) =
129129

130130
"""
131131
```julia
132-
function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem,
132+
function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem,u0map,
133133
parammap=DiffEqBase.NullParameters();
134-
u0=nothing, lb=nothing, ub=nothing,
134+
lb=nothing, ub=nothing,
135135
grad = false,
136136
hess = false, sparse = false,
137137
checkbounds = false,
@@ -142,7 +142,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem,
142142
Generates an OptimizationProblem from an OptimizationSystem and allows for automatically
143143
symbolically calculating numerical enhancements.
144144
"""
145-
function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0,
145+
function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
146146
parammap=DiffEqBase.NullParameters();
147147
lb=nothing, ub=nothing,
148148
grad = false,
@@ -177,7 +177,10 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0,
177177
_f = DiffEqBase.OptimizationFunction{iip,AutoModelingToolkit,typeof(f),typeof(_grad),typeof(_hess),Nothing,Nothing,Nothing,Nothing}(f,AutoModelingToolkit(),_grad,_hess,nothing,nothing,nothing,nothing)
178178

179179
defs = defaults(sys)
180-
u0 = varmap_to_vars(u0,dvs; defaults=defs)
180+
defs = mergedefaults(defs,parammap,ps)
181+
defs = mergedefaults(defs,u0map,dvs)
182+
183+
u0 = varmap_to_vars(u0map,dvs; defaults=defs)
181184
p = varmap_to_vars(parammap,ps; defaults=defs)
182185
lb = varmap_to_vars(lb,dvs; check=false)
183186
ub = varmap_to_vars(ub,dvs; check=false)
@@ -233,7 +236,10 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0,
233236
end
234237

235238
defs = defaults(sys)
236-
u0 = varmap_to_vars(u0,dvs; defaults=defs)
239+
defs = mergedefaults(defs,parammap,ps)
240+
defs = mergedefaults(defs,u0map,dvs)
241+
242+
u0 = varmap_to_vars(u0map,dvs; defaults=defs)
237243
p = varmap_to_vars(parammap,ps; defaults=defs)
238244
lb = varmap_to_vars(lb,dvs)
239245
ub = varmap_to_vars(ub,dvs)

src/utils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,3 +459,15 @@ function get_substitutions_and_solved_states(sys; no_postprocess=false)
459459
end
460460
return pre, sol_states
461461
end
462+
463+
function mergedefaults(defaults, varmap, vars)
464+
defs = if varmap isa Dict
465+
merge(defaults, varmap)
466+
elseif eltype(varmap) <: Pair
467+
merge(defaults, Dict(varmap))
468+
elseif eltype(varmap) <: Number
469+
merge(defaults, Dict(zip(vars, varmap)))
470+
else
471+
defaults
472+
end
473+
end

src/variables.jl

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,23 +73,13 @@ function _varmap_to_vars(varmap::Dict, varlist; defaults=Dict(), check=false, to
7373
varmap = Dict(toterm(value(k))=>value(varmap[k]) for k in keys(varmap))
7474
# resolve symbolic parameter expressions
7575
for (p, v) in pairs(varmap)
76-
val = varmap[p] = fixpoint_sub(v, varmap)
76+
varmap[p] = fixpoint_sub(v, varmap)
7777
end
78-
vs = values(varmap)
79-
T′ = eltype(vs)
80-
if Base.isconcretetype(T′)
81-
T = T′
82-
else
83-
T = foldl((t, elem)->promote_type(t, eltype(elem)), vs; init=typeof(first(vs)))
84-
end
85-
out = Vector{T}(undef, length(varlist))
78+
8679
missingvars = setdiff(varlist, keys(varmap))
8780
check && (isempty(missingvars) || throw_missingvars(missingvars))
8881

89-
for (i, var) in enumerate(varlist)
90-
out[i] = varmap[var]
91-
end
92-
out
82+
out = [varmap[var] for var in varlist]
9383
end
9484

9585
@noinline throw_missingvars(vars) = throw(ArgumentError("$vars are missing from the variable map."))

test/odesystem.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,4 +591,59 @@ eqs[end] = D(D(z)) ~ α*x - β*y
591591
sol = solve(prob, Euler(); dt=0.1)
592592

593593
@test c[1] == length(sol)
594+
end
595+
596+
let
597+
@parameters t
598+
D = Differential(t)
599+
@variables x[1:2](t) = zeros(2)
600+
@variables y(t) = 0
601+
@parameters k = 1
602+
eqs= [
603+
D(x[1]) ~ x[2]
604+
D(x[2]) ~ -x[1] - 0.5 * x[2] + k
605+
y ~ 0.9 * x[1] + x[2]
606+
]
607+
@named sys = ODESystem(eqs, t, vcat(x, [y]), [k])
608+
sys = structural_simplify(sys)
609+
610+
u0 = [0.5, 0]
611+
du0 = 0 .* copy(u0)
612+
prob = DAEProblem(sys, du0, u0, (0, 50))
613+
@test prob.u0 u0
614+
@test prob.du0 du0
615+
@test prob.p [1]
616+
sol = solve(prob, IDA())
617+
@test isapprox(sol[x[1]][end], 1, atol=1e-3)
618+
619+
prob = DAEProblem(sys, [D(y) => 0, D(x[1]) => 0, D(x[2]) => 0], Pair[x[1] => 0.5], (0, 50))
620+
@test prob.u0 [0.5, 0]
621+
@test prob.du0 [0, 0]
622+
@test prob.p [1]
623+
sol = solve(prob, IDA())
624+
@test isapprox(sol[x[1]][end], 1, atol=1e-3)
625+
626+
prob = DAEProblem(sys, [D(y) => 0, D(x[1]) => 0, D(x[2]) => 0], Pair[x[1] => 0.5], (0, 50), [k => 2])
627+
@test prob.u0 [0.5, 0]
628+
@test prob.du0 [0, 0]
629+
@test prob.p [2]
630+
sol = solve(prob, IDA())
631+
@test isapprox(sol[x[1]][end], 2, atol=1e-3)
632+
633+
# no initial conditions for D(x[1]) and D(x[2]) provided
634+
@test_throws ArgumentError prob = DAEProblem(sys, Pair[], Pair[], (0, 50))
635+
end
636+
637+
#issue 1475 (mixed numeric type)
638+
let
639+
@parameters k1 k2
640+
@variables t, A(t)
641+
D = Differential(t)
642+
eqs = [D(A) ~ -k1*k2*A]
643+
@named sys = ODESystem(eqs,t)
644+
u0map = [A => 1.0]
645+
pmap = (k1 => 1.0, k2 => 1)
646+
tspan = (0.0,1.0)
647+
prob = ODEProblem(sys, u0map, tspan, pmap)
648+
@test prob.p === Tuple([(Dict(pmap))[k] for k in values(parameters(sys))])
594649
end

0 commit comments

Comments
 (0)