Skip to content

Commit 7e1f9d7

Browse files
committed
Minor fixes
1 parent bc22ed4 commit 7e1f9d7

File tree

2 files changed

+20
-15
lines changed

2 files changed

+20
-15
lines changed

src/systems/diffeqs/modelingtoolkitize.jl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,17 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem)
2323
D = Differential(t)
2424
mm = prob.f.mass_matrix
2525

26-
lhs = map(mm * vars) do v
27-
if iszero(v)
28-
0
29-
elseif v in var_set
30-
D(v)
31-
else
32-
error("Non-permuation mass matrix is not supported.")
26+
if mm === I
27+
lhs = map(v->D(v), vars)
28+
else
29+
lhs = map(mm * vars) do v
30+
if iszero(v)
31+
0
32+
elseif v in var_set
33+
D(v)
34+
else
35+
error("Non-permuation mass matrix is not supported.")
36+
end
3337
end
3438
end
3539

@@ -42,17 +46,17 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem)
4246

4347
eqs = vcat([lhs[i] ~ rhs[i] for i in eachindex(prob.u0)]...)
4448

45-
sts = Vector(vec(vars))
49+
sts = vec(collect(vars))
4650
params = if ndims(params) == 0
4751
[params[1]]
4852
else
49-
Vector(vec(params))
53+
vec(collect(params))
5054
end
5155

5256
de = ODESystem(
5357
eqs, t, sts, params,
54-
default_u0=Dict(sts .=> prob.u0),
55-
default_p=Dict(params .=> prob.p)
58+
default_u0=Dict(sts .=> vec(collect(prob.u0))),
59+
default_p=Dict(params .=> vec(collect(prob.p))),
5660
)
5761

5862
de

test/modelingtoolkitize.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ sol = solve(prob, Vern8())
162162
modelingtoolkitize(prob)
163163

164164
# Index reduction and mass matrix handling
165+
using LinearAlgebra
165166
function pendulum!(du, u, p, t)
166167
x, dx, y, dy, T = u
167168
g, L = p
@@ -177,10 +178,10 @@ u0 = [1.0, 0, 0, 0, 0]
177178
p = [9.8, 1]
178179
tspan = (0, 10.0)
179180
pendulum_prob = ODEProblem(pendulum_fun!, u0, tspan, p)
180-
pendulum_sys = dae_index_lowering(modelingtoolkitize(pendulum_prob))
181+
pendulum_sys_org = modelingtoolkitize(pendulum_prob)
182+
sts = states(pendulum_sys_org)
183+
pendulum_sys = dae_index_lowering(pendulum_sys_org)
181184
prob = ODEProblem(pendulum_sys, Pair[], tspan)
182185
sol = solve(prob, Rodas4())
183-
@parameters t
184-
@variables x[1:5](t)
185-
l2 = sol[x[1]].^2 + sol[x[3]].^2
186+
l2 = sol[sts[1]].^2 + sol[sts[3]].^2
186187
@test all(l->abs(sqrt(l) - 1) < 0.05, l2)

0 commit comments

Comments
 (0)