Skip to content

Commit 325539e

Browse files
committed
Add test
1 parent eb6357c commit 325539e

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

src/systems/diffeqs/modelingtoolkitize.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,23 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
5454

5555
sts = vec(collect(vars))
5656

57+
_params = params
5758
params = values(params)
5859
params = if params isa Number || (params isa Array && ndims(params) == 0)
5960
[params[1]]
6061
else
6162
vec(collect(params))
6263
end
6364
default_u0 = Dict(sts .=> vec(collect(prob.u0)))
64-
default_p = has_p ? Dict(params .=> vec(collect(prob.p))) : Dict()
65+
default_p = if has_p
66+
if prob.p isa AbstractDict
67+
Dict(v => prob.p[k] for (k, v) in pairs(_params))
68+
else
69+
Dict(params .=> vec(collect(prob.p)))
70+
end
71+
else
72+
Dict()
73+
end
6574

6675
de = ODESystem(eqs, t, sts, params,
6776
defaults = merge(default_u0, default_p);
@@ -95,7 +104,7 @@ function define_params(p)
95104
end
96105

97106
function define_params(p::AbstractDict)
98-
Dict(k => toparam(variable(, i)) for (i, k) in zip(1:length(p), keys(p)))
107+
OrderedDict(k => toparam(variable(, i)) for (i, k) in zip(1:length(p), keys(p)))
99108
end
100109

101110
function define_params(p::Union{SLArray, LArray})

test/modelingtoolkitize.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using OrdinaryDiffEq, ModelingToolkit, Test
1+
using OrdinaryDiffEq, ModelingToolkit, DataStructures, Test
22
using Optimization, RecursiveArrayTools, OptimizationOptimJL
33

44
N = 32
@@ -277,3 +277,15 @@ params = (1, 1)
277277
prob = ODEProblem(ode_prob, [1 1], (0, 1), params)
278278
sys = modelingtoolkitize(prob)
279279
@test nameof.(parameters(sys)) == [:α₁, :α₂]
280+
281+
function ode_prob_dict(du, u, p, t)
282+
du[1] = u[1] + p[:a]
283+
du[2] = u[2] + p[:b]
284+
nothing
285+
end
286+
params = OrderedDict(:a => 10, :b => 20)
287+
u0 = [1, 2.0]
288+
prob = ODEProblem(ode_prob_dict, u0, (0.0, 1.0), params)
289+
sys = modelingtoolkitize(prob)
290+
@test [ModelingToolkit.defaults(sys)[s] for s in states(sys)] == u0
291+
@test [ModelingToolkit.defaults(sys)[s] for s in parameters(sys)] == [10, 20]

0 commit comments

Comments
 (0)