Skip to content

Commit 8a6b391

Browse files
Fix first-order transform
1 parent a11d28b commit 8a6b391

File tree

3 files changed

+47
-25
lines changed

3 files changed

+47
-25
lines changed

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,33 +29,53 @@ function to_diffeq(eq::Equation)
2929
throw(ArgumentError("invalid dependent variable $x"))
3030
return t.op, DiffEq(x.op, n, eq.rhs)
3131
end
32-
Base.:(==)(a::DiffEq, b::DiffEq) = isequal((a.x, a.t, a.n, a.rhs), (b.x, b.t, b.n, b.rhs))
32+
Base.:(==)(a::DiffEq, b::DiffEq) = isequal((a.x, a.n, a.rhs), (b.x, b.n, b.rhs))
3333

3434
struct ODESystem <: AbstractSystem
3535
eqs::Vector{DiffEq}
3636
iv::Variable
3737
dvs::Vector{Variable}
3838
ps::Vector{Variable}
3939
jac::RefValue{Matrix{Expression}}
40-
function ODESystem(eqs)
41-
reformatted = to_diffeq.(eqs)
40+
end
41+
42+
function ODESystem(eqs)
43+
reformatted = to_diffeq.(eqs)
4244

43-
ivs = unique(r[1] for r reformatted)
44-
length(ivs) == 1 || throw(ArgumentError("one independent variable currently supported"))
45-
iv = first(ivs)
45+
ivs = unique(r[1] for r reformatted)
46+
length(ivs) == 1 || throw(ArgumentError("one independent variable currently supported"))
47+
iv = first(ivs)
4648

47-
deqs = [r[2] for r reformatted]
49+
deqs = [r[2] for r reformatted]
4850

49-
dvs = [deq.x for deq deqs]
50-
ps = filter(vars(deq.rhs for deq deqs)) do x
51-
x.known & !isequal(x, iv)
52-
end |> collect
51+
dvs = [deq.x for deq deqs]
52+
ps = filter(vars(deq.rhs for deq deqs)) do x
53+
x.known & !isequal(x, iv)
54+
end |> collect
5355

54-
jac = RefValue(Matrix{Expression}(undef, 0, 0))
56+
ODESystem(deqs, iv, dvs, ps)
57+
end
58+
function ODESystem(deqs, iv, dvs, ps)
59+
jac = RefValue(Matrix{Expression}(undef, 0, 0))
60+
ODESystem(deqs, iv, dvs, ps, jac)
61+
end
5562

56-
new(deqs, iv, dvs, ps, jac)
63+
function _eq_unordered(a, b)
64+
length(a) === length(b) || return false
65+
n = length(a)
66+
idxs = Set(1:n)
67+
for x a
68+
idx = findfirst(isequal(x), b)
69+
idx === nothing && return false
70+
idx idxs || return false
71+
delete!(idxs, idx)
5772
end
73+
return true
5874
end
75+
Base.:(==)(sys1::ODESystem, sys2::ODESystem) =
76+
_eq_unordered(sys1.eqs, sys2.eqs) && isequal(sys1.iv, sys2.iv) &&
77+
_eq_unordered(sys1.dvs, sys2.dvs) && _eq_unordered(sys1.ps, sys2.ps)
78+
# NOTE: equality does not check cached Jacobian
5979

6080

6181
function calculate_jacobian(sys::ODESystem)

src/systems/diffeqs/first_order_transform.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,43 +8,47 @@ function lower_varname(var::Variable, idv, order)
88
end
99

1010
function ode_order_lowering(sys::ODESystem)
11-
eqs_lowered = ode_order_lowering(sys.eqs, sys.iv)
12-
ODESystem(eqs_lowered, sys.iv, sys.dvs, sys.ps)
11+
(eqs_lowered, new_vars) = ode_order_lowering(sys.eqs, sys.iv)
12+
ODESystem(eqs_lowered, sys.iv, [sys.dvs; new_vars], sys.ps)
1313
end
1414
function ode_order_lowering(eqs, iv)
1515
var_order = Dict{Variable,Int}()
1616
vars = Variable[]
1717
new_eqs = similar(eqs, DiffEq)
18+
new_vars = Variable[]
1819

1920
for (i, eq) enumerate(eqs)
2021
var, maxorder = eq.x, eq.n
2122
if maxorder > get(var_order, var, 0)
2223
var_order[var] = maxorder
2324
any(isequal(var), vars) || push!(vars, var)
2425
end
25-
var′ = lower_varname(eq.x, eq.t, eq.n - 1)
26+
var′ = lower_varname(eq.x, iv, eq.n - 1)
2627
rhs′ = rename(eq.rhs)
27-
new_eqs[i] = DiffEq(var′, iv, 1, rhs′)
28+
new_eqs[i] = DiffEq(var′, 1, rhs′)
2829
end
2930

3031
for var vars
3132
order = var_order[var]
3233
for o in (order-1):-1:1
3334
lvar = lower_varname(var, iv, o-1)
34-
rhs = lower_varname(var, iv, o)
35-
eq = DiffEq(lvar, iv, 1, rhs)
35+
rvar = lower_varname(var, iv, o)
36+
push!(new_vars, rvar)
37+
38+
rhs = rvar(iv())
39+
eq = DiffEq(lvar, 1, rhs)
3640
push!(new_eqs, eq)
3741
end
3842
end
3943

40-
return new_eqs
44+
return (new_eqs, new_vars)
4145
end
4246

4347
function rename(O::Expression)
4448
isa(O, Operation) || return O
4549
if is_derivative(O)
4650
(x, t, order) = flatten_differential(O)
47-
return lower_varname(x, t, order)
51+
return lower_varname(x.op, t.op, order)(x.args...)
4852
end
4953
return Operation(O.op, rename.(O.args))
5054
end

test/system_construction.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,22 +73,20 @@ ModelingToolkit.generate_ode_iW(de)
7373
end
7474
end
7575

76-
@test_broken begin
7776
# Conversion to first-order ODEs #17
7877
@derivatives D3'''~t
7978
@derivatives D2''~t
8079
@variables u(t) u_tt(t) u_t(t) x_t(t)
8180
eqs = [D3(u) ~ 2(D2(u)) + D(u) + D(x) + 1
8281
D2(x) ~ D(x) + 2]
83-
de = ODESystem(eqs, t)
82+
de = ODESystem(eqs)
8483
de1 = ode_order_lowering(de)
8584
lowered_eqs = [D(u_tt) ~ 2u_tt + u_t + x_t + 1
8685
D(x_t) ~ x_t + 2
8786
D(u_t) ~ u_tt
8887
D(u) ~ u_t
8988
D(x) ~ x_t]
90-
@test de1.eqs == convert.(ModelingToolkit.DiffEq, lowered_eqs)
91-
end
89+
@test de1 == ODESystem(lowered_eqs)
9290

9391
# Internal calculations
9492
a = y - x

0 commit comments

Comments
 (0)