Skip to content

Commit c90f261

Browse files
simsuracedevmotion
andauthored
Add iszero(x) branches to xlogy and xlog1py (#57)
* Add `iszero(x)` branches to `xlogy` and `xlog1py` * Import FiniteDifferences for tests * Remove additional NaN branch * Fix tests * White space Co-authored-by: David Widmann <[email protected]> * White space Co-authored-by: David Widmann <[email protected]> * Simplify expression Co-authored-by: David Widmann <[email protected]> * Improve formatting Co-authored-by: David Widmann <[email protected]> * Adjust and fix tests * Correct mistake * Revert `frule` change, update test accordingly Co-authored-by: David Widmann <[email protected]>
1 parent 153b54e commit c90f261

File tree

4 files changed

+26
-5
lines changed

4 files changed

+26
-5
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ julia = "1"
2121

2222
[extras]
2323
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
24+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2425
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
2526
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2627
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2728

2829
[targets]
29-
test = ["ChainRulesTestUtils", "OffsetArrays", "Random", "Test"]
30+
test = ["ChainRulesTestUtils", "FiniteDifferences", "OffsetArrays", "Random", "Test"]

src/chainrules.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,15 @@ end
1919
function _Ω_∂_xlogy(x::Real, y::Real)
2020
logy = log(y)
2121
z = x * logy
22-
Ω = iszero(x) && !isnan(y) ? zero(z) : z
22+
w = x / y
23+
if iszero(x) && !isnan(y)
24+
Ω = zero(z)
25+
∂y = zero(w)
26+
else
27+
Ω = z
28+
∂y = w
29+
end
2330
∂x = logy
24-
∂y = x / y
2531
return Ω, ∂x, ∂y
2632
end
2733
function ChainRulesCore.frule((_, Δx, Δy), ::typeof(xlogy), x::Real, y::Real)
@@ -38,9 +44,15 @@ end
3844
function _Ω_∂_xlog1py(x::Real, y::Real)
3945
log1py = log1p(y)
4046
z = x * log1py
41-
Ω = iszero(x) && !isnan(y) ? zero(z) : z
47+
w = x / (1 + y)
48+
if iszero(x) && !isnan(y)
49+
Ω = zero(z)
50+
∂y = zero(w)
51+
else
52+
Ω = z
53+
∂y = w
54+
end
4255
∂x = log1py
43-
∂y = x / (1 + y)
4456
return Ω, ∂x, ∂y
4557
end
4658
function ChainRulesCore.frule((_, Δx, Δy), ::typeof(xlog1py), x::Real, y::Real)

test/chainrules.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
x = rand()
33
test_frule(xlogx, x)
44
test_rrule(xlogx, x)
5+
6+
# Test `iszero(x)` branches
7+
test_frule(xlogy, 0.0, 1.0; fdm = forward_fdm(5, 1), nans = true)
8+
test_rrule(xlogy, 0.0, 1.0; fdm = forward_fdm(5, 1), nans = true)
9+
@test iszero(last(frule((NoTangent(), ZeroTangent(), 1.), xlog1py, 0.0, -1.0)))
10+
@test iszero(last(last(rrule(xlog1py, 0.0, -1.0))(1.)))
11+
512
for x in (-x, 0.0, x)
613
y = rand()
714
test_frule(xlogy, x, y)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using LogExpFunctions
22
using ChainRulesTestUtils
33
using ChainRulesCore
44
using ChangesOfVariables
5+
using FiniteDifferences
56
using InverseFunctions
67
using OffsetArrays
78

0 commit comments

Comments
 (0)