Skip to content

Commit 7720991

Browse files
committed
Optimize and simplify ode_order_lowering implementation
1 parent a9d4b6d commit 7720991

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "3.6.1"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
8+
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
89
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
910
DiffEqJump = "c894b116-72e5-5b58-be3c-e6d8d4ac2b12"
1011
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
@@ -28,6 +29,7 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
2829

2930
[compat]
3031
ArrayInterface = "2.8"
32+
DataStructures = "0.17"
3133
DiffEqBase = "6.28"
3234
DiffEqJump = "6.7.5"
3335
DiffRules = "0.1, 1.0"

src/systems/diffeqs/first_order_transform.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using DataStructures: OrderedDict
12
function lower_varname(var::Variable, idv, order)
23
order == 0 && return var
34
name = Symbol(var.name, , string(idv.name)^order)
@@ -24,8 +25,7 @@ function ode_order_lowering(sys::ODESystem)
2425
end
2526

2627
function ode_order_lowering(eqs, iv, states)
27-
var_order = Dict{Variable,Int}()
28-
vars = Variable[]
28+
var_order = OrderedDict{Variable,Int}()
2929
D = Differential(iv())
3030
diff_eqs = Equation[]
3131
diff_vars = Variable[]
@@ -38,19 +38,16 @@ function ode_order_lowering(eqs, iv, states)
3838
push!(alge_eqs, eq)
3939
else
4040
var, maxorder = var_from_nested_derivative(eq.lhs)
41-
if maxorder > get(var_order, var, 0)
42-
var_order[var] = maxorder
43-
any(isequal(var), vars) || push!(vars, var)
44-
end
41+
# only save to the dict when we need to lower the order to save memory
42+
maxorder > get(var_order, var, 1) && (var_order[var] = maxorder)
4543
var′ = lower_varname(var, iv, maxorder - 1)
4644
rhs′ = rename_lower_order(eq.rhs)
4745
push!(diff_vars, var′)
4846
push!(diff_eqs, D(var′(iv())) ~ rhs′)
4947
end
5048
end
5149

52-
for var vars
53-
order = var_order[var]
50+
for (var, order) var_order
5451
for o in (order-1):-1:1
5552
lvar = lower_varname(var, iv, o-1)
5653
rvar = lower_varname(var, iv, o)

0 commit comments

Comments
 (0)