Skip to content

Commit 51c8dac

Browse files
Define type for storing equation
1 parent bc7e230 commit 51c8dac

File tree

10 files changed

+131
-123
lines changed

10 files changed

+131
-123
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ function caclulate_jacobian end
2626

2727
include("operations.jl")
2828
include("differentials.jl")
29+
include("equations.jl")
2930
include("systems/diffeqs/diffeqsystem.jl")
3031
include("systems/diffeqs/first_order_transform.jl")
3132
include("systems/nonlinear/nonlinear_system.jl")

src/equations.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
export Equation
2+
3+
4+
mutable struct Equation
5+
lhs::Expression
6+
rhs::Expression
7+
end
8+
Base.broadcastable(eq::Equation) = Ref(eq)
9+
10+
Base.:~(lhs::Expression, rhs::Expression) = Equation(lhs, rhs)
11+
Base.:~(lhs::Expression, rhs::Number ) = Equation(lhs, rhs)
12+
Base.:~(lhs::Number , rhs::Expression) = Equation(lhs, rhs)

src/function_registration.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ for (M, f, arity) in DiffRules.diffrules()
3232
@eval @register $sig
3333
end
3434

35-
for fun = (:<, :>, :(==), :~, :!, :&, :|, :div)
35+
for fun = (:<, :>, :(==), :!, :&, :|, :div)
3636
basefun = Expr(:., Base, QuoteNode(fun))
3737
sig = :($basefun(x,y))
3838
@eval @register $sig

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 74 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
mutable struct DiffEqSystem <: AbstractSystem
2-
eqs::Vector{Operation}
2+
eqs::Vector{Equation}
33
ivs::Vector{Variable}
44
dvs::Vector{Variable}
55
vs::Vector{Variable}
@@ -41,71 +41,64 @@ function generate_ode_function(sys::DiffEqSystem;version = ArrayFunction)
4141
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
4242
sys_exprs = build_equals_expr.(sys.eqs)
4343
if version == ArrayFunction
44-
dvar_exprs = [:(du[$i] = $(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in 1:length(sys.dvs)]
45-
exprs = vcat(var_exprs,param_exprs,sys_exprs,dvar_exprs)
46-
block = expr_arr_to_block(exprs)
47-
:((du,u,p,t)->$(toexpr(block)))
44+
dvar_exprs = [:(du[$i] = $(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in 1:length(sys.dvs)]
45+
exprs = vcat(var_exprs,param_exprs,sys_exprs,dvar_exprs)
46+
block = expr_arr_to_block(exprs)
47+
:((du,u,p,t)->$(toexpr(block)))
4848
elseif version == SArrayFunction
49-
dvar_exprs = [:($(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in 1:length(sys.dvs)]
50-
svector_expr = quote
51-
E = eltype(tuple($(dvar_exprs...)))
52-
T = StaticArrays.similar_type(typeof(u), E)
53-
T($(dvar_exprs...))
54-
end
55-
exprs = vcat(var_exprs,param_exprs,sys_exprs,svector_expr)
56-
block = expr_arr_to_block(exprs)
57-
:((u,p,t)->$(toexpr(block)))
49+
dvar_exprs = [:($(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in 1:length(sys.dvs)]
50+
svector_expr = quote
51+
E = eltype(tuple($(dvar_exprs...)))
52+
T = StaticArrays.similar_type(typeof(u), E)
53+
T($(dvar_exprs...))
54+
end
55+
exprs = vcat(var_exprs,param_exprs,sys_exprs,svector_expr)
56+
block = expr_arr_to_block(exprs)
57+
:((u,p,t)->$(toexpr(block)))
5858
end
5959
end
6060

61-
isintermediate(eq) = eq.args[1].diff == nothing
61+
isintermediate(eq::Equation) = eq.lhs.diff === nothing
6262

63-
function build_equals_expr(eq)
64-
@assert typeof(eq.args[1]) <: Variable
65-
if !(isintermediate(eq))
66-
# Differential statement
67-
:($(Symbol("$(eq.args[1].name)_$(eq.args[1].diff.x.name)")) = $(eq.args[2]))
68-
else
69-
# Intermediate calculation
70-
:($(Symbol("$(eq.args[1].name)")) = $(eq.args[2]))
71-
end
63+
function build_equals_expr(eq::Equation)
64+
@assert typeof(eq.lhs) <: Variable
65+
66+
lhs = Symbol("$(eq.lhs.name)")
67+
isintermediate(eq) || (lhs = Symbol(lhs, :_, "$(eq.lhs.diff.x.name)"))
68+
69+
return :($lhs = $(convert(Expr, eq.rhs)))
7270
end
7371

74-
function calculate_jacobian(sys::DiffEqSystem,simplify=true)
75-
diff_idxs = map(eq->eq.args[1].diff !=nothing,sys.eqs)
76-
diff_exprs = sys.eqs[diff_idxs]
77-
rhs = [eq.args[2] for eq in diff_exprs]
72+
function calculate_jacobian(sys::DiffEqSystem, simplify=true)
73+
calcs, diff_exprs = partition(isintermediate, sys.eqs)
74+
rhs = [eq.rhs for eq in diff_exprs]
75+
7876
# Handle intermediate calculations by substitution
79-
calcs = sys.eqs[.!(diff_idxs)]
80-
for i in 1:length(calcs)
81-
find_replace!.(rhs,calcs[i].args[1],calcs[i].args[2])
77+
for calc calcs
78+
find_replace!.(rhs, calc.lhs, calc.rhs)
8279
end
83-
sys_exprs = calculate_jacobian(rhs,sys.dvs)
80+
81+
sys_exprs = calculate_jacobian(rhs, sys.dvs)
8482
sys_exprs = Expression[expand_derivatives(expr) for expr in sys_exprs]
85-
if simplify
86-
sys_exprs = Expression[simplify_constants(expr) for expr in sys_exprs]
87-
end
8883
sys_exprs
8984
end
9085

91-
function generate_ode_jacobian(sys::DiffEqSystem,simplify=true)
86+
function generate_ode_jacobian(sys::DiffEqSystem, simplify=true)
9287
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in 1:length(sys.dvs)]
9388
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
94-
diff_idxs = map(eq->eq.args[1].diff !=nothing,sys.eqs)
95-
diff_exprs = sys.eqs[diff_idxs]
96-
jac = calculate_jacobian(sys,simplify)
89+
diff_exprs = filter(!isintermediate, sys.eqs)
90+
jac = calculate_jacobian(sys, simplify)
9791
sys.jac = jac
9892
jac_exprs = [:(J[$i,$j] = $(convert(Expr, jac[i,j]))) for i in 1:size(jac,1), j in 1:size(jac,2)]
9993
exprs = vcat(var_exprs,param_exprs,vec(jac_exprs))
10094
block = expr_arr_to_block(exprs)
10195
:((J,u,p,t)->$(block))
10296
end
10397

104-
function generate_ode_iW(sys::DiffEqSystem,simplify=true)
98+
function generate_ode_iW(sys::DiffEqSystem, simplify=true)
10599
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in 1:length(sys.dvs)]
106100
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
107-
diff_idxs = map(eq->eq.args[1].diff !=nothing,sys.eqs)
108-
diff_exprs = sys.eqs[diff_idxs]
101+
diff_exprs = filter(!isintermediate, sys.eqs)
109102
jac = sys.jac
110103

111104
gam = Variable(:gam)
@@ -144,5 +137,44 @@ function DiffEqBase.ODEFunction(sys::DiffEqSystem;version = ArrayFunction,kwargs
144137
end
145138
end
146139

140+
function extract_elements(eqs, targetmap, default = nothing)
141+
elems = Dict{Symbol, Vector{Variable}}()
142+
names = Dict{Symbol, Set{Symbol}}()
143+
if default == nothing
144+
targets = unique(collect(values(targetmap)))
145+
else
146+
targets = [unique(collect(values(targetmap))), default]
147+
end
148+
for target in targets
149+
elems[target] = Vector{Variable}()
150+
names[target] = Set{Symbol}()
151+
end
152+
for eq in eqs
153+
extract_elements!(eq, elems, names, targetmap, default)
154+
end
155+
Tuple(elems[target] for target in targets)
156+
end
157+
# Walk the tree recursively and push variables into the right set
158+
function extract_elements!(op, elems, names, targetmap, default)
159+
args = isa(op, Equation) ? Expression[op.lhs, op.rhs] : op.args
160+
161+
for arg in args
162+
if arg isa Operation
163+
extract_elements!(arg, elems, names, targetmap, default)
164+
elseif arg isa Variable
165+
if default == nothing
166+
target = haskey(targetmap, arg.subtype) ? targetmap[arg.subtype] : continue
167+
else
168+
target = haskey(targetmap, arg.subtype) ? targetmap[arg.subtype] : default
169+
end
170+
if !in(arg.name, names[target])
171+
push!(names[target], arg.name)
172+
push!(elems[target], arg)
173+
end
174+
end
175+
end
176+
end
177+
178+
147179
export DiffEqSystem, ODEFunction
148180
export generate_ode_function

src/systems/diffeqs/first_order_transform.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
extract_idv(eq::Equation) = eq.lhs.diff.x
2+
13
function lower_varname(var::Variable, naming_scheme; lower=false)
24
D = var.diff
35
D == nothing && return var
@@ -22,7 +24,7 @@ function ode_order_lowering!(eqs, naming_scheme)
2224
idv = extract_idv(eqs[ind])
2325
D = Differential(idv, 1)
2426
sym_order = Dict{Symbol, Int}()
25-
dv_name = eqs[1].args[1].subtype
27+
dv_name = eqs[1].lhs.subtype
2628
for eq in eqs
2729
isintermediate(eq) && continue
2830
sym, maxorder = extract_symbol_order(eq)
@@ -38,21 +40,18 @@ function ode_order_lowering!(eqs, naming_scheme)
3840
for o in (order-1):-1:1
3941
lhs = D(lower_varname(sym, idv, o-1, dv_name, naming_scheme))
4042
rhs = lower_varname(sym, idv, o, dv_name, naming_scheme)
41-
eq = Operation(==, [lhs, rhs])
43+
eq = Equation(lhs, rhs)
4244
push!(eqs, eq)
4345
end
4446
end
4547
eqs
4648
end
4749

4850
function lhs_renaming!(eq, D, naming_scheme)
49-
eq.args[1] = D(lower_varname(eq.args[1], naming_scheme, lower=true))
51+
eq.lhs = D(lower_varname(eq.lhs, naming_scheme, lower=true))
5052
return eq
5153
end
52-
function rhs_renaming!(eq, naming_scheme)
53-
rhs = eq.args[2]
54-
_rec_renaming!(rhs, naming_scheme)
55-
end
54+
rhs_renaming!(eq, naming_scheme) = _rec_renaming!(eq.rhs, naming_scheme)
5655

5756
function _rec_renaming!(rhs, naming_scheme)
5857
rhs isa Variable && rhs.diff != nothing && return lower_varname(rhs, naming_scheme)
@@ -67,7 +66,7 @@ end
6766

6867
function extract_symbol_order(eq)
6968
# We assume that the differential with the highest order is always going to be in the LHS
70-
dv = eq.args[1]
69+
dv = eq.lhs
7170
sym = dv.name
7271
order = dv.diff.order
7372
sym, order

src/systems/nonlinear/nonlinear_system.jl

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
struct NonlinearSystem <: AbstractSystem
2-
eqs::Vector{Operation}
2+
eqs::Vector{Equation}
33
vs::Vector{Variable}
44
ps::Vector{Variable}
55
v_name::Vector{Symbol}
@@ -26,32 +26,25 @@ end
2626
function generate_nlsys_function(sys::NonlinearSystem)
2727
var_exprs = [:($(sys.vs[i].name) = u[$i]) for i in 1:length(sys.vs)]
2828
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
29-
sys_idxs = map(eq->isequal(eq.args[1],Constant(0)),sys.eqs)
30-
sys_eqs = sys.eqs[sys_idxs]
31-
calc_eqs = sys.eqs[.!(sys_idxs)]
32-
calc_exprs = [:($(Symbol("$(eq.args[1].name)")) = $(eq.args[2])) for eq in calc_eqs]
33-
sys_exprs = [:($(Symbol("resid[$i]")) = $(sys_eqs[i].args[2])) for i in eachindex(sys_eqs)]
29+
sys_eqs, calc_eqs = partition(eq -> isequal(eq.lhs, Constant(0)), sys.eqs)
30+
calc_exprs = [:($(Symbol("$(eq.lhs.name)")) = $(eq.rhs)) for eq in calc_eqs]
31+
sys_exprs = [:($(Symbol("resid[$i]")) = $(sys_eqs[i].rhs)) for i in eachindex(sys_eqs)]
3432

3533
exprs = vcat(var_exprs,param_exprs,calc_exprs,sys_exprs)
3634
block = expr_arr_to_block(exprs)
3735
:((du,u,p)->$(block))
3836
end
3937

4038
function calculate_jacobian(sys::NonlinearSystem,simplify=true)
41-
sys_idxs = map(eq->isequal(eq.args[1],Constant(0)),sys.eqs)
42-
sys_eqs = sys.eqs[sys_idxs]
43-
calc_eqs = sys.eqs[.!(sys_idxs)]
44-
rhs = [eq.args[2] for eq in sys_eqs]
39+
sys_eqs, calc_eqs = partition(eq -> isequal(eq.lhs, Constant(0)), sys.eqs)
40+
rhs = [eq.rhs for eq in sys_eqs]
4541

46-
for i in 1:length(calc_eqs)
47-
find_replace!.(rhs,calc_eqs[i].args[1],calc_eqs[i].args[2])
42+
for calc_eq calc_eqs
43+
find_replace!.(rhs, calc_eq.lhs, calc_eq.rhs)
4844
end
4945

5046
sys_exprs = calculate_jacobian(rhs,sys.vs)
5147
sys_exprs = Expression[expand_derivatives(expr) for expr in sys_exprs]
52-
if simplify
53-
sys_exprs = Expression[simplify_constants(expr) for expr in sys_exprs]
54-
end
5548
sys_exprs
5649
end
5750

src/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ end
3333

3434
toexpr(ex) = MacroTools.postwalk(x -> isa(x, Expression) ? convert(Expr, x) : x, ex)
3535

36+
function partition(f, xs)
37+
idxs = map(f, xs)
38+
not_idxs = eachindex(xs) .∉ (idxs,)
39+
return (xs[idxs], xs[not_idxs])
40+
end
41+
3642
is_constant(x::Variable) = x.subtype === :Constant
3743
is_constant(::Any) = false
3844

src/variables.jl

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -113,44 +113,6 @@ function Base.show(io::IO, A::Variable)
113113
end
114114
end
115115

116-
extract_idv(eq) = eq.args[1].diff.x
117-
118-
function extract_elements(ops, targetmap, default = nothing)
119-
elems = Dict{Symbol, Vector{Variable}}()
120-
names = Dict{Symbol, Set{Symbol}}()
121-
if default == nothing
122-
targets = unique(collect(values(targetmap)))
123-
else
124-
targets = [unique(collect(values(targetmap))), default]
125-
end
126-
for target in targets
127-
elems[target] = Vector{Variable}()
128-
names[target] = Set{Symbol}()
129-
end
130-
for op in ops
131-
extract_elements!(op, elems, names, targetmap, default)
132-
end
133-
Tuple(elems[target] for target in targets)
134-
end
135-
# Walk the tree recursively and push variables into the right set
136-
function extract_elements!(op::AbstractOperation, elems, names, targetmap, default)
137-
for arg in op.args
138-
if arg isa Operation
139-
extract_elements!(arg, elems, names, targetmap, default)
140-
elseif arg isa Variable
141-
if default == nothing
142-
target = haskey(targetmap, arg.subtype) ? targetmap[arg.subtype] : continue
143-
else
144-
target = haskey(targetmap, arg.subtype) ? targetmap[arg.subtype] : default
145-
end
146-
if !in(arg.name, names[target])
147-
push!(names[target], arg.name)
148-
push!(elems[target], arg)
149-
end
150-
end
151-
end
152-
end
153-
154116
# Build variables more easily
155117
function _parse_vars(macroname, fun, x)
156118
ex = Expr(:block)

test/components.jl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,27 @@ using ModelingToolkit
22
using Test
33

44
struct Lorenz <: AbstractComponent
5-
x::Variable
6-
y::Variable
7-
z::Variable
8-
σ::Variable
9-
ρ::Variable
10-
β::Variable
11-
eqs::Vector{Expression}
5+
x::Variable
6+
y::Variable
7+
z::Variable
8+
σ::Variable
9+
ρ::Variable
10+
β::Variable
11+
eqs::Vector{Equation}
1212
end
1313
function generate_lorenz_eqs(t,x,y,z,σ,ρ,β)
14-
D = Differential(t)
15-
[D(x) ~ σ*(y-x)
16-
D(y) ~ x*-z)-y
17-
D(z) ~ x*y - β*z]
14+
D = Differential(t)
15+
[D(x) ~ σ*(y-x)
16+
D(y) ~ x*-z)-y
17+
D(z) ~ x*y - β*z]
18+
end
19+
function Lorenz(t)
20+
@DVar x(t) y(t) z(t)
21+
@Param σ ρ β
22+
Lorenz(x, y, z, σ, ρ, β, generate_lorenz_eqs(t, x, y, z, σ, ρ, β))
1823
end
19-
Lorenz(t) = Lorenz(first(@DVar(x(t))),first(@DVar(y(t))),first(@DVar(z(t))),first(@Param(σ)),first(@Param(ρ)),first(@Param(β)),generate_lorenz_eqs(t,x,y,z,σ,ρ,β))
2024

2125
@IVar t
2226
lz1 = Lorenz(t)
2327
lz2 = Lorenz(t)
24-
Expression[lz1.x ~ lz2.x
25-
lz1
26-
lz2]
28+
lz1.x ~ lz2.x

0 commit comments

Comments
 (0)