Skip to content

Commit f998a02

Browse files
committed
leave D(foo(t)) untouched if foo is not a known function.
1 parent 580aa6d commit f998a02

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

src/differentials.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,11 @@ function expand_derivatives(O::Term, simplify=true; occurances=nothing)
101101
x = if _iszero(t2)
102102
t2
103103
elseif _isone(t2)
104-
derivative(o, i)
104+
d = derivative(o, i)
105+
d isa NoDeriv ? D(o) : d
105106
else
106107
t1 = derivative(o, i)
108+
t1 = t1 isa NoDeriv ? D(o) : t1
107109
make_operation(*, [t1, t2])
108110
end
109111

@@ -181,6 +183,11 @@ sin(x())
181183
derivative(O::Term, idx) = derivative(O.op, (O.args...,), Val(idx))
182184
derivative(O::Any, ::Any) = 0
183185

186+
# Indicate that no derivative is defined.
187+
struct NoDeriv
188+
end
189+
derivative(f, args::Tuple, v::Val) = NoDeriv()
190+
184191
# Pre-defined derivatives
185192
import DiffRules
186193
for (modu, fun, arity) DiffRules.diffrules()

test/direct.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,14 @@ test_worldage()
167167
@test_nowarn muladd(x, y, 0)
168168
@test promote(x, 0) == (x, identity(0))
169169
@test_nowarn [x, y, z]'
170+
171+
let
172+
@register foo(x)
173+
@variables t
174+
@derivatives D'~t
175+
176+
177+
@test isequal(expand_derivatives(D(foo(t))), D(foo(t)))
178+
@test isequal(expand_derivatives(D(sin(t) * foo(t))), cos(t) * foo(t) + sin(t) * D(foo(t)))
179+
180+
end

0 commit comments

Comments
 (0)