Skip to content

Commit 6260098

Browse files
Merge pull request #861 from SciML/myb/ex
Handle mass matrix in modelingtoolkitize
2 parents f42e3f8 + f4ecf12 commit 6260098

File tree

4 files changed

+60
-19
lines changed

4 files changed

+60
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ModelingToolkit"
22
uuid = "961ee093-0014-501f-94e3-6117800e7a78"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "5.12.1"
4+
version = "5.13.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/structural_transformation/pantelides.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,11 @@ function pantelides_reassemble(sys, eqassoc, assign)
6363
lhsarg1 = arguments(eq.lhs)[1]
6464
@assert !(lhsarg1 isa Differential) "The equation $eq is not first order"
6565
i = get(d_dict, lhsarg1, nothing)
66-
if i !== nothing
67-
lhs = D(out_vars[varassoc[i]])
68-
if lhs in lhss
69-
# check only trivial equations are removed
70-
@assert isequal(diff2term(D(eq.rhs)), diff2term(lhs)) "The duplicate equation is not trivial: $eq"
71-
lhs = Num(nothing)
72-
end
73-
lhs
74-
else
66+
if i === nothing
7567
D(eq.lhs)
68+
else
69+
# remove clashing equations
70+
lhs = Num(nothing)
7671
end
7772
else
7873
D(eq.lhs)
@@ -86,8 +81,10 @@ function pantelides_reassemble(sys, eqassoc, assign)
8681
final_vars = unique(filter(x->!(operation(x) isa Differential), fullvars))
8782
final_eqs = map(identity, filter(x->value(x.lhs) !== nothing, out_eqs[sort(filter(x->x != UNASSIGNED, assign))]))
8883

89-
# remove clashing equations (from order lowering vs index reduction)
90-
return ODESystem(final_eqs, independent_variable(sys), final_vars, parameters(sys))
84+
@set! sys.eqs = final_eqs
85+
@set! sys.states = final_vars
86+
@set! sys.structure = nothing
87+
return sys
9188
end
9289

9390
"""

src/systems/diffeqs/modelingtoolkitize.jl

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,46 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem)
1818
vars = ArrayInterface.restructure(prob.u0,[var(:x, i)(ModelingToolkit.value(t)) for i in eachindex(prob.u0)])
1919
params = p isa DiffEqBase.NullParameters ? [] :
2020
reshape([Num(Sym{Real}(nameof(Variable(, i)))) for i in eachindex(p)],size(p))
21+
var_set = Set(vars)
2122

2223
D = Differential(t)
24+
mm = prob.f.mass_matrix
2325

24-
rhs = [D(var) for var in vars]
26+
if mm === I
27+
lhs = map(v->D(v), vars)
28+
else
29+
lhs = map(mm * vars) do v
30+
if iszero(v)
31+
0
32+
elseif v in var_set
33+
D(v)
34+
else
35+
error("Non-permuation mass matrix is not supported.")
36+
end
37+
end
38+
end
2539

2640
if DiffEqBase.isinplace(prob)
27-
lhs = similar(vars, Num)
28-
prob.f(lhs, vars, params, t)
41+
rhs = similar(vars, Num)
42+
prob.f(rhs, vars, params, t)
2943
else
30-
lhs = prob.f(vars, params, t)
44+
rhs = prob.f(vars, params, t)
3145
end
3246

33-
eqs = vcat([rhs[i] ~ lhs[i] for i in eachindex(prob.u0)]...)
47+
eqs = vcat([lhs[i] ~ rhs[i] for i in eachindex(prob.u0)]...)
3448

49+
sts = vec(collect(vars))
3550
params = if ndims(params) == 0
3651
[params[1]]
3752
else
38-
Vector(vec(params))
53+
vec(collect(params))
3954
end
4055

41-
de = ODESystem(eqs,t,Vector(vec(vars)),params)
56+
de = ODESystem(
57+
eqs, t, sts, params,
58+
default_u0=Dict(sts .=> vec(collect(prob.u0))),
59+
default_p=Dict(params .=> vec(collect(prob.p))),
60+
)
4261

4362
de
4463
end

test/modelingtoolkitize.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,28 @@ prob = ODEProblem(f, rv0, (0.0, Δt), μ)
160160
sol = solve(prob, Vern8())
161161

162162
modelingtoolkitize(prob)
163+
164+
# Index reduction and mass matrix handling
165+
using LinearAlgebra
166+
function pendulum!(du, u, p, t)
167+
x, dx, y, dy, T = u
168+
g, L = p
169+
du[1] = dx
170+
du[2] = T*x
171+
du[3] = dy
172+
du[4] = T*y - g
173+
du[5] = x^2 + y^2 - L^2
174+
return nothing
175+
end
176+
pendulum_fun! = ODEFunction(pendulum!, mass_matrix=Diagonal([1,1,1,1,0]))
177+
u0 = [1.0, 0, 0, 0, 0]
178+
p = [9.8, 1]
179+
tspan = (0, 10.0)
180+
pendulum_prob = ODEProblem(pendulum_fun!, u0, tspan, p)
181+
pendulum_sys_org = modelingtoolkitize(pendulum_prob)
182+
sts = states(pendulum_sys_org)
183+
pendulum_sys = dae_index_lowering(pendulum_sys_org)
184+
prob = ODEProblem(pendulum_sys, Pair[], tspan)
185+
sol = solve(prob, Rodas4())
186+
l2 = sol[sts[1]].^2 + sol[sts[3]].^2
187+
@test all(l->abs(sqrt(l) - 1) < 0.05, l2)

0 commit comments

Comments
 (0)