Skip to content

Commit 0165afd

Browse files
Merge pull request #790 from SciML/mtkitize
make modelingtoolkitize more robust to weird arrays
2 parents d31ee38 + 51a8ec8 commit 0165afd

File tree

2 files changed

+37
-5
lines changed

2 files changed

+37
-5
lines changed

src/systems/diffeqs/modelingtoolkitize.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem)
1515
end
1616

1717
var(x, i) = Num(Sym{FnType{Tuple{symtype(t)}, Real}}(nameof(Variable(x, i))))
18-
vars = reshape([var(:x, i)(value(t)) for i in eachindex(prob.u0)],size(prob.u0))
18+
vars = ArrayInterface.restructure(prob.u0,[var(:x, i)(ModelingToolkit.value(t)) for i in eachindex(prob.u0)])
1919
params = p isa DiffEqBase.NullParameters ? [] :
2020
reshape([Num(Sym{Real}(nameof(Variable(, i)))) for i in eachindex(p)],size(p))
2121

@@ -31,7 +31,14 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem)
3131
end
3232

3333
eqs = vcat([rhs[i] ~ lhs[i] for i in eachindex(prob.u0)]...)
34-
de = ODESystem(eqs,t,vec(vars),vec(params))
34+
35+
params = if ndims(params) == 0
36+
[params[1]]
37+
else
38+
Vector(vec(params))
39+
end
40+
41+
de = ODESystem(eqs,t,Vector(vec(vars)),params)
3542

3643
de
3744
end
@@ -53,7 +60,7 @@ function modelingtoolkitize(prob::DiffEqBase.SDEProblem)
5360
p = prob.p
5461
end
5562
var(x, i) = Num(Sym{FnType{Tuple{symtype(t)}, Real}}(nameof(Variable(x, i))))
56-
vars = reshape([var(:x, i)(value(t)) for i in eachindex(prob.u0)],size(prob.u0))
63+
vars = ArrayInterface.restructure(prob.u0,[var(:x, i)(ModelingToolkit.value(t)) for i in eachindex(prob.u0)])
5764
params = p isa DiffEqBase.NullParameters ? [] :
5865
reshape([Num(Sym{Real}(nameof(Variable(, i)))) for i in eachindex(p)],size(p))
5966

@@ -83,7 +90,13 @@ function modelingtoolkitize(prob::DiffEqBase.SDEProblem)
8390
end
8491
deqs = vcat([rhs[i] ~ lhs[i] for i in eachindex(prob.u0)]...)
8592

86-
de = SDESystem(deqs,neqs,t,vec(vars),vec(params))
93+
params = if ndims(params) == 0
94+
[params[1]]
95+
else
96+
Vector(vec(params))
97+
end
98+
99+
de = SDESystem(deqs,neqs,t,Vector(vec(vars)),params)
87100

88101
de
89102
end

test/modelingtoolkitize.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using OrdinaryDiffEq, ModelingToolkit, Test
2-
using GalacticOptim, Optim
2+
using GalacticOptim, Optim, RecursiveArrayTools
33

44
N = 32
55
const xyd_brusselator = range(0,stop=1,length=N)
@@ -141,3 +141,22 @@ problem = ODEProblem(SIRD_ac!, ℬ, 𝒯, 𝒫)
141141
sys = modelingtoolkitize(problem)
142142
fast_problem = ODEProblem(sys,ℬ, 𝒯, 𝒫 )
143143
@time solution = solve(fast_problem, Tsit5(), saveat = 1:final_time)
144+
145+
## Issue #778
146+
147+
r0 = [1131.340, -2282.343, 6672.423]
148+
v0 = [-5.64305, 4.30333, 2.42879]
149+
Δt = 86400.0*365
150+
μ = 398600.4418
151+
rv0 = ArrayPartition(r0,v0)
152+
153+
function f(dy, y, μ, t)
154+
r = sqrt(sum(y[1,:].^2))
155+
dy[1,:] = y[2,:]
156+
dy[2,:] = -μ .* y[1,:] / r^3
157+
end
158+
159+
prob = ODEProblem(f, rv0, (0.0, Δt), μ)
160+
sol = solve(prob, Vern8())
161+
162+
modelingtoolkitize(prob)

0 commit comments

Comments
 (0)