Skip to content

Commit d1a6adc

Browse files
Merge pull request #84 from JuliaDiffEq/hg/fix/82
Reduce potential for nontermination in Derivative
2 parents e0fad95 + 50faac0 commit d1a6adc

File tree

3 files changed

+6
-10
lines changed

3 files changed

+6
-10
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,8 @@ to better scale to larger systems. You can define derivatives for your own
234234
function via the dispatch:
235235

236236
```julia
237-
ModelingToolkit.Derivative(::typeof(my_function),args,::Val{i})
237+
# `N` arguments are accepted by the relevant method of `my_function`
238+
ModelingToolkit.Derivative(::typeof(my_function), args::NTuple{N,Any}, ::Val{i})
238239
```
239240

240241
where `i` means that it's the derivative of the `i`th argument. `args` is the
@@ -244,7 +245,7 @@ You should return an `Operation` for the derivative of your function.
244245
For example, `sin(t)`'s derivative (by `t`) is given by the following:
245246

246247
```julia
247-
ModelingToolkit.Derivative(::typeof(sin),args,::Val{1}) = cos(args[1])
248+
ModelingToolkit.Derivative(::typeof(sin), args::NTuple{1,Any}, ::Val{1}) = cos(args[1])
248249
```
249250

250251
### Macro-free Usage

src/differentials.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ Differential(x) = Differential(x,1)
77
Base.show(io::IO, D::Differential) = print(io,"($(D.x),$(D.order))")
88
Base.convert(::Type{Expr}, D::Differential) = D
99

10-
function Derivative end
1110
(D::Differential)(x::Operation) = Operation(D, Expression[x])
1211
function (D::Differential)(x::Variable)
1312
D.x === x && return Constant(1)
@@ -33,11 +32,7 @@ end
3332
expand_derivatives(x) = x
3433

3534
# Don't specialize on the function here
36-
function Derivative(O::Operation,idx)
37-
# This calls the Derivative dispatch from the user or pre-defined code
38-
Derivative(O.op, O.args, Val(idx))
39-
end
40-
Derivative(op, args, idx) = Derivative(op, (args...,), idx)
35+
Derivative(O::Operation, idx) = Derivative(O.op, (O.args...,), Val(idx))
4136

4237
# Pre-defined derivatives
4338
import DiffRules, SpecialFunctions, NaNMath

test/derivatives.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ dcsch = D(csch(t))
1717
dsinsin = D(sin(sin(t)))
1818
@test expand_derivatives(dsinsin) == cos(sin(t))*cos(t)
1919
# Binary
20-
dpow1 = Derivative(^,[x, y],Val(1))
21-
dpow2 = Derivative(^,[x, y],Val(2))
20+
dpow1 = Derivative(^, (x, y), Val(1))
21+
dpow2 = Derivative(^, (x, y), Val(2))
2222
@test dpow1 == y*x^(y-1)
2323
@test dpow2 == x^y*log(x)
2424

0 commit comments

Comments
 (0)