Skip to content

Commit fdb4770

Browse files
Merge pull request #1017 from pepijndevos/syminit
support symbolic parameters at problem level
2 parents 0a0a7cb + 1a7f1ba commit fdb4770

File tree

2 files changed

+41
-2
lines changed

2 files changed

+41
-2
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,16 +397,34 @@ function process_DEProblem(constructor, sys::AbstractODESystem,u0map,parammap;
397397
ps = parameters(sys)
398398
defs = defaults(sys)
399399
iv = independent_variable(sys)
400+
if parammap isa Dict
401+
u0defs = merge(parammap, defs)
402+
elseif eltype(parammap) <: Pair
403+
u0defs = merge(Dict(parammap), defs)
404+
elseif eltype(parammap) <: Number
405+
u0defs = merge(Dict(zip(ps, parammap)), defs)
406+
else
407+
u0defs = defs
408+
end
409+
if u0map isa Dict
410+
pdefs = merge(u0map, defs)
411+
elseif eltype(u0map) <: Pair
412+
pdefs = merge(Dict(u0map), defs)
413+
elseif eltype(u0map) <: Number
414+
pdefs = merge(Dict(zip(dvs, u0map)), defs)
415+
else
416+
pdefs = defs
417+
end
400418

401-
u0 = varmap_to_vars(u0map,dvs; defaults=defs)
419+
u0 = varmap_to_vars(u0map,dvs; defaults=u0defs)
402420
if implicit_dae && du0map !== nothing
403421
ddvs = map(Differential(iv), dvs)
404422
du0 = varmap_to_vars(du0map, ddvs; defaults=defaults, toterm=identity)
405423
else
406424
du0 = nothing
407425
ddvs = nothing
408426
end
409-
p = varmap_to_vars(parammap,ps; defaults=defs)
427+
p = varmap_to_vars(parammap,ps; defaults=pdefs)
410428

411429
check_eqs_u0(eqs, dvs, u0)
412430

test/symbolic_parameters.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,24 @@ prob = NonlinearProblem(top, [states(ns, u)=>1.0, a=>1.0], Pair[])
4545
prob = NonlinearProblem(top, [states(ns, u)=>1.0, a=>1.0])
4646
@test prob.u0 == [1.0, 0.5, 1.1, 0.9]
4747
@show sol = solve(prob,NewtonRaphson())
48+
49+
# test initial conditions and parameters at the problem level
50+
pars = @parameters(begin
51+
x0
52+
t
53+
end)
54+
vars = @variables(begin
55+
x(t)
56+
end)
57+
der = Differential(t)
58+
eqs = [der(x) ~ x]
59+
sys = ODESystem(eqs, t, vars, [x0])
60+
pars = [
61+
x0 => 10.0,
62+
]
63+
initialValues = [
64+
x => x0
65+
]
66+
tspan = (0.0, 1.0)
67+
problem = ODEProblem(sys, initialValues, tspan, pars)
68+
@test problem.u0 isa Vector{Float64}

0 commit comments

Comments
 (0)