Skip to content

Commit 91c3b47

Browse files
Reshape data to mirror implicit invariants
1 parent d9bef69 commit 91c3b47

File tree

4 files changed

+31
-47
lines changed

4 files changed

+31
-47
lines changed

src/equations.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ function extract_elements(eqs, predicates)
2929
return result
3030
end
3131

32+
get_args(O::Operation) = O.args
33+
get_args(eq::Equation) = Expression[eq.lhs, eq.rhs]
3234
function vars!(vars, op)
33-
args = isa(op, Equation) ? Expression[op.lhs, op.rhs] : op.args
34-
35-
for arg args
35+
for arg get_args(op)
3636
if isa(arg, Operation)
3737
vars!(vars, arg)
3838
elseif isa(arg, Variable)

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,26 @@
11
using Base: RefValue
22

33

4+
isintermediate(eq::Equation) = !(isa(eq.lhs, Operation) && isa(eq.lhs.op, Differential))
5+
6+
mutable struct DiffEq # D(x) = t
7+
D::Differential # D
8+
var::Variable # x
9+
rhs::Expression # t
10+
end
11+
function Base.convert(::Type{DiffEq}, eq::Equation)
12+
isintermediate(eq) && throw(ArgumentError("intermediate equation received"))
13+
return DiffEq(eq.lhs.op, eq.lhs.args[1], eq.rhs)
14+
end
15+
get_args(eq::DiffEq) = Expression[eq.var, eq.rhs]
16+
417
struct DiffEqSystem <: AbstractSystem
5-
eqs::Vector{Equation}
18+
eqs::Vector{DiffEq}
619
ivs::Vector{Variable}
720
dvs::Vector{Variable}
821
ps::Vector{Variable}
922
jac::RefValue{Matrix{Expression}}
1023
function DiffEqSystem(eqs, ivs, dvs, ps)
11-
all(!isintermediate, eqs) ||
12-
throw(ArgumentError("no intermediate equations permitted in DiffEqSystem"))
13-
1424
jac = RefValue(Matrix{Expression}(undef, 0, 0))
1525
new(eqs, ivs, dvs, ps, jac)
1626
end
@@ -28,8 +38,6 @@ function DiffEqSystem(eqs, ivs)
2838
DiffEqSystem(eqs, ivs, dvs, ps)
2939
end
3040

31-
isintermediate(eq::Equation) = !(isa(eq.lhs, Operation) && isa(eq.lhs.op, Differential))
32-
3341

3442
function generate_ode_function(sys::DiffEqSystem;version = ArrayFunction)
3543
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
@@ -53,10 +61,8 @@ function generate_ode_function(sys::DiffEqSystem;version = ArrayFunction)
5361
end
5462
end
5563

56-
function build_equals_expr(eq::Equation)
57-
@assert !isintermediate(eq)
58-
59-
lhs = Symbol(eq.lhs.args[1].name, :_, eq.lhs.op.x.name)
64+
function build_equals_expr(eq::DiffEq)
65+
lhs = Symbol(eq.var.name, :_, eq.D.x.name)
6066
return :($lhs = $(convert(Expr, eq.rhs)))
6167
end
6268

src/systems/diffeqs/first_order_transform.jl

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
1-
extract_idv(eq::Equation) = eq.lhs.op.x
1+
extract_idv(eq::DiffEq) = eq.D.x
22

3-
function lower_varname(O::Operation, naming_scheme; lower=false)
4-
@assert isa(O.op, Differential)
5-
6-
D, x = O.op, O.args[1]
3+
function lower_varname(D::Differential, x, naming_scheme; lower=false)
74
order = lower ? D.order-1 : D.order
8-
9-
lower_varname(x, D.x, order, naming_scheme)
5+
return lower_varname(x, D.x, order, naming_scheme)
106
end
117
function lower_varname(var::Variable, idv, order::Int, naming_scheme)
128
sym = var.name
@@ -26,7 +22,7 @@ function ode_order_lowering!(eqs, naming_scheme)
2622
D = Differential(idv, 1)
2723
var_order = Dict{Variable,Int}()
2824
vars = Variable[]
29-
dv_name = eqs[1].lhs.args[1].subtype
25+
dv_name = eqs[1].var.subtype
3026

3127
for eq in eqs
3228
var, maxorder = extract_var_order(eq)
@@ -35,7 +31,7 @@ function ode_order_lowering!(eqs, naming_scheme)
3531
var_order[var] = maxorder
3632
var vars || push!(vars, var)
3733
end
38-
lhs_renaming!(eq, D, naming_scheme)
34+
lhs_renaming!(eq, naming_scheme)
3935
rhs_renaming!(eq, naming_scheme)
4036
end
4137

@@ -52,14 +48,15 @@ function ode_order_lowering!(eqs, naming_scheme)
5248
return eqs
5349
end
5450

55-
function lhs_renaming!(eq, D, naming_scheme)
56-
eq.lhs = D(lower_varname(eq.lhs, naming_scheme, lower=true))
51+
function lhs_renaming!(eq::DiffEq, naming_scheme)
52+
eq.var = lower_varname(eq.D, eq.var, naming_scheme, lower=true)
5753
return eq
5854
end
59-
rhs_renaming!(eq, naming_scheme) = _rec_renaming!(eq.rhs, naming_scheme)
55+
rhs_renaming!(eq::DiffEq, naming_scheme) = _rec_renaming!(eq.rhs, naming_scheme)
6056

6157
function _rec_renaming!(rhs, naming_scheme)
62-
isa(rhs, Operation) && isa(rhs.op, Differential) && return lower_varname(rhs, naming_scheme)
58+
isa(rhs, Operation) && isa(rhs.op, Differential) &&
59+
return lower_varname(rhs.op, rhs.args[1], naming_scheme)
6360
if rhs isa Operation
6461
args = rhs.args
6562
for i in eachindex(args)
@@ -69,12 +66,6 @@ function _rec_renaming!(rhs, naming_scheme)
6966
rhs
7067
end
7168

72-
function extract_var_order(eq)
73-
# We assume that the differential with the highest order is always going to be in the LHS
74-
dv = eq.lhs
75-
var = dv.args[1]
76-
order = dv.op.order
77-
return (var, order)
78-
end
69+
extract_var_order(eq::DiffEq) = (eq.var, eq.D.order)
7970

8071
export ode_order_lowering

test/system_construction.jl

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,7 @@ lowered_eqs = [D(u_tt) ~ 2u_tt + u_t + x_t + 1
4444
D(u_t) ~ u_tt
4545
D(u) ~ u_t
4646
D(x) ~ x_t]
47-
function test_eqs(eqs1, eqs2)
48-
length(eqs1) == length(eqs2) || return false
49-
eq = true
50-
for (eq1, eq2) in zip(eqs1, eqs2)
51-
lhs1, lhs2 = eq1.lhs, eq2.lhs
52-
typeof(lhs1) === typeof(lhs2) || return false
53-
for f in fieldnames(typeof(lhs1))
54-
eq = eq & isequal(getfield(lhs1, f), getfield(lhs2, f))
55-
end
56-
eq = eq & isequal(eq1.rhs, eq2.rhs)
57-
end
58-
eq
59-
end
60-
@test test_eqs(de1.eqs, lowered_eqs)
47+
@test_broken de1.eqs == lowered_eqs
6148

6249
# Internal calculations
6350
a = y - x

0 commit comments

Comments
 (0)