Skip to content

Commit ab45092

Browse files
Merge pull request #103 from JuliaDiffEq/hg/fix/cleanup
Minor cleanup
2 parents 02edf04 + 6379485 commit ab45092

15 files changed

+159
-149
lines changed

README.md

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@ to manipulate.
2121
### Example: ODE
2222

2323
Let's build an ODE. First we define some variables. In a differential equation
24-
system, we need to differentiate between our unknown (dependent) variables
24+
system, we need to differentiate between our (dependent) variables
2525
and parameters. Therefore we label them as follows:
2626

2727
```julia
2828
using ModelingToolkit
2929

3030
# Define some variables
31-
@Param t σ ρ β
32-
@Unknown x(t) y(t) z(t)
33-
@Deriv D'~t
31+
@parameters t σ ρ β
32+
@variables x(t) y(t) z(t)
33+
@derivatives D'~t
3434
```
3535

3636
Then we build the system:
@@ -78,11 +78,11 @@ f = ODEFunction(de)
7878
7979
We can also build nonlinear systems. Let's say we wanted to solve for the steady
8080
state of the previous ODE. This is the nonlinear system defined by where the
81-
derivatives are zero. We use unknown variables for our nonlinear system.
81+
derivatives are zero. We use (unknown) variables for our nonlinear system.
8282
8383
```julia
84-
@Unknown x y z
85-
@Param σ ρ β
84+
@variables x y z
85+
@parameters σ ρ β
8686

8787
# Define a nonlinear system
8888
eqs = [0 ~ σ*(y-x),
@@ -173,7 +173,7 @@ structure is as follows:
173173
the system of equations.
174174
- Name to subtype mappings: these describe how variable `subtype`s are mapped
175175
to the contexts of the system. For example, for a differential equation,
176-
the unknown variable corresponds to given subtypes and then the `eqs` can
176+
the variable corresponds to given subtypes and then the `eqs` can
177177
be analyzed knowing what the state variables are.
178178
- Variable names which do not fall into one of the system's core subtypes are
179179
treated as intermediates which can be used for holding subcalculations and
@@ -223,7 +223,7 @@ function via the dispatch:
223223
224224
```julia
225225
# `N` arguments are accepted by the relevant method of `my_function`
226-
ModelingToolkit.Derivative(::typeof(my_function), args::NTuple{N,Any}, ::Val{i})
226+
ModelingToolkit.derivative(::typeof(my_function), args::NTuple{N,Any}, ::Val{i})
227227
```
228228
229229
where `i` means that it's the derivative of the `i`th argument. `args` is the
@@ -233,7 +233,7 @@ You should return an `Operation` for the derivative of your function.
233233
For example, `sin(t)`'s derivative (by `t`) is given by the following:
234234
235235
```julia
236-
ModelingToolkit.Derivative(::typeof(sin), args::NTuple{1,Any}, ::Val{1}) = cos(args[1])
236+
ModelingToolkit.derivative(::typeof(sin), args::NTuple{1,Any}, ::Val{1}) = cos(args[1])
237237
```
238238
239239
### Macro-free Usage
@@ -243,31 +243,31 @@ is accessible via a function-based interface. This means that all macros are
243243
syntactic sugar in some form. For example, the variable construction:
244244
245245
```julia
246-
@Param t σ ρ β
247-
@Unknown x(t) y(t) z(t)
248-
@Deriv D'~t
246+
@parameters t σ ρ β
247+
@variables x(t) y(t) z(t)
248+
@derivatives D'~t
249249
```
250250
251251
is syntactic sugar for:
252252
253253
```julia
254-
t = Parameter(:t)
255-
x = Unknown(:x, [t])
256-
y = Unknown(:y, [t])
257-
z = Unknown(:z, [t])
254+
t = Variable(:t; known = true)
255+
x = Variable(:x, [t])
256+
y = Variable(:y, [t])
257+
z = Variable(:z, [t])
258258
D = Differential(t)
259-
σ = Parameter()
260-
ρ = Parameter()
261-
β = Parameter()
259+
σ = Variable(; known = true)
260+
ρ = Variable(; known = true)
261+
β = Variable(; known = true)
262262
```
263263
264264
### Intermediate Calculations
265265
266266
The system building functions can handle intermediate calculations. For example,
267267
268268
```julia
269-
@Unknown x y z
270-
@Param σ ρ β
269+
@variables x y z
270+
@parameters σ ρ β
271271
a = y - x
272272
eqs = [0 ~ σ*a,
273273
0 ~ x*-z)-y,

src/ModelingToolkit.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
module ModelingToolkit
22

3+
export Operation, Expression
4+
export calculate_jacobian, generate_jacobian, generate_function
5+
export @register
6+
7+
38
using DiffEqBase
49
using StaticArrays, LinearAlgebra
510

@@ -30,9 +35,4 @@ include("function_registration.jl")
3035
include("simplify.jl")
3136
include("utils.jl")
3237

33-
export Operation, Expression, AbstractComponent
34-
export calculate_jacobian, generate_jacobian, generate_function
35-
export ArrayFunction, SArrayFunction
36-
export @register
37-
3838
end # module

src/differentials.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
export Differential, expand_derivatives, @derivatives
2+
3+
14
struct Differential <: Function
25
x::Expression
36
end
@@ -12,7 +15,7 @@ function (D::Differential)(x::Variable)
1215
return Operation(D, Expression[x])
1316
end
1417
(::Differential)(::Any) = Constant(0)
15-
Base.:(==)(D1::Differential, D2::Differential) = D1.x == D2.x
18+
Base.:(==)(D1::Differential, D2::Differential) = isequal(D1.x, D2.x)
1619

1720
function expand_derivatives(O::Operation)
1821
@. O.args = expand_derivatives(O.args)
@@ -21,21 +24,21 @@ function expand_derivatives(O::Operation)
2124
D = O.op
2225
o = O.args[1]
2326
isa(o, Operation) || return O
24-
return simplify_constants(sum(i->Derivative(o,i)*expand_derivatives(D(o.args[i])),1:length(o.args)))
27+
return simplify_constants(sum(i->derivative(o,i)*expand_derivatives(D(o.args[i])),1:length(o.args)))
2528
end
2629

2730
return O
2831
end
2932
expand_derivatives(x) = x
3033

3134
# Don't specialize on the function here
32-
Derivative(O::Operation, idx) = Derivative(O.op, (O.args...,), Val(idx))
35+
derivative(O::Operation, idx) = derivative(O.op, (O.args...,), Val(idx))
3336

3437
# Pre-defined derivatives
3538
import DiffRules, SpecialFunctions, NaNMath
3639
for (modu, fun, arity) DiffRules.diffrules()
3740
for i 1:arity
38-
@eval function Derivative(::typeof($modu.$fun), args::NTuple{$arity,Any}, ::Val{$i})
41+
@eval function derivative(::typeof($modu.$fun), args::NTuple{$arity,Any}, ::Val{$i})
3942
M, f = $(modu, fun)
4043
partials = DiffRules.diffrule(M, f, args...)
4144
dx = @static $arity == 1 ? partials : partials[$i]
@@ -60,7 +63,7 @@ function _differential_macro(x)
6063
lhss = Symbol[]
6164
x = flatten_expr!(x)
6265
for di in x
63-
@assert di isa Expr && di.args[1] == :~ "@Deriv expects a form that looks like `@Deriv D''~t E'~t`"
66+
@assert di isa Expr && di.args[1] == :~ "@derivatives expects a form that looks like `@derivatives D''~t E'~t`"
6467
lhs = di.args[2]
6568
rhs = di.args[3]
6669
order, lhs = count_order(lhs)
@@ -72,12 +75,10 @@ function _differential_macro(x)
7275
ex
7376
end
7477

75-
macro Deriv(x...)
78+
macro derivatives(x...)
7679
esc(_differential_macro(x))
7780
end
7881

7982
function calculate_jacobian(eqs,vars)
8083
Expression[Differential(vars[j])(eqs[i]) for i in 1:length(eqs), j in 1:length(vars)]
8184
end
82-
83-
export Differential, expand_derivatives, @Deriv, calculate_jacobian

src/equations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@ struct Equation
55
lhs::Expression
66
rhs::Expression
77
end
8-
Base.:(==)(a::Equation, b::Equation) = (a.lhs, a.rhs) == (b.lhs, b.rhs)
8+
Base.:(==)(a::Equation, b::Equation) = isequal((a.lhs, a.rhs), (b.lhs, b.rhs))
99

1010
Base.:~(lhs::Expression, rhs::Expression) = Equation(lhs, rhs)
1111
Base.:~(lhs::Expression, rhs::Number ) = Equation(lhs, rhs)
1212
Base.:~(lhs::Number , rhs::Expression) = Equation(lhs, rhs)
1313

1414

1515
_is_dependent(x::Variable) = !x.known && !isempty(x.dependents)
16-
_is_parameter(iv) = x -> x.known && x iv
16+
_is_parameter(iv) = x -> x.known && !isequal(x, iv)
1717
_is_known(x::Variable) = x.known
1818
_is_unknown(x::Variable) = !x.known
1919

src/function_registration.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,13 @@ for (M, f, arity) in DiffRules.diffrules()
3232
@eval @register $sig
3333
end
3434

35-
for fun = (:<, :>, :(==), :!, :&, :|, :div)
35+
for fun [:!]
36+
basefun = Expr(:., Base, QuoteNode(fun))
37+
sig = :($basefun(x))
38+
@eval @register $sig
39+
end
40+
41+
for fun [:<, :>, :(==), :&, :|, :div]
3642
basefun = Expr(:., Base, QuoteNode(fun))
3743
sig = :($basefun(x,y))
3844
@eval @register $sig

src/operations.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ struct Operation <: Expression
44
end
55

66
# Recursive ==
7-
function Base.:(==)(x::Operation,y::Operation)
7+
function Base.isequal(x::Operation,y::Operation)
88
x.op == y.op && length(x.args) == length(y.args) && all(isequal.(x.args,y.args))
99
end
10-
Base.:(==)(::Operation, ::Number ) = false
11-
Base.:(==)(::Number , ::Operation) = false
12-
Base.:(==)(::Operation, ::Variable ) = false
13-
Base.:(==)(::Variable , ::Operation) = false
14-
Base.:(==)(::Operation, ::Constant ) = false
15-
Base.:(==)(::Constant , ::Operation) = false
10+
Base.isequal(::Operation, ::Number ) = false
11+
Base.isequal(::Number , ::Operation) = false
12+
Base.isequal(::Operation, ::Variable ) = false
13+
Base.isequal(::Variable , ::Operation) = false
14+
Base.isequal(::Operation, ::Constant ) = false
15+
Base.isequal(::Constant , ::Operation) = false
1616

1717
Base.convert(::Type{Expr}, O::Operation) =
1818
build_expr(:call, Any[Symbol(O.op); convert.(Expr, O.args)])

src/simplify.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
export simplify_constants
2+
3+
14
function simplify_constants(O::Operation, shorten_tree)
25
while true
36
O′ = _simplify_constants(O, shorten_tree)
47
if is_operation(O′)
58
O′ = Operation(O′.op, simplify_constants.(O′.args, shorten_tree))
69
end
7-
O == O′ && return O
10+
isequal(O, O′) && return O
811
O = O′
912
end
1013
end
@@ -72,5 +75,3 @@ function _simplify_constants(O::Operation, shorten_tree)
7275
end
7376
_simplify_constants(x, shorten_tree) = x
7477
_simplify_constants(x) = _simplify_constants(x, true)
75-
76-
export simplify_constants

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ function flatten_differential(O::Operation)
1010
@assert is_derivative(O) "invalid differential: $O"
1111
is_derivative(O.args[1]) || return (O.args[1], O.op.x, 1)
1212
(x, t, order) = flatten_differential(O.args[1])
13-
t == O.op.x || throw(ArgumentError("non-matching differentials on lhs: $t, $(O.op.x)"))
13+
isequal(t, O.op.x) || throw(ArgumentError("non-matching differentials on lhs: $t, $(O.op.x)"))
1414
return (x, t, order + 1)
1515
end
1616

@@ -26,7 +26,7 @@ function Base.convert(::Type{DiffEq}, eq::Equation)
2626
(x, t, n) = flatten_differential(eq.lhs)
2727
return DiffEq(x, t, n, eq.rhs)
2828
end
29-
Base.:(==)(a::DiffEq, b::DiffEq) = (a.x, a.t, a.n, a.rhs) == (b.x, b.t, b.n, b.rhs)
29+
Base.:(==)(a::DiffEq, b::DiffEq) = isequal((a.x, a.t, a.n, a.rhs), (b.x, b.t, b.n, b.rhs))
3030
get_args(eq::DiffEq) = Expression[eq.x, eq.t, eq.rhs]
3131

3232
struct DiffEqSystem <: AbstractSystem
@@ -79,7 +79,7 @@ end
7979
function generate_ode_iW(sys::DiffEqSystem, simplify=true; version::FunctionVersion = ArrayFunction)
8080
jac = calculate_jacobian(sys)
8181

82-
gam = Parameter(:gam)
82+
gam = Variable(:gam; known = true)
8383

8484
W = LinearAlgebra.I - gam*jac
8585
W = SMatrix{size(W,1),size(W,2)}(W)

src/systems/diffeqs/first_order_transform.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
export ode_order_lowering
2+
3+
14
function lower_varname(var::Variable, idv, order)
25
order == 0 && return var
36
name = Symbol(var.name, :_, string(idv.name)^order)
4-
return Variable(name, var.known, var.dependents)
7+
return Variable(name, var.dependents; known = var.known)
58
end
69

710
function ode_order_lowering(sys::DiffEqSystem)
@@ -17,7 +20,7 @@ function ode_order_lowering(eqs, iv)
1720
var, maxorder = eq.x, eq.n
1821
if maxorder > get(var_order, var, 0)
1922
var_order[var] = maxorder
20-
var vars || push!(vars, var)
23+
any(isequal(var), vars) || push!(vars, var)
2124
end
2225
var′ = lower_varname(eq.x, eq.t, eq.n - 1)
2326
rhs′ = rename(eq.rhs)
@@ -45,5 +48,3 @@ function rename(O::Expression)
4548
end
4649
return Operation(O.op, rename.(O.args))
4750
end
48-
49-
export ode_order_lowering

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,4 @@ is_derivative(::Any) = false
6565

6666
has_dependent(t::Variable) = Base.Fix2(has_dependent, t)
6767
has_dependent(x::Variable, t::Variable) =
68-
t x.dependents || any(has_dependent(t), x.dependents)
68+
any(isequal(t), x.dependents) || any(has_dependent(t), x.dependents)

0 commit comments

Comments
 (0)