Skip to content

Commit bdf252e

Browse files
Disallow intermediate equations in DiffEqSystem
1 parent 90a5379 commit bdf252e

File tree

2 files changed

+11
-14
lines changed

2 files changed

+11
-14
lines changed

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@ mutable struct DiffEqSystem <: AbstractSystem
44
dvs::Vector{Variable}
55
ps::Vector{Variable}
66
jac::Matrix{Expression}
7+
function DiffEqSystem(eqs, ivs, dvs, ps, jac)
8+
all(!isintermediate, eqs) ||
9+
throw(ArgumentError("no intermediate equations permitted in DiffEqSystem"))
10+
11+
new(eqs, ivs, dvs, ps, jac)
12+
end
713
end
814

915
DiffEqSystem(eqs, ivs, dvs, ps) = DiffEqSystem(eqs, ivs, dvs, ps, Matrix{Expression}(undef,0,0))
@@ -20,6 +26,9 @@ function DiffEqSystem(eqs, ivs)
2026
DiffEqSystem(eqs, ivs, dvs, ps, Matrix{Expression}(undef,0,0))
2127
end
2228

29+
isintermediate(eq::Equation) = !(isa(eq.lhs, Operation) && isa(eq.lhs.op, Differential))
30+
31+
2332
function generate_ode_function(sys::DiffEqSystem;version = ArrayFunction)
2433
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
2534
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
@@ -42,8 +51,6 @@ function generate_ode_function(sys::DiffEqSystem;version = ArrayFunction)
4251
end
4352
end
4453

45-
isintermediate(eq::Equation) = !(isa(eq.lhs, Operation) && isa(eq.lhs.op, Differential))
46-
4754
function build_equals_expr(eq::Equation)
4855
@assert !isintermediate(eq)
4956

@@ -52,13 +59,7 @@ function build_equals_expr(eq::Equation)
5259
end
5360

5461
function calculate_jacobian(sys::DiffEqSystem, simplify=true)
55-
calcs, diff_exprs = partition(isintermediate, sys.eqs)
56-
rhs = [eq.rhs for eq in diff_exprs]
57-
58-
# Handle intermediate calculations by substitution
59-
for calc calcs
60-
find_replace!.(rhs, calc.lhs, calc.rhs)
61-
end
62+
rhs = [eq.rhs for eq in sys.eqs]
6263

6364
sys_exprs = calculate_jacobian(rhs, sys.dvs)
6465
sys_exprs = Expression[expand_derivatives(expr) for expr in sys_exprs]
@@ -68,7 +69,6 @@ end
6869
function generate_ode_jacobian(sys::DiffEqSystem, simplify=true)
6970
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
7071
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
71-
diff_exprs = filter(!isintermediate, sys.eqs)
7272
jac = calculate_jacobian(sys, simplify)
7373
sys.jac = jac
7474
jac_exprs = [:(J[$i,$j] = $(convert(Expr, jac[i,j]))) for i in 1:size(jac,1), j in 1:size(jac,2)]
@@ -80,7 +80,6 @@ end
8080
function generate_ode_iW(sys::DiffEqSystem, simplify=true)
8181
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
8282
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
83-
diff_exprs = filter(!isintermediate, sys.eqs)
8483
jac = sys.jac
8584

8685
gam = Parameter(:gam)

src/systems/diffeqs/first_order_transform.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,11 @@ function ode_order_lowering(sys::DiffEqSystem; kwargs...)
1919
end
2020
ode_order_lowering(eqs; naming_scheme = "_") = ode_order_lowering!(deepcopy(eqs), naming_scheme)
2121
function ode_order_lowering!(eqs, naming_scheme)
22-
ind = findfirst(x->!(isintermediate(x)), eqs)
23-
idv = extract_idv(eqs[ind])
22+
idv = extract_idv(eqs[1])
2423
D = Differential(idv, 1)
2524
sym_order = Dict{Symbol, Int}()
2625
dv_name = eqs[1].lhs.subtype
2726
for eq in eqs
28-
isintermediate(eq) && continue
2927
sym, maxorder = extract_symbol_order(eq)
3028
maxorder == 1 && continue # fast pass
3129
if maxorder > get(sym_order, sym, 0)

0 commit comments

Comments
 (0)