Skip to content

Commit 02edf04

Browse files
Merge pull request #98 from JuliaDiffEq/hg/refactor/diff
Remove order from Differential
2 parents 3a1b5a1 + 6c75a23 commit 02edf04

File tree

7 files changed

+48
-36
lines changed

7 files changed

+48
-36
lines changed

src/differentials.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
struct Differential <: Function
22
x::Expression
3-
order::Int
43
end
5-
Differential(x) = Differential(x,1)
64

7-
Base.show(io::IO, D::Differential) = print(io,"($(D.x),$(D.order))")
5+
Base.show(io::IO, D::Differential) = print(io, "(D'~", D.x, ")")
86
Base.convert(::Type{Expr}, D::Differential) = D
97

108
(D::Differential)(x::Operation) = Operation(D, Expression[x])
@@ -13,8 +11,8 @@ function (D::Differential)(x::Variable)
1311
has_dependent(x, D.x) || return Constant(0)
1412
return Operation(D, Expression[x])
1513
end
16-
(::Differential)(::Constant) = Constant(0)
17-
Base.:(==)(D1::Differential, D2::Differential) = D1.order == D2.order && D1.x == D2.x
14+
(::Differential)(::Any) = Constant(0)
15+
Base.:(==)(D1::Differential, D2::Differential) = D1.x == D2.x
1816

1917
function expand_derivatives(O::Operation)
2018
@. O.args = expand_derivatives(O.args)
@@ -56,6 +54,7 @@ function count_order(x)
5654
n, x.args[1]
5755
end
5856

57+
_repeat_apply(f, n) = n == 1 ? f : f _repeat_apply(f, n-1)
5958
function _differential_macro(x)
6059
ex = Expr(:block)
6160
lhss = Symbol[]
@@ -66,7 +65,7 @@ function _differential_macro(x)
6665
rhs = di.args[3]
6766
order, lhs = count_order(lhs)
6867
push!(lhss, lhs)
69-
expr = :($lhs = Differential($rhs, $order))
68+
expr = :($lhs = $_repeat_apply(Differential($rhs), $order))
7069
push!(ex.args, expr)
7170
end
7271
push!(ex.args, Expr(:tuple, lhss...))

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,28 @@ using Base: RefValue
66

77
isintermediate(eq::Equation) = !(isa(eq.lhs, Operation) && isa(eq.lhs.op, Differential))
88

9-
struct DiffEq # D(x) = t
10-
D::Differential # D
11-
var::Variable # x
12-
rhs::Expression # t
9+
function flatten_differential(O::Operation)
10+
@assert is_derivative(O) "invalid differential: $O"
11+
is_derivative(O.args[1]) || return (O.args[1], O.op.x, 1)
12+
(x, t, order) = flatten_differential(O.args[1])
13+
t == O.op.x || throw(ArgumentError("non-matching differentials on lhs: $t, $(O.op.x)"))
14+
return (x, t, order + 1)
15+
end
16+
17+
18+
struct DiffEq # dⁿx/dtⁿ = rhs
19+
x::Expression
20+
t::Variable
21+
n::Int
22+
rhs::Expression
1323
end
1424
function Base.convert(::Type{DiffEq}, eq::Equation)
1525
isintermediate(eq) && throw(ArgumentError("intermediate equation received"))
16-
return DiffEq(eq.lhs.op, eq.lhs.args[1], eq.rhs)
26+
(x, t, n) = flatten_differential(eq.lhs)
27+
return DiffEq(x, t, n, eq.rhs)
1728
end
18-
Base.:(==)(a::DiffEq, b::DiffEq) = (a.D, a.var, a.rhs) == (b.D, b.var, b.rhs)
19-
get_args(eq::DiffEq) = Expression[eq.var, eq.rhs]
29+
Base.:(==)(a::DiffEq, b::DiffEq) = (a.x, a.t, a.n, a.rhs) == (b.x, b.t, b.n, b.rhs)
30+
get_args(eq::DiffEq) = Expression[eq.x, eq.t, eq.rhs]
2031

2132
struct DiffEqSystem <: AbstractSystem
2233
eqs::Vector{DiffEq}

src/systems/diffeqs/first_order_transform.jl

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,35 @@
1-
function lower_varname(D::Differential, x; lower=false)
2-
order = lower ? D.order-1 : D.order
3-
return lower_varname(x, D.x, order)
4-
end
5-
function lower_varname(var::Variable, idv, order::Int)
6-
sym = var.name
7-
name = order == 0 ? sym : Symbol(sym, :_, string(idv.name)^order)
1+
function lower_varname(var::Variable, idv, order)
2+
order == 0 && return var
3+
name = Symbol(var.name, :_, string(idv.name)^order)
84
return Variable(name, var.known, var.dependents)
95
end
106

117
function ode_order_lowering(sys::DiffEqSystem)
128
eqs_lowered = ode_order_lowering(sys.eqs, sys.iv)
13-
DiffEqSystem(eqs_lowered, sys.iv)
9+
DiffEqSystem(eqs_lowered, sys.iv, sys.dvs, sys.ps)
1410
end
1511
function ode_order_lowering(eqs, iv)
16-
D = Differential(iv, 1)
1712
var_order = Dict{Variable,Int}()
1813
vars = Variable[]
1914
new_eqs = similar(eqs, DiffEq)
2015

2116
for (i, eq) enumerate(eqs)
22-
var, maxorder = eq.var, eq.D.order
23-
maxorder == 1 && continue # fast pass
17+
var, maxorder = eq.x, eq.n
2418
if maxorder > get(var_order, var, 0)
2519
var_order[var] = maxorder
2620
var vars || push!(vars, var)
2721
end
28-
var′ = lower_varname(eq.D, eq.var, lower = true)
22+
var′ = lower_varname(eq.x, eq.t, eq.n - 1)
2923
rhs′ = rename(eq.rhs)
30-
new_eqs[i] = DiffEq(D, var′, rhs′)
24+
new_eqs[i] = DiffEq(var′, iv, 1, rhs′)
3125
end
3226

3327
for var vars
3428
order = var_order[var]
3529
for o in (order-1):-1:1
3630
lvar = lower_varname(var, iv, o-1)
3731
rhs = lower_varname(var, iv, o)
38-
eq = DiffEq(D, lvar, rhs)
32+
eq = DiffEq(lvar, iv, 1, rhs)
3933
push!(new_eqs, eq)
4034
end
4135
end
@@ -45,7 +39,10 @@ end
4539

4640
function rename(O::Expression)
4741
isa(O, Operation) || return O
48-
isa(O.op, Differential) && return lower_varname(O.op, O.args[1])
42+
if is_derivative(O)
43+
(x, t, order) = flatten_differential(O)
44+
return lower_varname(x, t, order)
45+
end
4946
return Operation(O.op, rename.(O.args))
5047
end
5148

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ is_constant(::Any) = false
6060
is_operation(::Operation) = true
6161
is_operation(::Any) = false
6262

63+
is_derivative(O::Operation) = isa(O.op, Differential)
64+
is_derivative(::Any) = false
65+
6366
has_dependent(t::Variable) = Base.Fix2(has_dependent, t)
6467
has_dependent(x::Variable, t::Variable) =
6568
t x.dependents || any(has_dependent(t), x.dependents)

src/variables.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,7 @@ function Base.convert(::Type{Expr}, x::Variable)
3838
end
3939
Base.convert(::Type{Expr}, c::Constant) = c.value
4040

41-
function Base.show(io::IO, x::Variable)
42-
subtype = x.known ? :Parameter : :Unknown
43-
print(io, subtype, '(', repr(x.name))
44-
isempty(x.dependents) || print(io, ", ", x.dependents)
45-
print(io, ')')
46-
end
41+
Base.show(io::IO, x::Variable) = print(io, x.name)
4742

4843
# Build variables more easily
4944
function _parse_vars(macroname, fun, x)

test/derivatives.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ dsin = D(sin(t))
1515
dcsch = D(csch(t))
1616
@test expand_derivatives(dcsch) == simplify_constants(coth(t) * csch(t) * -1)
1717

18+
@test expand_derivatives(D(-7)) == 0
19+
@test expand_derivatives(D(sin(2t))) == simplify_constants(cos(2t) * 2)
20+
@test expand_derivatives(D2(sin(t))) == simplify_constants(-sin(t))
21+
@test expand_derivatives(D2(sin(2t))) == simplify_constants(sin(2t) * -4)
22+
@test expand_derivatives(D2(t)) == 0
23+
@test expand_derivatives(D2(5)) == 0
24+
1825
# Chain rule
1926
dsinsin = D(sin(sin(t)))
2027
@test expand_derivatives(dsinsin) == cos(sin(t))*cos(t)

test/variable_parsing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ s1 = Parameter(:s)
2727
@test convert(Expr, s) == :s
2828
@test convert(Expr, cos(t + sin(s))) == :(cos(t + sin(s)))
2929

30-
@Deriv D''~t
31-
D1 = Differential(t, 2)
30+
@Deriv D'~t
31+
D1 = Differential(t)
3232
@test D1 == D
3333
@test convert(Expr, D) == D

0 commit comments

Comments
 (0)