Skip to content

Commit 269409e

Browse files
Merge pull request #1400 from devmotion/dw/pow_forwarddiff
Fix ForwardDiff derivative of `NaNMath.pow`
2 parents da1af00 + e3bf321 commit 269409e

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

ext/SymbolicsForwardDiffExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ end
199199
# exponentiation #
200200
#----------------#
201201

202-
for f in (:(Base.:^), :(NaNMath.pow))
202+
for (f, log) in ((:(Base.:^), :(Base.log)), (:(NaNMath.pow), :(NaNMath.log)))
203203
@eval begin
204204
@define_binary_dual_op(
205205
$f,
@@ -212,7 +212,7 @@ for f in (:(Base.:^), :(NaNMath.pow))
212212
elseif iszero(vx) && vy > 0
213213
logval = zero(vx)
214214
else
215-
logval = expv * log(vx)
215+
logval = expv * ($log)(vx)
216216
end
217217
new_partials = _mul_partials(partials(x), partials(y), powval, logval)
218218
return Dual{Txy}(expv, new_partials)
@@ -230,7 +230,7 @@ for f in (:(Base.:^), :(NaNMath.pow))
230230
begin
231231
v = value(y)
232232
expv = ($f)(x, v)
233-
deriv = (iszero(x) && v > 0) ? zero(expv) : expv*log(x)
233+
deriv = (iszero(x) && v > 0) ? zero(expv) : expv*($log)(x)
234234
return Dual{Ty}(expv, deriv * partials(y))
235235
end,
236236
$AMBIGUOUS_TYPES

test/forwarddiff_symbolic_dual_ops.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,9 @@ end
114114
y(x) = isequal(z, x) ? 0 : x
115115
@test ForwardDiff.derivative(y, 0) == 1 # expect ∂(x)/∂x
116116
end
117+
118+
@testset "NaNMath.pow (issue #1399)" begin
119+
@variables x
120+
@test_throws DomainError substitute(ForwardDiff.derivative(z -> x^z, 0.5), x => -1.0)
121+
@test isnan(Symbolics.value(substitute(ForwardDiff.derivative(z -> NaNMath.pow(x, z), 0.5), x => -1.0)))
122+
end

0 commit comments

Comments
 (0)