Skip to content

Commit 7312a72

Browse files
Fix first-order transformation
1 parent bdf252e commit 7312a72

File tree

2 files changed

+34
-25
lines changed

2 files changed

+34
-25
lines changed
Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
extract_idv(eq::Equation) = eq.lhs.op.x
22

3-
function lower_varname(var::Variable, naming_scheme; lower=false)
4-
D = var.diff
5-
D === nothing && return var
3+
function lower_varname(O::Operation, naming_scheme; lower=false)
4+
@assert isa(O.op, Differential)
5+
6+
D, x = O.op, O.args[1]
67
order = lower ? D.order-1 : D.order
7-
lower_varname(var.name, D.x, order, var.subtype, naming_scheme)
8+
9+
sym = x.name
10+
name = order == 0 ? sym : Symbol(sym, naming_scheme, string(D.x.name)^order)
11+
12+
Variable(name, x.subtype, x.dependents)
813
end
9-
function lower_varname(sym::Symbol, idv, order::Int, subtype::Symbol, naming_scheme)
14+
function lower_varname(var::Variable, idv, order::Int, naming_scheme)
15+
sym = var.name
1016
name = order == 0 ? sym : Symbol(sym, naming_scheme, string(idv.name)^order)
11-
return Variable(name, subtype=subtype)
17+
return Variable(name, var.subtype, var.dependents)
1218
end
1319

1420
function ode_order_lowering(sys::DiffEqSystem; kwargs...)
@@ -21,27 +27,30 @@ ode_order_lowering(eqs; naming_scheme = "_") = ode_order_lowering!(deepcopy(eqs)
2127
function ode_order_lowering!(eqs, naming_scheme)
2228
idv = extract_idv(eqs[1])
2329
D = Differential(idv, 1)
24-
sym_order = Dict{Symbol, Int}()
25-
dv_name = eqs[1].lhs.subtype
30+
var_order = Dict{Variable,Int}()
31+
dv_name = eqs[1].lhs.args[1].subtype
32+
2633
for eq in eqs
27-
sym, maxorder = extract_symbol_order(eq)
34+
var, maxorder = extract_var_order(eq)
2835
maxorder == 1 && continue # fast pass
29-
if maxorder > get(sym_order, sym, 0)
30-
sym_order[sym] = maxorder
36+
if maxorder > get(var_order, var, 0)
37+
var_order[var] = maxorder
3138
end
32-
eq = lhs_renaming!(eq, D, naming_scheme)
33-
eq = rhs_renaming!(eq, naming_scheme)
39+
lhs_renaming!(eq, D, naming_scheme)
40+
rhs_renaming!(eq, naming_scheme)
3441
end
35-
for sym in keys(sym_order)
36-
order = sym_order[sym]
42+
43+
for var keys(var_order)
44+
order = var_order[var]
3745
for o in (order-1):-1:1
38-
lhs = D(lower_varname(sym, idv, o-1, dv_name, naming_scheme))
39-
rhs = lower_varname(sym, idv, o, dv_name, naming_scheme)
46+
lhs = D(lower_varname(var, idv, o-1, naming_scheme))
47+
rhs = lower_varname(var, idv, o, naming_scheme)
4048
eq = Equation(lhs, rhs)
4149
push!(eqs, eq)
4250
end
4351
end
44-
eqs
52+
53+
return eqs
4554
end
4655

4756
function lhs_renaming!(eq, D, naming_scheme)
@@ -51,7 +60,7 @@ end
5160
rhs_renaming!(eq, naming_scheme) = _rec_renaming!(eq.rhs, naming_scheme)
5261

5362
function _rec_renaming!(rhs, naming_scheme)
54-
rhs isa Variable && rhs.diff != nothing && return lower_varname(rhs, naming_scheme)
63+
isa(rhs, Operation) && isa(rhs.op, Differential) && return lower_varname(rhs, naming_scheme)
5564
if rhs isa Operation
5665
args = rhs.args
5766
for i in eachindex(args)
@@ -61,12 +70,12 @@ function _rec_renaming!(rhs, naming_scheme)
6170
rhs
6271
end
6372

64-
function extract_symbol_order(eq)
73+
function extract_var_order(eq)
6574
# We assume that the differential with the highest order is always going to be in the LHS
6675
dv = eq.lhs
67-
sym = dv.name
68-
order = dv.diff.order
69-
sym, order
76+
var = dv.args[1]
77+
order = dv.op.order
78+
return (var, order)
7079
end
7180

7281
export ode_order_lowering

test/system_construction.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ test_vars_extraction(de, de2)
3838
eqs = [D3(u) ~ 2(D2(u)) + D(u) + D(x) + 1
3939
D2(x) ~ D(x) + 2]
4040
de = DiffEqSystem(eqs, [t])
41-
@test_broken de1 = ode_order_lowering(de)
41+
de1 = ode_order_lowering(de)
4242
lowered_eqs = [D(u_tt) ~ 2u_tt + u_t + x_t + 1
4343
D(x_t) ~ x_t + 2
4444
D(u_t) ~ u_tt
@@ -57,7 +57,7 @@ function test_eqs(eqs1, eqs2)
5757
end
5858
eq
5959
end
60-
@test_broken test_eqs(de1.eqs, lowered_eqs)
60+
@test test_eqs(de1.eqs, lowered_eqs)
6161

6262
# Internal calculations
6363
a = y - x

0 commit comments

Comments
 (0)