Skip to content

Commit ce3c3e4

Browse files
authored
Merge pull request #682 from SciML/s/deriv-generic
leave D(foo(t)) untouched if foo is not a known function
2 parents 580aa6d + 87bc6cd commit ce3c3e4

File tree

3 files changed

+26
-8
lines changed

3 files changed

+26
-8
lines changed

src/differentials.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,11 @@ function expand_derivatives(O::Term, simplify=true; occurances=nothing)
101101
x = if _iszero(t2)
102102
t2
103103
elseif _isone(t2)
104-
derivative(o, i)
104+
d = derivative_idx(o, i)
105+
d isa NoDeriv ? D(o) : d
105106
else
106-
t1 = derivative(o, i)
107+
t1 = derivative_idx(o, i)
108+
t1 = t1 isa NoDeriv ? D(o) : t1
107109
make_operation(*, [t1, t2])
108110
end
109111

@@ -157,7 +159,7 @@ julia> using ModelingToolkit
157159
158160
julia> @variables x y;
159161
160-
julia> ModelingToolkit.derivative(sin(x), 1)
162+
julia> ModelingToolkit.derivative_idx(sin(x), 1)
161163
cos(x())
162164
```
163165
@@ -171,15 +173,20 @@ sin(x()) * y() ^ 2
171173
julia> typeof(myop.op) # Op is multiplication function
172174
typeof(*)
173175
174-
julia> ModelingToolkit.derivative(myop, 1) # wrt. sin(x)
176+
julia> ModelingToolkit.derivative_idx(myop, 1) # wrt. sin(x)
175177
y() ^ 2
176178
177-
julia> ModelingToolkit.derivative(myop, 2) # wrt. y^2
179+
julia> ModelingToolkit.derivative_idx(myop, 2) # wrt. y^2
178180
sin(x())
179181
```
180182
"""
181-
derivative(O::Term, idx) = derivative(O.op, (O.args...,), Val(idx))
182-
derivative(O::Any, ::Any) = 0
183+
derivative_idx(O::Any, ::Any) = 0
184+
derivative_idx(O::Term, idx) = derivative(O.op, (O.args...,), Val(idx))
185+
186+
# Indicate that no derivative is defined.
187+
struct NoDeriv
188+
end
189+
derivative(f, args, v) = NoDeriv()
183190

184191
# Pre-defined derivatives
185192
import DiffRules

test/derivatives.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ z = t-2t
8080
# isequal(ModelingToolkit.derivative(Term(*, [x, y, z*ρ]), 1), y*(z*ρ))
8181
# isequal(ModelingToolkit.derivative(Term(+, [x*y, y, z]), 1), 1)
8282

83-
@test iszero(ModelingToolkit.derivative(42, x))
83+
@test iszero(expand_derivatives(D(42)))
8484
@test all(iszero, ModelingToolkit.gradient(42, [t, x, y, z]))
8585
@test all(iszero, ModelingToolkit.hessian(42, [t, x, y, z]))
8686
@test isequal(ModelingToolkit.jacobian([t, x, 42], [t, x]),

test/direct.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,14 @@ test_worldage()
167167
@test_nowarn muladd(x, y, 0)
168168
@test promote(x, 0) == (x, identity(0))
169169
@test_nowarn [x, y, z]'
170+
171+
let
172+
@register foo(x)
173+
@variables t
174+
@derivatives D'~t
175+
176+
177+
@test isequal(expand_derivatives(D(foo(t))), D(foo(t)))
178+
@test isequal(expand_derivatives(D(sin(t) * foo(t))), cos(t) * foo(t) + sin(t) * D(foo(t)))
179+
180+
end

0 commit comments

Comments
 (0)