Skip to content

Commit 1fae6e6

Browse files
Merge pull request #87 from JuliaDiffEq/hg/refactor/variable
Remove diff field from Variable
2 parents c7c2d88 + c711a8e commit 1fae6e6

File tree

7 files changed

+55
-72
lines changed

7 files changed

+55
-72
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ context-aware single variable of the IR. Its fields are described as follows:
152152
the core identifier of the `Variable` in the sense of equality.
153153
- `subtype`: the main denotation of context. Variables within systems
154154
are grouped according to their `subtype`.
155-
- `diff`: the `Differential` object representing the quantity the variable is differentiated with respect to, or `nothing`
156155
- `dependents`: the vector of variables on which the current variable
157156
is dependent. For example, `u(t,x)` has dependents `[t,x]`. Derivatives thus
158157
require this information in order to simplify down.

src/differentials.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@ Base.convert(::Type{Expr}, D::Differential) = D
1111
function (D::Differential)(x::Variable)
1212
D.x === x && return Constant(1)
1313
has_dependent(x, D.x) || return Constant(0)
14-
15-
x′ = copy(x)
16-
x′.diff = D
17-
return x′
14+
return Operation(D, Expression[x])
1815
end
1916
Base.:(==)(D1::Differential, D2::Differential) = D1.order == D2.order && D1.x == D2.x
2017

src/equations.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ Base.:~(lhs::Expression, rhs::Number ) = Equation(lhs, rhs)
1212
Base.:~(lhs::Number , rhs::Expression) = Equation(lhs, rhs)
1313

1414

15-
_is_derivative(x::Variable) = x.diff !== nothing
1615
_is_dependent(x::Variable) = x.subtype === :Unknown && !isempty(x.dependents)
1716
_is_parameter(ivs) = x -> x.subtype === :Parameter && x ivs
1817
_subtype(subtype::Symbol) = x -> x.subtype === subtype

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,31 @@ mutable struct DiffEqSystem <: AbstractSystem
44
dvs::Vector{Variable}
55
ps::Vector{Variable}
66
jac::Matrix{Expression}
7+
function DiffEqSystem(eqs, ivs, dvs, ps, jac)
8+
all(!isintermediate, eqs) ||
9+
throw(ArgumentError("no intermediate equations permitted in DiffEqSystem"))
10+
11+
new(eqs, ivs, dvs, ps, jac)
12+
end
713
end
814

915
DiffEqSystem(eqs, ivs, dvs, ps) = DiffEqSystem(eqs, ivs, dvs, ps, Matrix{Expression}(undef,0,0))
1016

1117
function DiffEqSystem(eqs)
12-
predicates = [_is_derivative, _is_dependent]
13-
_, dvs = extract_elements(eqs, predicates)
18+
dvs, = extract_elements(eqs, [_is_dependent])
1419
ivs = unique(vcat((dv.dependents for dv dvs)...))
1520
ps, = extract_elements(eqs, [_is_parameter(ivs)])
1621
DiffEqSystem(eqs, ivs, dvs, ps, Matrix{Expression}(undef,0,0))
1722
end
1823

1924
function DiffEqSystem(eqs, ivs)
20-
predicates = [_is_derivative, _is_dependent, _is_parameter(ivs)]
21-
_, dvs, ps = extract_elements(eqs, predicates)
25+
dvs, ps = extract_elements(eqs, [_is_dependent, _is_parameter(ivs)])
2226
DiffEqSystem(eqs, ivs, dvs, ps, Matrix{Expression}(undef,0,0))
2327
end
2428

29+
isintermediate(eq::Equation) = !(isa(eq.lhs, Operation) && isa(eq.lhs.op, Differential))
30+
31+
2532
function generate_ode_function(sys::DiffEqSystem;version = ArrayFunction)
2633
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
2734
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
@@ -44,25 +51,15 @@ function generate_ode_function(sys::DiffEqSystem;version = ArrayFunction)
4451
end
4552
end
4653

47-
isintermediate(eq::Equation) = eq.lhs.diff === nothing
48-
4954
function build_equals_expr(eq::Equation)
50-
@assert typeof(eq.lhs) <: Variable
51-
52-
lhs = eq.lhs.name
53-
isintermediate(eq) || (lhs = Symbol(lhs, :_, "$(eq.lhs.diff.x.name)"))
55+
@assert !isintermediate(eq)
5456

57+
lhs = Symbol(eq.lhs.args[1].name, :_, eq.lhs.op.x.name)
5558
return :($lhs = $(convert(Expr, eq.rhs)))
5659
end
5760

5861
function calculate_jacobian(sys::DiffEqSystem, simplify=true)
59-
calcs, diff_exprs = partition(isintermediate, sys.eqs)
60-
rhs = [eq.rhs for eq in diff_exprs]
61-
62-
# Handle intermediate calculations by substitution
63-
for calc calcs
64-
find_replace!.(rhs, calc.lhs, calc.rhs)
65-
end
62+
rhs = [eq.rhs for eq in sys.eqs]
6663

6764
sys_exprs = calculate_jacobian(rhs, sys.dvs)
6865
sys_exprs = Expression[expand_derivatives(expr) for expr in sys_exprs]
@@ -72,7 +69,6 @@ end
7269
function generate_ode_jacobian(sys::DiffEqSystem, simplify=true)
7370
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
7471
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
75-
diff_exprs = filter(!isintermediate, sys.eqs)
7672
jac = calculate_jacobian(sys, simplify)
7773
sys.jac = jac
7874
jac_exprs = [:(J[$i,$j] = $(convert(Expr, jac[i,j]))) for i in 1:size(jac,1), j in 1:size(jac,2)]
@@ -84,7 +80,6 @@ end
8480
function generate_ode_iW(sys::DiffEqSystem, simplify=true)
8581
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
8682
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
87-
diff_exprs = filter(!isintermediate, sys.eqs)
8883
jac = sys.jac
8984

9085
gam = Parameter(:gam)
Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1-
extract_idv(eq::Equation) = eq.lhs.diff.x
1+
extract_idv(eq::Equation) = eq.lhs.op.x
22

3-
function lower_varname(var::Variable, naming_scheme; lower=false)
4-
D = var.diff
5-
D === nothing && return var
3+
function lower_varname(O::Operation, naming_scheme; lower=false)
4+
@assert isa(O.op, Differential)
5+
6+
D, x = O.op, O.args[1]
67
order = lower ? D.order-1 : D.order
7-
lower_varname(var.name, D.x, order, var.subtype, naming_scheme)
8+
9+
lower_varname(x, D.x, order, naming_scheme)
810
end
9-
function lower_varname(sym::Symbol, idv, order::Int, subtype::Symbol, naming_scheme)
11+
function lower_varname(var::Variable, idv, order::Int, naming_scheme)
12+
sym = var.name
1013
name = order == 0 ? sym : Symbol(sym, naming_scheme, string(idv.name)^order)
11-
return Variable(name, subtype=subtype)
14+
return Variable(name, var.subtype, var.dependents)
1215
end
1316

1417
function ode_order_lowering(sys::DiffEqSystem; kwargs...)
@@ -19,31 +22,34 @@ function ode_order_lowering(sys::DiffEqSystem; kwargs...)
1922
end
2023
ode_order_lowering(eqs; naming_scheme = "_") = ode_order_lowering!(deepcopy(eqs), naming_scheme)
2124
function ode_order_lowering!(eqs, naming_scheme)
22-
ind = findfirst(x->!(isintermediate(x)), eqs)
23-
idv = extract_idv(eqs[ind])
25+
idv = extract_idv(eqs[1])
2426
D = Differential(idv, 1)
25-
sym_order = Dict{Symbol, Int}()
26-
dv_name = eqs[1].lhs.subtype
27+
var_order = Dict{Variable,Int}()
28+
vars = Variable[]
29+
dv_name = eqs[1].lhs.args[1].subtype
30+
2731
for eq in eqs
28-
isintermediate(eq) && continue
29-
sym, maxorder = extract_symbol_order(eq)
32+
var, maxorder = extract_var_order(eq)
3033
maxorder == 1 && continue # fast pass
31-
if maxorder > get(sym_order, sym, 0)
32-
sym_order[sym] = maxorder
34+
if maxorder > get(var_order, var, 0)
35+
var_order[var] = maxorder
36+
var vars || push!(vars, var)
3337
end
34-
eq = lhs_renaming!(eq, D, naming_scheme)
35-
eq = rhs_renaming!(eq, naming_scheme)
38+
lhs_renaming!(eq, D, naming_scheme)
39+
rhs_renaming!(eq, naming_scheme)
3640
end
37-
for sym in keys(sym_order)
38-
order = sym_order[sym]
41+
42+
for var vars
43+
order = var_order[var]
3944
for o in (order-1):-1:1
40-
lhs = D(lower_varname(sym, idv, o-1, dv_name, naming_scheme))
41-
rhs = lower_varname(sym, idv, o, dv_name, naming_scheme)
45+
lhs = D(lower_varname(var, idv, o-1, naming_scheme))
46+
rhs = lower_varname(var, idv, o, naming_scheme)
4247
eq = Equation(lhs, rhs)
4348
push!(eqs, eq)
4449
end
4550
end
46-
eqs
51+
52+
return eqs
4753
end
4854

4955
function lhs_renaming!(eq, D, naming_scheme)
@@ -53,7 +59,7 @@ end
5359
rhs_renaming!(eq, naming_scheme) = _rec_renaming!(eq.rhs, naming_scheme)
5460

5561
function _rec_renaming!(rhs, naming_scheme)
56-
rhs isa Variable && rhs.diff != nothing && return lower_varname(rhs, naming_scheme)
62+
isa(rhs, Operation) && isa(rhs.op, Differential) && return lower_varname(rhs, naming_scheme)
5763
if rhs isa Operation
5864
args = rhs.args
5965
for i in eachindex(args)
@@ -63,12 +69,12 @@ function _rec_renaming!(rhs, naming_scheme)
6369
rhs
6470
end
6571

66-
function extract_symbol_order(eq)
72+
function extract_var_order(eq)
6773
# We assume that the differential with the highest order is always going to be in the LHS
6874
dv = eq.lhs
69-
sym = dv.name
70-
order = dv.diff.order
71-
sym, order
75+
var = dv.args[1]
76+
order = dv.op.order
77+
return (var, order)
7278
end
7379

7480
export ode_order_lowering

src/variables.jl

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,15 @@
1-
mutable struct Variable <: Expression
1+
struct Variable <: Expression
22
name::Symbol
33
subtype::Symbol
4-
diff::Union{Function,Nothing} # FIXME
54
dependents::Vector{Variable}
65
end
76

8-
Variable(name; subtype::Symbol, dependents::Vector{Variable} = Variable[]) =
9-
Variable(name, subtype, nothing, dependents)
10-
11-
Parameter(name; kwargs...) = Variable(name; subtype=:Parameter, kwargs...)
12-
Unknown(name, ;kwargs...) = Variable(name; subtype=:Unknown, kwargs...)
7+
Parameter(name; dependents = Variable[]) = Variable(name, :Parameter, dependents)
8+
Unknown(name; dependents = Variable[]) = Variable(name, :Unknown, dependents)
139

1410
export Variable, Unknown, Parameter, Constant, @Unknown, @Param, @Const
1511

1612

17-
Base.copy(x::Variable) = Variable(x.name, x.subtype, x.diff, x.dependents)
18-
19-
2013
struct Constant <: Expression
2114
value::Number
2215
end
@@ -28,7 +21,7 @@ Base.isone(ex::Expression) = isa(ex, Constant) && isone(ex.value)
2821

2922

3023
# Variables use isequal for equality since == is an Operation
31-
Base.:(==)(x::Variable, y::Variable) = (x.name, x.subtype, x.diff) == (y.name, y.subtype, y.diff)
24+
Base.:(==)(x::Variable, y::Variable) = (x.name, x.subtype) == (y.name, y.subtype)
3225
Base.:(==)(::Variable, ::Number) = false
3326
Base.:(==)(::Number, ::Variable) = false
3427
Base.:(==)(::Variable, ::Constant) = false
@@ -37,16 +30,10 @@ Base.:(==)(c::Constant, n::Number) = c.value == n
3730
Base.:(==)(n::Number, c::Constant) = c.value == n
3831
Base.:(==)(a::Constant, b::Constant) = a.value == b.value
3932

40-
function Base.convert(::Type{Expr}, x::Variable)
41-
x.diff === nothing && return x.name
42-
return Symbol("$(x.name)_$(x.diff.x.name)")
43-
end
33+
Base.convert(::Type{Expr}, x::Variable) = x.name
4434
Base.convert(::Type{Expr}, c::Constant) = c.value
4535

46-
function Base.show(io::IO, x::Variable)
47-
print(io, x.subtype, '(', x.name, ')')
48-
x.diff === nothing || print(io, ", diff = ", x.diff)
49-
end
36+
Base.show(io::IO, x::Variable) = print(io, x.subtype, '(', x.name, ')')
5037

5138
# Build variables more easily
5239
function _parse_vars(macroname, fun, x)

test/system_construction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ function test_eqs(eqs1, eqs2)
5757
end
5858
eq
5959
end
60-
@test_broken test_eqs(de1.eqs, lowered_eqs)
60+
@test test_eqs(de1.eqs, lowered_eqs)
6161

6262
# Internal calculations
6363
a = y - x

0 commit comments

Comments
 (0)