Skip to content

Commit 69390a9

Browse files
Merge pull request #80 from JuliaDiffEq/hg/fix/equation
Update IR for equations
2 parents d1a6adc + 2abd6b7 commit 69390a9

12 files changed

+143
-128
lines changed

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,13 @@ context-aware single variable of the IR. Its fields are described as follows:
167167
### Operations
168168

169169
Operations are the basic composition of variables and puts together the pieces
170-
with a function. The `~` function denotes equality between the arguments.
170+
with a function.
171+
172+
### Equations
173+
174+
Equations are stored using the `Equation` datatype. Given expressions for the
175+
left-hand and right-hand sides, an equation is constructed as `Equation(lhs, rhs)`,
176+
or equivalently `lhs ~ rhs`.
171177

172178
### Differentials
173179

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ function caclulate_jacobian end
2424

2525
include("operations.jl")
2626
include("differentials.jl")
27+
include("equations.jl")
2728
include("systems/diffeqs/diffeqsystem.jl")
2829
include("systems/diffeqs/first_order_transform.jl")
2930
include("systems/nonlinear/nonlinear_system.jl")

src/equations.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
export Equation
2+
3+
4+
mutable struct Equation
5+
lhs::Expression
6+
rhs::Expression
7+
end
8+
Base.broadcastable(eq::Equation) = Ref(eq)
9+
10+
Base.:~(lhs::Expression, rhs::Expression) = Equation(lhs, rhs)
11+
Base.:~(lhs::Expression, rhs::Number ) = Equation(lhs, rhs)
12+
Base.:~(lhs::Number , rhs::Expression) = Equation(lhs, rhs)
13+
14+
15+
function extract_elements(eqs, targetmap, default = nothing)
16+
elems = Dict{Symbol, Vector{Variable}}()
17+
names = Dict{Symbol, Set{Symbol}}()
18+
if default == nothing
19+
targets = unique(collect(values(targetmap)))
20+
else
21+
targets = [unique(collect(values(targetmap))), default]
22+
end
23+
for target in targets
24+
elems[target] = Vector{Variable}()
25+
names[target] = Set{Symbol}()
26+
end
27+
for eq in eqs
28+
extract_elements!(eq, elems, names, targetmap, default)
29+
end
30+
Tuple(elems[target] for target in targets)
31+
end
32+
# Walk the tree recursively and push variables into the right set
33+
function extract_elements!(op, elems, names, targetmap, default)
34+
args = isa(op, Equation) ? Expression[op.lhs, op.rhs] : op.args
35+
36+
for arg in args
37+
if arg isa Operation
38+
extract_elements!(arg, elems, names, targetmap, default)
39+
elseif arg isa Variable
40+
if default == nothing
41+
target = haskey(targetmap, arg.subtype) ? targetmap[arg.subtype] : continue
42+
else
43+
target = haskey(targetmap, arg.subtype) ? targetmap[arg.subtype] : default
44+
end
45+
if !in(arg.name, names[target])
46+
push!(names[target], arg.name)
47+
push!(elems[target], arg)
48+
end
49+
end
50+
end
51+
end

src/function_registration.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ for (M, f, arity) in DiffRules.diffrules()
3232
@eval @register $sig
3333
end
3434

35-
for fun = (:<, :>, :(==), :~, :!, :&, :|, :div)
35+
for fun = (:<, :>, :(==), :!, :&, :|, :div)
3636
basefun = Expr(:., Base, QuoteNode(fun))
3737
sig = :($basefun(x,y))
3838
@eval @register $sig

src/operations.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ Base.show(io::IO, O::Operation) = print(io, convert(Expr, O))
2121

2222

2323
"""
24-
find_replace(O::Operation,x::Variable,y::Expression)
24+
find_replace(O::Operation, x::Expression, y::Expression)
2525
26-
Finds the variable `x` in Operation `O` and replaces it with the Expression `y`
26+
Finds the expression `x` in Operation `O` and replaces it with the Expression `y`
2727
"""
28-
function find_replace!(O::Operation,x::Variable,y::Expression)
28+
function find_replace!(O::Operation, x::Expression, y::Expression)
2929
for i in eachindex(O.args)
30-
if isequal(O.args[i],x)
30+
if isequal(O.args[i], x)
3131
O.args[i] = y
3232
elseif typeof(O.args[i]) <: Operation
3333
find_replace!(O.args[i],x,y)

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 36 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
mutable struct DiffEqSystem <: AbstractSystem
2-
eqs::Vector{Operation}
2+
eqs::Vector{Equation}
33
ivs::Vector{Variable}
44
dvs::Vector{Variable}
55
vs::Vector{Variable}
@@ -41,71 +41,64 @@ function generate_ode_function(sys::DiffEqSystem;version = ArrayFunction)
4141
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
4242
sys_exprs = build_equals_expr.(sys.eqs)
4343
if version == ArrayFunction
44-
dvar_exprs = [:(du[$i] = $(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in 1:length(sys.dvs)]
45-
exprs = vcat(var_exprs,param_exprs,sys_exprs,dvar_exprs)
46-
block = expr_arr_to_block(exprs)
47-
:((du,u,p,t)->$(toexpr(block)))
44+
dvar_exprs = [:(du[$i] = $(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in 1:length(sys.dvs)]
45+
exprs = vcat(var_exprs,param_exprs,sys_exprs,dvar_exprs)
46+
block = expr_arr_to_block(exprs)
47+
:((du,u,p,t)->$(toexpr(block)))
4848
elseif version == SArrayFunction
49-
dvar_exprs = [:($(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in 1:length(sys.dvs)]
50-
svector_expr = quote
51-
E = eltype(tuple($(dvar_exprs...)))
52-
T = StaticArrays.similar_type(typeof(u), E)
53-
T($(dvar_exprs...))
54-
end
55-
exprs = vcat(var_exprs,param_exprs,sys_exprs,svector_expr)
56-
block = expr_arr_to_block(exprs)
57-
:((u,p,t)->$(toexpr(block)))
49+
dvar_exprs = [:($(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in 1:length(sys.dvs)]
50+
svector_expr = quote
51+
E = eltype(tuple($(dvar_exprs...)))
52+
T = StaticArrays.similar_type(typeof(u), E)
53+
T($(dvar_exprs...))
54+
end
55+
exprs = vcat(var_exprs,param_exprs,sys_exprs,svector_expr)
56+
block = expr_arr_to_block(exprs)
57+
:((u,p,t)->$(toexpr(block)))
5858
end
5959
end
6060

61-
isintermediate(eq) = eq.args[1].diff == nothing
61+
isintermediate(eq::Equation) = eq.lhs.diff === nothing
6262

63-
function build_equals_expr(eq)
64-
@assert typeof(eq.args[1]) <: Variable
65-
if !(isintermediate(eq))
66-
# Differential statement
67-
:($(Symbol("$(eq.args[1].name)_$(eq.args[1].diff.x.name)")) = $(eq.args[2]))
68-
else
69-
# Intermediate calculation
70-
:($(Symbol("$(eq.args[1].name)")) = $(eq.args[2]))
71-
end
63+
function build_equals_expr(eq::Equation)
64+
@assert typeof(eq.lhs) <: Variable
65+
66+
lhs = eq.lhs.name
67+
isintermediate(eq) || (lhs = Symbol(lhs, :_, "$(eq.lhs.diff.x.name)"))
68+
69+
return :($lhs = $(convert(Expr, eq.rhs)))
7270
end
7371

74-
function calculate_jacobian(sys::DiffEqSystem,simplify=true)
75-
diff_idxs = map(eq->eq.args[1].diff !=nothing,sys.eqs)
76-
diff_exprs = sys.eqs[diff_idxs]
77-
rhs = [eq.args[2] for eq in diff_exprs]
72+
function calculate_jacobian(sys::DiffEqSystem, simplify=true)
73+
calcs, diff_exprs = partition(isintermediate, sys.eqs)
74+
rhs = [eq.rhs for eq in diff_exprs]
75+
7876
# Handle intermediate calculations by substitution
79-
calcs = sys.eqs[.!(diff_idxs)]
80-
for i in 1:length(calcs)
81-
find_replace!.(rhs,calcs[i].args[1],calcs[i].args[2])
77+
for calc calcs
78+
find_replace!.(rhs, calc.lhs, calc.rhs)
8279
end
83-
sys_exprs = calculate_jacobian(rhs,sys.dvs)
80+
81+
sys_exprs = calculate_jacobian(rhs, sys.dvs)
8482
sys_exprs = Expression[expand_derivatives(expr) for expr in sys_exprs]
85-
if simplify
86-
sys_exprs = Expression[simplify_constants(expr) for expr in sys_exprs]
87-
end
8883
sys_exprs
8984
end
9085

91-
function generate_ode_jacobian(sys::DiffEqSystem,simplify=true)
86+
function generate_ode_jacobian(sys::DiffEqSystem, simplify=true)
9287
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in 1:length(sys.dvs)]
9388
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
94-
diff_idxs = map(eq->eq.args[1].diff !=nothing,sys.eqs)
95-
diff_exprs = sys.eqs[diff_idxs]
96-
jac = calculate_jacobian(sys,simplify)
89+
diff_exprs = filter(!isintermediate, sys.eqs)
90+
jac = calculate_jacobian(sys, simplify)
9791
sys.jac = jac
9892
jac_exprs = [:(J[$i,$j] = $(convert(Expr, jac[i,j]))) for i in 1:size(jac,1), j in 1:size(jac,2)]
9993
exprs = vcat(var_exprs,param_exprs,vec(jac_exprs))
10094
block = expr_arr_to_block(exprs)
10195
:((J,u,p,t)->$(block))
10296
end
10397

104-
function generate_ode_iW(sys::DiffEqSystem,simplify=true)
98+
function generate_ode_iW(sys::DiffEqSystem, simplify=true)
10599
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in 1:length(sys.dvs)]
106100
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
107-
diff_idxs = map(eq->eq.args[1].diff !=nothing,sys.eqs)
108-
diff_exprs = sys.eqs[diff_idxs]
101+
diff_exprs = filter(!isintermediate, sys.eqs)
109102
jac = sys.jac
110103

111104
gam = Variable(:gam)
@@ -144,5 +137,6 @@ function DiffEqBase.ODEFunction(sys::DiffEqSystem;version = ArrayFunction,kwargs
144137
end
145138
end
146139

140+
147141
export DiffEqSystem, ODEFunction
148142
export generate_ode_function

src/systems/diffeqs/first_order_transform.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
extract_idv(eq::Equation) = eq.lhs.diff.x
2+
13
function lower_varname(var::Variable, naming_scheme; lower=false)
24
D = var.diff
35
D === nothing && return var
@@ -21,7 +23,7 @@ function ode_order_lowering!(eqs, naming_scheme)
2123
idv = extract_idv(eqs[ind])
2224
D = Differential(idv, 1)
2325
sym_order = Dict{Symbol, Int}()
24-
dv_name = eqs[1].args[1].subtype
26+
dv_name = eqs[1].lhs.subtype
2527
for eq in eqs
2628
isintermediate(eq) && continue
2729
sym, maxorder = extract_symbol_order(eq)
@@ -37,21 +39,18 @@ function ode_order_lowering!(eqs, naming_scheme)
3739
for o in (order-1):-1:1
3840
lhs = D(lower_varname(sym, idv, o-1, dv_name, naming_scheme))
3941
rhs = lower_varname(sym, idv, o, dv_name, naming_scheme)
40-
eq = Operation(==, [lhs, rhs])
42+
eq = Equation(lhs, rhs)
4143
push!(eqs, eq)
4244
end
4345
end
4446
eqs
4547
end
4648

4749
function lhs_renaming!(eq, D, naming_scheme)
48-
eq.args[1] = D(lower_varname(eq.args[1], naming_scheme, lower=true))
50+
eq.lhs = D(lower_varname(eq.lhs, naming_scheme, lower=true))
4951
return eq
5052
end
51-
function rhs_renaming!(eq, naming_scheme)
52-
rhs = eq.args[2]
53-
_rec_renaming!(rhs, naming_scheme)
54-
end
53+
rhs_renaming!(eq, naming_scheme) = _rec_renaming!(eq.rhs, naming_scheme)
5554

5655
function _rec_renaming!(rhs, naming_scheme)
5756
rhs isa Variable && rhs.diff != nothing && return lower_varname(rhs, naming_scheme)
@@ -66,7 +65,7 @@ end
6665

6766
function extract_symbol_order(eq)
6867
# We assume that the differential with the highest order is always going to be in the LHS
69-
dv = eq.args[1]
68+
dv = eq.lhs
7069
sym = dv.name
7170
order = dv.diff.order
7271
sym, order

src/systems/nonlinear/nonlinear_system.jl

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
struct NonlinearSystem <: AbstractSystem
2-
eqs::Vector{Operation}
2+
eqs::Vector{Equation}
33
vs::Vector{Variable}
44
ps::Vector{Variable}
55
v_name::Vector{Symbol}
@@ -26,32 +26,25 @@ end
2626
function generate_nlsys_function(sys::NonlinearSystem)
2727
var_exprs = [:($(sys.vs[i].name) = u[$i]) for i in 1:length(sys.vs)]
2828
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
29-
sys_idxs = map(eq->isequal(eq.args[1],Constant(0)),sys.eqs)
30-
sys_eqs = sys.eqs[sys_idxs]
31-
calc_eqs = sys.eqs[.!(sys_idxs)]
32-
calc_exprs = [:($(Symbol("$(eq.args[1].name)")) = $(eq.args[2])) for eq in calc_eqs]
33-
sys_exprs = [:($(Symbol("resid[$i]")) = $(sys_eqs[i].args[2])) for i in eachindex(sys_eqs)]
29+
sys_eqs, calc_eqs = partition(eq -> isequal(eq.lhs, Constant(0)), sys.eqs)
30+
calc_exprs = [:($(eq.lhs.name) = $(eq.rhs)) for eq in calc_eqs if isa(eq.lhs, Variable)]
31+
sys_exprs = [:($(Symbol("resid[$i]")) = $(sys_eqs[i].rhs)) for i in eachindex(sys_eqs)]
3432

3533
exprs = vcat(var_exprs,param_exprs,calc_exprs,sys_exprs)
3634
block = expr_arr_to_block(exprs)
3735
:((du,u,p)->$(block))
3836
end
3937

4038
function calculate_jacobian(sys::NonlinearSystem,simplify=true)
41-
sys_idxs = map(eq->isequal(eq.args[1],Constant(0)),sys.eqs)
42-
sys_eqs = sys.eqs[sys_idxs]
43-
calc_eqs = sys.eqs[.!(sys_idxs)]
44-
rhs = [eq.args[2] for eq in sys_eqs]
39+
sys_eqs, calc_eqs = partition(eq -> isequal(eq.lhs, Constant(0)), sys.eqs)
40+
rhs = [eq.rhs for eq in sys_eqs]
4541

46-
for i in 1:length(calc_eqs)
47-
find_replace!.(rhs,calc_eqs[i].args[1],calc_eqs[i].args[2])
42+
for calc_eq calc_eqs
43+
find_replace!.(rhs, calc_eq.lhs, calc_eq.rhs)
4844
end
4945

5046
sys_exprs = calculate_jacobian(rhs,sys.vs)
5147
sys_exprs = Expression[expand_derivatives(expr) for expr in sys_exprs]
52-
if simplify
53-
sys_exprs = Expression[simplify_constants(expr) for expr in sys_exprs]
54-
end
5548
sys_exprs
5649
end
5750

src/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ end
3333

3434
toexpr(ex) = MacroTools.postwalk(x -> isa(x, Expression) ? convert(Expr, x) : x, ex)
3535

36+
function partition(f, xs)
37+
idxs = map(f, xs)
38+
not_idxs = eachindex(xs) .∉ (idxs,)
39+
return (xs[idxs], xs[not_idxs])
40+
end
41+
3642
is_constant(::Constant) = true
3743
is_constant(::Any) = false
3844

src/variables.jl

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -55,44 +55,6 @@ function Base.show(io::IO, x::Variable)
5555
x.diff === nothing || print(io, ", diff = ", x.diff)
5656
end
5757

58-
extract_idv(eq) = eq.args[1].diff.x
59-
60-
function extract_elements(ops, targetmap, default = nothing)
61-
elems = Dict{Symbol, Vector{Variable}}()
62-
names = Dict{Symbol, Set{Symbol}}()
63-
if default == nothing
64-
targets = unique(collect(values(targetmap)))
65-
else
66-
targets = [unique(collect(values(targetmap))), default]
67-
end
68-
for target in targets
69-
elems[target] = Vector{Variable}()
70-
names[target] = Set{Symbol}()
71-
end
72-
for op in ops
73-
extract_elements!(op, elems, names, targetmap, default)
74-
end
75-
Tuple(elems[target] for target in targets)
76-
end
77-
# Walk the tree recursively and push variables into the right set
78-
function extract_elements!(op::AbstractOperation, elems, names, targetmap, default)
79-
for arg in op.args
80-
if arg isa Operation
81-
extract_elements!(arg, elems, names, targetmap, default)
82-
elseif arg isa Variable
83-
if default == nothing
84-
target = haskey(targetmap, arg.subtype) ? targetmap[arg.subtype] : continue
85-
else
86-
target = haskey(targetmap, arg.subtype) ? targetmap[arg.subtype] : default
87-
end
88-
if !in(arg.name, names[target])
89-
push!(names[target], arg.name)
90-
push!(elems[target], arg)
91-
end
92-
end
93-
end
94-
end
95-
9658
# Build variables more easily
9759
function _parse_vars(macroname, fun, x)
9860
ex = Expr(:block)

0 commit comments

Comments
 (0)