@@ -17,7 +17,7 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
17
17
params = if has_p
18
18
_params = define_params (p)
19
19
p isa Number ? _params[1 ] :
20
- (p isa Tuple || p isa NamedTuple ? _params :
20
+ (p isa Tuple || p isa NamedTuple || p isa AbstractDict ? _params :
21
21
ArrayInterfaceCore. restructure (p, _params))
22
22
else
23
23
[]
@@ -44,6 +44,7 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
44
44
45
45
if DiffEqBase. isinplace (prob)
46
46
rhs = ArrayInterfaceCore. restructure (prob. u0, similar (vars, Num))
47
+ fill! (rhs, 0 )
47
48
prob. f (rhs, vars, params, t)
48
49
else
49
50
rhs = prob. f (vars, params, t)
@@ -53,13 +54,23 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
53
54
54
55
sts = vec (collect (vars))
55
56
57
+ _params = params
58
+ params = values (params)
56
59
params = if params isa Number || (params isa Array && ndims (params) == 0 )
57
60
[params[1 ]]
58
61
else
59
62
vec (collect (params))
60
63
end
61
64
default_u0 = Dict (sts .=> vec (collect (prob. u0)))
62
- default_p = has_p ? Dict (params .=> vec (collect (prob. p))) : Dict ()
65
+ default_p = if has_p
66
+ if prob. p isa AbstractDict
67
+ Dict (v => prob. p[k] for (k, v) in pairs (_params))
68
+ else
69
+ Dict (params .=> vec (collect (prob. p)))
70
+ end
71
+ else
72
+ Dict ()
73
+ end
63
74
64
75
de = ODESystem (eqs, t, sts, params,
65
76
defaults = merge (default_u0, default_p);
@@ -69,29 +80,33 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
69
80
de
70
81
end
71
82
72
- _defvaridx (x, i, t ) = variable (x, i, T = SymbolicUtils. FnType{Tuple, Real})
73
- _defvar (x, t ) = variable (x, T = SymbolicUtils. FnType{Tuple, Real})
83
+ _defvaridx (x, i) = variable (x, i, T = SymbolicUtils. FnType{Tuple, Real})
84
+ _defvar (x) = variable (x, T = SymbolicUtils. FnType{Tuple, Real})
74
85
75
86
function define_vars (u, t)
76
- _vars = [_defvaridx (:x , i, t )(t) for i in eachindex (u)]
87
+ [_defvaridx (:x , i)(t) for i in eachindex (u)]
77
88
end
78
89
79
90
function define_vars (u:: Union{SLArray, LArray} , t)
80
- _vars = [_defvar (x, t )(t) for x in LabelledArrays. symnames (typeof (u))]
91
+ [_defvar (x)(t) for x in LabelledArrays. symnames (typeof (u))]
81
92
end
82
93
83
94
function define_vars (u:: Tuple , t)
84
- _vars = tuple ((_defvaridx (:x , i, t )(ModelingToolkit. value (t)) for i in eachindex (u)). .. )
95
+ tuple ((_defvaridx (:x , i)(ModelingToolkit. value (t)) for i in eachindex (u)). .. )
85
96
end
86
97
87
98
function define_vars (u:: NamedTuple , t)
88
- _vars = NamedTuple (x => _defvar (x, t )(ModelingToolkit. value (t)) for x in keys (u))
99
+ NamedTuple (x => _defvar (x)(ModelingToolkit. value (t)) for x in keys (u))
89
100
end
90
101
91
102
function define_params (p)
92
103
[toparam (variable (:α , i)) for i in eachindex (p)]
93
104
end
94
105
106
+ function define_params (p:: AbstractDict )
107
+ OrderedDict (k => toparam (variable (:α , i)) for (i, k) in zip (1 : length (p), keys (p)))
108
+ end
109
+
95
110
function define_params (p:: Union{SLArray, LArray} )
96
111
[toparam (variable (x)) for x in LabelledArrays. symnames (typeof (p))]
97
112
end
0 commit comments