Skip to content

Commit f7f45e2

Browse files
Turn Differential into a higher-order function
1 parent 7e5da57 commit f7f45e2

File tree

8 files changed

+52
-55
lines changed

8 files changed

+52
-55
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ using ModelingToolkit
3737
Then we build the system:
3838

3939
```julia
40-
eqs = [D*x ~ σ*(y-x),
41-
D*y ~ x*-z)-y,
42-
D*z ~ x*y - β*z]
40+
eqs = [D(x) ~ σ*(y-x),
41+
D(y) ~ x*-z)-y,
42+
D(z) ~ x*y - β*z]
4343
```
4444

4545
Each operation builds an `Operation` type, and thus `eqs` is an array of
@@ -255,7 +255,7 @@ to better scale to larger systems. You can define derivatives for your own
255255
function via the dispatch:
256256

257257
```julia
258-
ModelingToolkit.Derivative(::typeof(my_function),args,::Type{Val{i}})
258+
ModelingToolkit.Derivative(::typeof(my_function),args,::Val{i})
259259
```
260260

261261
where `i` means that it's the derivative of the `i`th argument. `args` is the
@@ -265,7 +265,7 @@ You should return an `Operation` for the derivative of your function.
265265
For example, `sin(t)`'s derivative (by `t`) is given by the following:
266266

267267
```julia
268-
ModelingToolkit.Derivative(::typeof(sin),args,::Type{Val{1}}) = cos(args[1])
268+
ModelingToolkit.Derivative(::typeof(sin),args,::Val{1}) = cos(args[1])
269269
```
270270

271271
### Macro-free Usage

src/operators.jl

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,11 @@ Base.show(io::IO, D::Differential) = print(io,"($(D.x),$(D.order))")
88
Base.Expr(D::Differential) = :($(Symbol("D_$(D.x.name)_$(D.order)")))
99

1010
function Derivative end
11-
Base.:*(D::Differential,x::Operation) = Operation(Derivative,Expression[x,D])
12-
function Base.:*(D::Differential,x::Variable)
13-
if D.x === x
14-
return Constant(1)
15-
elseif D.x x.dependents
16-
return Constant(0)
17-
else
18-
return Variable(x,D)
19-
end
11+
(D::Differential)(x::Operation) = Operation(Derivative,Expression[x,D])
12+
function (D::Differential)(x::Variable)
13+
D.x === x && return Constant(1)
14+
D.x x.dependents && return Constant(0) # FIXME
15+
return Variable(x,D)
2016
end
2117
Base.:(==)(D1::Differential, D2::Differential) = D1.order == D2.order && D1.x == D2.x
2218

@@ -25,22 +21,23 @@ Variable(x::Variable,D::Differential) = Variable(x.name,x.value,x.value_type,
2521
x.size,x.context)
2622

2723
function expand_derivatives(O::Operation)
24+
@. O.args = expand_derivatives(O.args)
25+
2826
if O.op == Derivative
2927
D = O.args[2]
3028
o = O.args[1]
31-
simplify_constants(sum(i->Derivative(o,i)*expand_derivatives(D*o.args[i]),1:length(o.args)))
32-
else
33-
for i in 1:length(O.args)
34-
O.args[i] = expand_derivatives(O.args[i])
35-
end
29+
return simplify_constants(sum(i->Derivative(o,i)*expand_derivatives(D(o.args[i])),1:length(o.args)))
3630
end
31+
32+
return O
3733
end
3834
expand_derivatives(x::Variable) = x
35+
expand_derivatives(D::Differential) = D
3936

4037
# Don't specialize on the function here
4138
function Derivative(O::Operation,idx)
4239
# This calls the Derivative dispatch from the user or pre-defined code
43-
Derivative(O.op,O.args,Val{idx})
40+
Derivative(O.op, O.args, Val(idx))
4441
end
4542

4643
# Pre-defined derivatives
@@ -83,7 +80,7 @@ function count_order(x)
8380
n, x.args[1]
8481
end
8582

86-
function _differetial_macro(x)
83+
function _differential_macro(x)
8784
ex = Expr(:block)
8885
lhss = Symbol[]
8986
x = flatten_expr!(x)
@@ -101,11 +98,11 @@ function _differetial_macro(x)
10198
end
10299

103100
macro Deriv(x...)
104-
esc(_differetial_macro(x))
101+
esc(_differential_macro(x))
105102
end
106103

107104
function calculate_jacobian(eqs,vars)
108-
Expression[Differential(vars[j])*eqs[i] for i in 1:length(eqs), j in 1:length(vars)]
105+
Expression[Differential(vars[j])(eqs[i]) for i in 1:length(eqs), j in 1:length(vars)]
109106
end
110107

111108
export Differential, Derivative, expand_derivatives, @Deriv, calculate_jacobian

src/systems/diffeqs/first_order_transform.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function ode_order_lowering!(eqs, naming_scheme)
3636
for sym in keys(sym_order)
3737
order = sym_order[sym]
3838
for o in (order-1):-1:1
39-
lhs = D*lower_varname(sym, idv, o-1, dv_name, naming_scheme)
39+
lhs = D(lower_varname(sym, idv, o-1, dv_name, naming_scheme))
4040
rhs = lower_varname(sym, idv, o, dv_name, naming_scheme)
4141
eq = Operation(==, [lhs, rhs])
4242
push!(eqs, eq)
@@ -46,7 +46,7 @@ function ode_order_lowering!(eqs, naming_scheme)
4646
end
4747

4848
function lhs_renaming!(eq, D, naming_scheme)
49-
eq.args[1] = D*lower_varname(eq.args[1], naming_scheme, lower=true)
49+
eq.args[1] = D(lower_varname(eq.args[1], naming_scheme, lower=true))
5050
return eq
5151
end
5252
function rhs_renaming!(eq, naming_scheme)

test/basic_variables_and_operations.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ s = JumpVariable(:s,3,dependents=[t])
1717
n = NoiseVariable(:n,dependents=[t])
1818

1919
σ*(y-x)
20-
D*x
21-
D*x ~ -σ*(y-x)
22-
D*y ~ x*-z)-sin(y)
20+
D(x)
21+
D(x) ~ -σ*(y-x)
22+
D(y) ~ x*-z)-sin(y)
2323

24-
@test D*t == Constant(1)
24+
@test D(t) == Constant(1)

test/components.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ struct Lorenz <: AbstractComponent
1212
end
1313
function generate_lorenz_eqs(t,x,y,z,σ,ρ,β)
1414
D = Differential(t)
15-
[D*x ~ σ*(y-x)
16-
D*y ~ x*-z)-y
17-
D*z ~ x*y - β*z]
15+
[D(x) ~ σ*(y-x)
16+
D(y) ~ x*-z)-y
17+
D(z) ~ x*y - β*z]
1818
end
1919
Lorenz(t) = Lorenz(first(@DVar(x(t))),first(@DVar(y(t))),first(@DVar(z(t))),first(@Param(σ)),first(@Param(ρ)),first(@Param(β)),generate_lorenz_eqs(t,x,y,z,σ,ρ,β))
2020

test/derivatives.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,24 @@ using Test
66
@Var x(t) y(t) z(t)
77
@Param σ ρ β
88
@Deriv D'~t
9-
dsin = D*sin(t)
9+
dsin = D(sin(t))
1010
expand_derivatives(dsin)
1111

1212
@test expand_derivatives(dsin) == cos(t)
13-
dcsch = D*csch(t)
13+
dcsch = D(csch(t))
1414
@test expand_derivatives(dcsch) == simplify_constants(Operation(coth(t)*csch(t)*-1))
1515

1616
# Chain rule
17-
dsinsin = D*sin(sin(t))
17+
dsinsin = D(sin(sin(t)))
1818
@test expand_derivatives(dsinsin) == cos(sin(t))*cos(t)
1919
# Binary
20-
dpow1 = Derivative(^,[x, y],Val{1})
21-
dpow2 = Derivative(^,[x, y],Val{2})
20+
dpow1 = Derivative(^,[x, y],Val(1))
21+
dpow2 = Derivative(^,[x, y],Val(2))
2222
@test dpow1 == y*x^(y-1)
2323
@test dpow2 == x^y*log(x)
2424

25-
d1 = D*(sin(t)*t)
26-
d2 = D*(sin(t)*cos(t))
25+
d1 = D(sin(t)*t)
26+
d2 = D(sin(t)*cos(t))
2727
@test expand_derivatives(d1) == t*cos(t)+sin(t)
2828
@test expand_derivatives(d2) == simplify_constants(cos(t)*cos(t)+sin(t)*(-1*sin(t)))
2929

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ using ModelingToolkit, Test
77
@testset "Domain Test" begin include("domains.jl") end
88
@testset "Simplify Test" begin include("simplify.jl") end
99
@testset "Ambiguity Test" begin include("ambiguity.jl") end
10-
@testset "Componets Test" begin include("components.jl") end
10+
@testset "Components Test" begin include("components.jl") end
1111
@testset "System Construction Test" begin include("system_construction.jl") end

test/system_construction.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ using Test
1010
@Var a
1111

1212
# Define a differential equation
13-
eqs = [D*x ~ σ*(y-x),
14-
D*y ~ x*-z)-y,
15-
D*z ~ x*y - β*z]
13+
eqs = [D(x) ~ σ*(y-x),
14+
D(y) ~ x*-z)-y,
15+
D(z) ~ x*y - β*z]
1616
de = DiffEqSystem(eqs,[t],[x,y,z],Variable[],[σ,ρ,β])
1717
ModelingToolkit.generate_ode_function(de)
1818
ModelingToolkit.generate_ode_function(de;version=ModelingToolkit.SArrayFunction)
@@ -37,15 +37,15 @@ test_vars_extraction(de, de2)
3737
@Deriv D3'''~t
3838
@Deriv D2''~t
3939
@DVar u(t) u_tt(t) u_t(t) x_t(t)
40-
eqs = [D3*u ~ 2(D2*u) + D*u + D*x + 1
41-
D2*x ~ D*x + 2]
40+
eqs = [D3(u) ~ 2(D2(u)) + D(u) + D(x) + 1
41+
D2(x) ~ D(x) + 2]
4242
de = DiffEqSystem(eqs, [t])
4343
de1 = ode_order_lowering(de)
44-
lowered_eqs = [D*u_tt ~ 2u_tt + u_t + x_t + 1
45-
D*x_t ~ x_t + 2
46-
D*u_t ~ u_tt
47-
D*u ~ u_t
48-
D*x ~ x_t]
44+
lowered_eqs = [D(u_tt) ~ 2u_tt + u_t + x_t + 1
45+
D(x_t) ~ x_t + 2
46+
D(u_t) ~ u_tt
47+
D(u) ~ u_t
48+
D(x) ~ x_t]
4949
function test_eqs(eqs1, eqs2)
5050
eq = true
5151
for i in eachindex(eqs1)
@@ -61,9 +61,9 @@ test_eqs(de1.eqs, lowered_eqs)
6161

6262
# Internal calculations
6363
eqs = [a ~ y-x,
64-
D*x ~ σ*a,
65-
D*y ~ x*-z)-y,
66-
D*z ~ x*y - β*z]
64+
D(x) ~ σ*a,
65+
D(y) ~ x*-z)-y,
66+
D(z) ~ x*y - β*z]
6767
de = DiffEqSystem(eqs,[t],[x,y,z],[a],[σ,ρ,β])
6868
ModelingToolkit.generate_ode_function(de)
6969
jac = ModelingToolkit.calculate_jacobian(de)
@@ -87,8 +87,8 @@ ModelingToolkit.generate_nlsys_function(ns)
8787
@Deriv D'~t
8888
@Param A B C
8989
eqs = [_x ~ y/C,
90-
D*x ~ -A*x,
91-
D*y ~ A*x - B*_x]
90+
D(x) ~ -A*x,
91+
D(y) ~ A*x - B*_x]
9292
de = DiffEqSystem(eqs,[t],[x,y],Variable[_x],[A,B,C])
9393
@test eval(ModelingToolkit.generate_ode_function(de))([0.0,0.0],[1.0,2.0],[1,2,3],0.0) -1/3
9494

0 commit comments

Comments
 (0)