Skip to content

Commit f4be6d8

Browse files
Merge branch 'master' into hg/fix/simplify
2 parents 972b46f + fc4dffc commit f4be6d8

13 files changed

+158
-164
lines changed

README.md

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ using ModelingToolkit
3737
Then we build the system:
3838

3939
```julia
40-
eqs = [D*x ~ σ*(y-x),
41-
D*y ~ x*-z)-y,
42-
D*z ~ x*y - β*z]
40+
eqs = [D(x) ~ σ*(y-x),
41+
D(y) ~ x*-z)-y,
42+
D(z) ~ x*y - β*z]
4343
```
4444

4545
Each operation builds an `Operation` type, and thus `eqs` is an array of
@@ -163,7 +163,7 @@ context-aware single variable of the IR. Its fields are described as follows:
163163
to set units or denote a variable as being of higher precision.
164164
- `subtype`: the main denotation of context. Variables within systems
165165
are grouped according to their `subtype`.
166-
- `diff`: the operator objects attached to the variable
166+
- `diff`: the `Differential` object representing the quantity the variable is differentiated with respect to, or `nothing`
167167
- `dependents`: the vector of variables on which the current variable
168168
is dependent. For example, `u(t,x)` has dependents `[t,x]`. Derivatives thus
169169
require this information in order to simplify down.
@@ -182,19 +182,13 @@ context-aware single variable of the IR. Its fields are described as follows:
182182
### Operations
183183

184184
Operations are the basic composition of variables and puts together the pieces
185-
with a function. The operator `~` is a special operator which denotes equality
186-
between the arguments.
185+
with a function. The `~` function denotes equality between the arguments.
187186

188-
### Operators
187+
### Differentials
189188

190-
An operator is an object which modifies variables via `*`. It adds the operator
191-
to the `diff` field of the variable and changes the interpretation of the variable.
192-
The current operators are:
193-
194-
- `Differential`: a differential denotes the derivative with respect to a given
195-
variable. It can be expanded via `expand_derivatives` which symbolically
196-
differentiates expressions recursively and cancels out appropriate constant
197-
variables.
189+
A `Differential` denotes the derivative with respect to a given variable. It can
190+
be expanded via `expand_derivatives`, which symbolically differentiates
191+
expressions recursively and cancels out appropriate constant variables.
198192

199193
### Systems
200194

@@ -255,7 +249,7 @@ to better scale to larger systems. You can define derivatives for your own
255249
function via the dispatch:
256250

257251
```julia
258-
ModelingToolkit.Derivative(::typeof(my_function),args,::Type{Val{i}})
252+
ModelingToolkit.Derivative(::typeof(my_function),args,::Val{i})
259253
```
260254

261255
where `i` means that it's the derivative of the `i`th argument. `args` is the
@@ -265,7 +259,7 @@ You should return an `Operation` for the derivative of your function.
265259
For example, `sin(t)`'s derivative (by `t`) is given by the following:
266260

267261
```julia
268-
ModelingToolkit.Derivative(::typeof(sin),args,::Type{Val{1}}) = cos(args[1])
262+
ModelingToolkit.Derivative(::typeof(sin),args,::Val{1}) = cos(args[1])
269263
```
270264

271265
### Macro-free Usage

src/ModelingToolkit.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import MacroTools: splitdef, combinedef
88

99
abstract type Expression <: Number end
1010
abstract type AbstractOperation <: Expression end
11-
abstract type AbstractOperator <: Expression end
1211
abstract type AbstractComponent <: Expression end
1312
abstract type AbstractSystem end
1413
abstract type AbstractDomain end
@@ -26,14 +25,14 @@ function caclulate_jacobian end
2625
@enum FunctionVersions ArrayFunction=1 SArrayFunction=2
2726

2827
include("operations.jl")
29-
include("operators.jl")
28+
include("differentials.jl")
3029
include("systems/diffeqs/diffeqsystem.jl")
3130
include("systems/diffeqs/first_order_transform.jl")
3231
include("systems/nonlinear/nonlinear_system.jl")
3332
include("function_registration.jl")
3433
include("simplify.jl")
3534
include("utils.jl")
3635

37-
export Operation, Expression, AbstractOperator, AbstractComponent, AbstractDomain
36+
export Operation, Expression, AbstractComponent, AbstractDomain
3837
export @register
3938
end # module

src/differentials.jl

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
struct Differential <: Function
2+
x::Expression
3+
order::Int
4+
end
5+
Differential(x) = Differential(x,1)
6+
7+
Base.show(io::IO, D::Differential) = print(io,"($(D.x),$(D.order))")
8+
Base.Expr(D::Differential) = D
9+
10+
function Derivative end
11+
(D::Differential)(x::Operation) = Operation(D, Expression[x])
12+
function (D::Differential)(x::Variable)
13+
D.x === x && return Constant(1)
14+
has_dependent(x, D.x) || return Constant(0)
15+
return Variable(x,D)
16+
end
17+
Base.:(==)(D1::Differential, D2::Differential) = D1.order == D2.order && D1.x == D2.x
18+
19+
Variable(x::Variable, D::Differential) = Variable(x.name,x.value,x.value_type,
20+
x.subtype,D,x.dependents,x.description,x.flow,x.domain,
21+
x.size,x.context)
22+
23+
function expand_derivatives(O::Operation)
24+
@. O.args = expand_derivatives(O.args)
25+
26+
if O.op isa Differential
27+
D = O.op
28+
o = O.args[1]
29+
return simplify_constants(sum(i->Derivative(o,i)*expand_derivatives(D(o.args[i])),1:length(o.args)))
30+
end
31+
32+
return O
33+
end
34+
expand_derivatives(x::Variable) = x
35+
36+
# Don't specialize on the function here
37+
function Derivative(O::Operation,idx)
38+
# This calls the Derivative dispatch from the user or pre-defined code
39+
Derivative(O.op, O.args, Val(idx))
40+
end
41+
Derivative(op, args, idx) = Derivative(op, (args...,), idx)
42+
43+
# Pre-defined derivatives
44+
import DiffRules, SpecialFunctions, NaNMath
45+
for (modu, fun, arity) DiffRules.diffrules()
46+
for i 1:arity
47+
@eval function Derivative(::typeof($modu.$fun), args::NTuple{$arity,Any}, ::Val{$i})
48+
M, f = $(modu, fun)
49+
partials = DiffRules.diffrule(M, f, args...)
50+
dx = @static $arity == 1 ? partials : partials[$i]
51+
parse(Operation,dx)
52+
end
53+
end
54+
end
55+
56+
function count_order(x)
57+
@assert !(x isa Symbol) "The variable $x must have an order of differentiation that is greater or equal to 1!"
58+
n = 1
59+
while !(x.args[1] isa Symbol)
60+
n = n+1
61+
x = x.args[1]
62+
end
63+
n, x.args[1]
64+
end
65+
66+
function _differential_macro(x)
67+
ex = Expr(:block)
68+
lhss = Symbol[]
69+
x = flatten_expr!(x)
70+
for di in x
71+
@assert di isa Expr && di.args[1] == :~ "@Deriv expects a form that looks like `@Deriv D''~t E'~t`"
72+
lhs = di.args[2]
73+
rhs = di.args[3]
74+
order, lhs = count_order(lhs)
75+
push!(lhss, lhs)
76+
expr = :($lhs = Differential($rhs, $order))
77+
push!(ex.args, expr)
78+
end
79+
push!(ex.args, Expr(:tuple, lhss...))
80+
ex
81+
end
82+
83+
macro Deriv(x...)
84+
esc(_differential_macro(x))
85+
end
86+
87+
function calculate_jacobian(eqs,vars)
88+
Expression[Differential(vars[j])(eqs[i]) for i in 1:length(eqs), j in 1:length(vars)]
89+
end
90+
91+
export Differential, Derivative, expand_derivatives, @Deriv, calculate_jacobian

src/operators.jl

Lines changed: 0 additions & 111 deletions
This file was deleted.

src/systems/diffeqs/first_order_transform.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function ode_order_lowering!(eqs, naming_scheme)
3636
for sym in keys(sym_order)
3737
order = sym_order[sym]
3838
for o in (order-1):-1:1
39-
lhs = D*lower_varname(sym, idv, o-1, dv_name, naming_scheme)
39+
lhs = D(lower_varname(sym, idv, o-1, dv_name, naming_scheme))
4040
rhs = lower_varname(sym, idv, o, dv_name, naming_scheme)
4141
eq = Operation(==, [lhs, rhs])
4242
push!(eqs, eq)
@@ -46,7 +46,7 @@ function ode_order_lowering!(eqs, naming_scheme)
4646
end
4747

4848
function lhs_renaming!(eq, D, naming_scheme)
49-
eq.args[1] = D*lower_varname(eq.args[1], naming_scheme, lower=true)
49+
eq.args[1] = D(lower_varname(eq.args[1], naming_scheme, lower=true))
5050
return eq
5151
end
5252
function rhs_renaming!(eq, naming_scheme)

src/utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,7 @@ is_constant(::Any) = false
2424

2525
is_operation(::Operation) = true
2626
is_operation(::Any) = false
27+
28+
has_dependent(t::Variable) = Base.Fix2(has_dependent, t)
29+
has_dependent(x::Variable, t::Variable) =
30+
t x.dependents || any(has_dependent(t), x.dependents)

src/variables.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ mutable struct Variable <: Expression
33
value
44
value_type::DataType
55
subtype::Symbol
6-
diff::Union{AbstractOperator,Nothing}
6+
diff::Union{Function,Nothing} # FIXME
77
dependents::Vector{Variable}
88
description::String
99
flow::Bool

test/basic_variables_and_operations.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ s = JumpVariable(:s,3,dependents=[t])
1717
n = NoiseVariable(:n,dependents=[t])
1818

1919
σ*(y-x)
20-
D*x
21-
D*x ~ -σ*(y-x)
22-
D*y ~ x*-z)-sin(y)
20+
D(x)
21+
D(x) ~ -σ*(y-x)
22+
D(y) ~ x*-z)-sin(y)
2323

24-
@test D*t == Constant(1)
24+
@test D(t) == Constant(1)

test/components.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ struct Lorenz <: AbstractComponent
1212
end
1313
function generate_lorenz_eqs(t,x,y,z,σ,ρ,β)
1414
D = Differential(t)
15-
[D*x ~ σ*(y-x)
16-
D*y ~ x*-z)-y
17-
D*z ~ x*y - β*z]
15+
[D(x) ~ σ*(y-x)
16+
D(y) ~ x*-z)-y
17+
D(z) ~ x*y - β*z]
1818
end
1919
Lorenz(t) = Lorenz(first(@DVar(x(t))),first(@DVar(y(t))),first(@DVar(z(t))),first(@Param(σ)),first(@Param(ρ)),first(@Param(β)),generate_lorenz_eqs(t,x,y,z,σ,ρ,β))
2020

test/derivatives.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,24 @@ using Test
66
@Var x(t) y(t) z(t)
77
@Param σ ρ β
88
@Deriv D'~t
9-
dsin = D*sin(t)
9+
dsin = D(sin(t))
1010
expand_derivatives(dsin)
1111

1212
@test expand_derivatives(dsin) == cos(t)
13-
dcsch = D*csch(t)
13+
dcsch = D(csch(t))
1414
@test expand_derivatives(dcsch) == simplify_constants(Operation(coth(t)*csch(t)*-1))
1515

1616
# Chain rule
17-
dsinsin = D*sin(sin(t))
17+
dsinsin = D(sin(sin(t)))
1818
@test expand_derivatives(dsinsin) == cos(sin(t))*cos(t)
1919
# Binary
20-
dpow1 = Derivative(^,[x, y],Val{1})
21-
dpow2 = Derivative(^,[x, y],Val{2})
20+
dpow1 = Derivative(^,[x, y],Val(1))
21+
dpow2 = Derivative(^,[x, y],Val(2))
2222
@test dpow1 == y*x^(y-1)
2323
@test dpow2 == x^y*log(x)
2424

25-
d1 = D*(sin(t)*t)
26-
d2 = D*(sin(t)*cos(t))
25+
d1 = D(sin(t)*t)
26+
d2 = D(sin(t)*cos(t))
2727
@test expand_derivatives(d1) == t*cos(t)+sin(t)
2828
@test expand_derivatives(d2) == simplify_constants(cos(t)*cos(t)+sin(t)*(-1*sin(t)))
2929

@@ -41,3 +41,7 @@ jac = ModelingToolkit.calculate_jacobian(sys)
4141
@test jac[3,1] == y
4242
@test jac[3,2] == x
4343
@test jac[3,3] == -1*β
44+
45+
# Variable dependence checking in differentiation
46+
@Var a(t) b(a)
47+
@test D(b) Constant(0)

0 commit comments

Comments
 (0)