Skip to content

Commit 7424631

Browse files
Refactor differentials
Remove `order`, fixes higher-order derivatives.
1 parent 25d1ce6 commit 7424631

File tree

4 files changed

+37
-30
lines changed

4 files changed

+37
-30
lines changed

src/differentials.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
struct Differential <: Function
22
x::Expression
3-
order::Int
43
end
5-
Differential(x) = Differential(x,1)
64

7-
Base.show(io::IO, D::Differential) = print(io,"($(D.x),$(D.order))")
85
Base.convert(::Type{Expr}, D::Differential) = D
96

107
(D::Differential)(x::Operation) = Operation(D, Expression[x])
@@ -13,8 +10,8 @@ function (D::Differential)(x::Variable)
1310
has_dependent(x, D.x) || return Constant(0)
1411
return Operation(D, Expression[x])
1512
end
16-
(::Differential)(::Constant) = Constant(0)
17-
Base.:(==)(D1::Differential, D2::Differential) = D1.order == D2.order && D1.x == D2.x
13+
(::Differential)(::Any) = Constant(0)
14+
Base.:(==)(D1::Differential, D2::Differential) = D1.x == D2.x
1815

1916
function expand_derivatives(O::Operation)
2017
@. O.args = expand_derivatives(O.args)
@@ -56,6 +53,7 @@ function count_order(x)
5653
n, x.args[1]
5754
end
5855

56+
_repeat_apply(f, n) = n == 1 ? f : f _repeat_apply(f, n-1)
5957
function _differential_macro(x)
6058
ex = Expr(:block)
6159
lhss = Symbol[]
@@ -66,7 +64,7 @@ function _differential_macro(x)
6664
rhs = di.args[3]
6765
order, lhs = count_order(lhs)
6866
push!(lhss, lhs)
69-
expr = :($lhs = Differential($rhs, $order))
67+
expr = :($lhs = $_repeat_apply(Differential($rhs), $order))
7068
push!(ex.args, expr)
7169
end
7270
push!(ex.args, Expr(:tuple, lhss...))

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,29 @@ using Base: RefValue
66

77
isintermediate(eq::Equation) = !(isa(eq.lhs, Operation) && isa(eq.lhs.op, Differential))
88

9-
struct DiffEq # D(x) = t
10-
D::Differential # D
11-
var::Variable # x
12-
rhs::Expression # t
9+
function _unwrap_differenital(O)
10+
isa(O, Operation) || return (O, nothing, 0)
11+
isa(O.op, Differential) || return (O, nothing, 0)
12+
(x, t, order) = _unwrap_differenital(O.args[1])
13+
t === nothing && (t = O.op.x)
14+
t == O.op.x || throw(ArgumentError("non-matching differentials on lhs"))
15+
return (x, t, order + 1)
16+
end
17+
18+
19+
struct DiffEq # dⁿx/dtⁿ = rhs
20+
x::Expression
21+
t::Variable
22+
n::Int
23+
rhs::Expression
1324
end
1425
function Base.convert(::Type{DiffEq}, eq::Equation)
1526
isintermediate(eq) && throw(ArgumentError("intermediate equation received"))
16-
return DiffEq(eq.lhs.op, eq.lhs.args[1], eq.rhs)
27+
(x, t, n) = _unwrap_differenital(eq.lhs)
28+
return DiffEq(x, t, n, eq.rhs)
1729
end
18-
Base.:(==)(a::DiffEq, b::DiffEq) = (a.D, a.var, a.rhs) == (b.D, b.var, b.rhs)
19-
get_args(eq::DiffEq) = Expression[eq.var, eq.rhs]
30+
Base.:(==)(a::DiffEq, b::DiffEq) = (a.x, a.t, a.n, a.rhs) == (b.x, b.t, b.n, b.rhs)
31+
get_args(eq::DiffEq) = Expression[eq.x, eq.t, eq.rhs]
2032

2133
struct DiffEqSystem <: AbstractSystem
2234
eqs::Vector{DiffEq}

src/systems/diffeqs/first_order_transform.jl

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,35 @@
1-
function lower_varname(D::Differential, x; lower=false)
2-
order = lower ? D.order-1 : D.order
3-
return lower_varname(x, D.x, order)
4-
end
5-
function lower_varname(var::Variable, idv, order::Int)
6-
sym = var.name
7-
name = order == 0 ? sym : Symbol(sym, :_, string(idv.name)^order)
1+
function lower_varname(var::Variable, idv, order = 0)
2+
order == 0 && return var
3+
name = Symbol(var.name, :_, string(idv.name)^order)
84
return Variable(name, var.known, var.dependents)
95
end
106

117
function ode_order_lowering(sys::DiffEqSystem)
128
eqs_lowered = ode_order_lowering(sys.eqs, sys.iv)
13-
DiffEqSystem(eqs_lowered, sys.iv)
9+
DiffEqSystem(eqs_lowered, sys.iv, sys.dvs, sys.ps)
1410
end
1511
function ode_order_lowering(eqs, iv)
16-
D = Differential(iv, 1)
1712
var_order = Dict{Variable,Int}()
1813
vars = Variable[]
1914
new_eqs = similar(eqs, DiffEq)
2015

2116
for (i, eq) enumerate(eqs)
22-
var, maxorder = eq.var, eq.D.order
23-
maxorder == 1 && continue # fast pass
17+
var, maxorder = eq.x, eq.n
2418
if maxorder > get(var_order, var, 0)
2519
var_order[var] = maxorder
2620
var vars || push!(vars, var)
2721
end
28-
var′ = lower_varname(eq.D, eq.var, lower = true)
22+
var′ = lower_varname(eq.x, eq.t, eq.n - 1)
2923
rhs′ = rename(eq.rhs)
30-
new_eqs[i] = DiffEq(D, var′, rhs′)
24+
new_eqs[i] = DiffEq(var′, iv, 1, rhs′)
3125
end
3226

3327
for var vars
3428
order = var_order[var]
3529
for o in (order-1):-1:1
3630
lvar = lower_varname(var, iv, o-1)
3731
rhs = lower_varname(var, iv, o)
38-
eq = DiffEq(D, lvar, rhs)
32+
eq = DiffEq(lvar, iv, 1, rhs)
3933
push!(new_eqs, eq)
4034
end
4135
end
@@ -45,7 +39,10 @@ end
4539

4640
function rename(O::Expression)
4741
isa(O, Operation) || return O
48-
isa(O.op, Differential) && return lower_varname(O.op, O.args[1])
42+
if isa(O.op, Differential)
43+
(x, t, order) = _unwrap_differenital(O)
44+
return lower_varname(x, t, order)
45+
end
4946
return Operation(O.op, rename.(O.args))
5047
end
5148

test/variable_parsing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ s1 = Parameter(:s)
2727
@test convert(Expr, s) == :s
2828
@test convert(Expr, cos(t + sin(s))) == :(cos(t + sin(s)))
2929

30-
@Deriv D''~t
31-
D1 = Differential(t, 2)
30+
@Deriv D'~t
31+
D1 = Differential(t)
3232
@test D1 == D
3333
@test convert(Expr, D) == D

0 commit comments

Comments
 (0)