Skip to content

Commit 8ce6807

Browse files
authored
xexpx, xexpy (#35)
1 parent c8a4c28 commit 8ce6807

File tree

8 files changed

+126
-2
lines changed

8 files changed

+126
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LogExpFunctions"
22
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
33
authors = ["StatsFun.jl contributors, Tamas K. Papp <[email protected]>"]
4-
version = "0.3.6"
4+
version = "0.3.7"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/index.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ LogExpFunctions supports [`InverseFunctions.inverse`](https://github.com/JuliaMa
1010
xlogx
1111
xlogy
1212
xlog1py
13+
xexpx
14+
xexpy
1315
logistic
1416
logit
1517
logcosh

src/LogExpFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import InverseFunctions
99
import IrrationalConstants
1010
import LinearAlgebra
1111

12-
export xlogx, xlogy, xlog1py, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1,
12+
export xlogx, xlogy, xlog1py, xexpx, xexpy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1,
1313
softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, softmax,
1414
softmax!, logcosh
1515

src/basicfuns.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,37 @@ function xlog1py(x::Number, y::Number)
4444
return iszero(x) && !isnan(y) ? zero(result) : result
4545
end
4646

47+
"""
48+
$(SIGNATURES)
49+
50+
Return `x * exp(x)` for `x > -Inf`, or zero if `x == -Inf`.
51+
52+
```jldoctest
53+
julia> xexpx(-Inf)
54+
0.0
55+
```
56+
"""
57+
function xexpx(x::Real)
58+
expx = exp(x)
59+
return iszero(expx) ? expx : x * expx
60+
end
61+
62+
"""
63+
$(SIGNATURES)
64+
65+
Return `x * exp(y)` for `y > -Inf`, or zero if `y == -Inf`.
66+
67+
```jldoctest
68+
julia> xexpy(1.0, -Inf)
69+
0.0
70+
```
71+
"""
72+
function xexpy(x::Real, y::Real)
73+
expy = exp(y)
74+
result = x * expy
75+
return iszero(expy) && !isnan(x) ? zero(result) : result
76+
end
77+
4778
# The following bounds are precomputed versions of the following abstract
4879
# function, but the implicit interface for AbstractFloat doesn't uniformly
4980
# enforce that all floating point types implement nextfloat and prevfloat.

src/chainrules.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,44 @@ ChainRulesCore.@scalar_rule(xlogx(x::Real), (1 + log(x),))
22
ChainRulesCore.@scalar_rule(xlogy(x::Real, y::Real), (log(y), x / y,))
33
ChainRulesCore.@scalar_rule(xlog1py(x::Real, y::Real), (log1p(y), x / (1 + y),))
44

5+
function ChainRulesCore.frule((_, Δx), ::typeof(xexpx), x::Real)
6+
expx = exp(x)
7+
if iszero(expx)
8+
Ω = expx
9+
ΔΩ = expx * Δx
10+
else
11+
Ω = x * expx
12+
ΔΩ = (1 + x) * expx * Δx
13+
end
14+
return Ω, ΔΩ
15+
end
16+
17+
function ChainRulesCore.rrule(::typeof(xexpx), x::Real)
18+
expx = exp(x)
19+
Ω = iszero(expx) ? expx : x * expx
20+
function xexpx_pullback(ΔΩ)
21+
Δx = iszero(expx) ? expx * ΔΩ : (1 + x) * expx * ΔΩ
22+
return (ChainRulesCore.NoTangent(), Δx)
23+
end
24+
return Ω, xexpx_pullback
25+
end
26+
27+
function ChainRulesCore.frule((_, Δx, Δy), ::typeof(xexpy), x::Real, y::Real)
28+
expy = exp(y)
29+
result = x * expy
30+
Ω = iszero(expy) && !isnan(x) ? zero(result) : result
31+
ΔΩ = expy * Δx + Ω * Δy
32+
return Ω, ΔΩ
33+
end
34+
35+
function ChainRulesCore.rrule(::typeof(xexpy), x::Real, y::Real)
36+
expy = exp(y)
37+
result = x * expy
38+
Ω = iszero(expy) && !isnan(x) ? zero(result) : result
39+
xexpy_pullback(ΔΩ) = (ChainRulesCore.NoTangent(), ΔΩ * expy, ΔΩ * Ω)
40+
return Ω, xexpy_pullback
41+
end
42+
543
ChainRulesCore.@scalar_rule(logistic(x::Real), (Ω * (1 - Ω),))
644
ChainRulesCore.@scalar_rule(logit(x::Real), (inv(x * (1 - x)),))
745
ChainRulesCore.@scalar_rule(logcosh(x::Real), tanh(x))

test/basicfuns.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,37 @@
4444
@test iszero(xlog1py(0 + im * 0, -1 + im * Inf))
4545
end
4646

47+
@testset "xexpx" begin
48+
for x in (false, 0, 0.0, 0f0, -Inf, -Inf32)
49+
@test (@inferred xexpx(x)) === zero(exp(x))
50+
end
51+
for x in (NaN16, NaN32, NaN64, Inf16, Inf32, Inf64)
52+
@test (@inferred xexpx(x)) === x
53+
end
54+
for x in (1, true, 1.0, 1f0)
55+
@test (@inferred xexpx(x)) === exp(x)
56+
end
57+
for a in (2, 2f0, 2.0), x in -a:a
58+
@test (@inferred xexpx(x)) === x * exp(x)
59+
end
60+
end
61+
62+
@testset "xexpy" begin
63+
for x in (0, 1, 1.0, 1f0, Inf, Inf32), y in (-Inf, -Inf32)
64+
@test (@inferred xexpy(x, y)) === zero(x * exp(y))
65+
end
66+
for x in (0, 1, 1.0, 1f0, Inf, Inf32, -Inf, -Inf32, NaN, NaN32), nan in (NaN, NaN32)
67+
@test (@inferred xexpy(x, nan)) === oftype(x * exp(nan), NaN)
68+
@test (@inferred xexpy(nan, x)) === oftype(nan * exp(x), NaN)
69+
end
70+
for x in (2, -2f0, 2.0), y in (1, -1f0, 1.0)
71+
@test (@inferred xexpy(x, y)) x * exp(y)
72+
end
73+
for x in (randn(), randn(Float32))
74+
@test xexpy(x, x) xexpx(x)
75+
end
76+
end
77+
4778
@testset "logistic & logit" begin
4879
@test logistic(2) 1.0 / (1.0 + exp(-2.0))
4980
@test logistic(-750.0) === 0.0

test/chainrules.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,27 @@
1313
end
1414
end
1515

16+
@testset "xexpx" begin
17+
# regular branch
18+
test_scalar(xexpx, randn())
19+
# special cases (manually since FiniteDifferences/ChainRulesTestUtils fails at -Inf)
20+
@test @inferred(frule((NoTangent(), rand()), xexpx, -Inf)) === (0.0, 0.0)
21+
Ω, back = @inferred(rrule(xexpx, -Inf))
22+
@test Ω === 0.0
23+
@test back(rand()) === (NoTangent(), 0.0)
24+
end
25+
26+
@testset "xexpy" begin
27+
# regular branch
28+
test_frule(xexpy, randn(), randn())
29+
test_rrule(xexpy, randn(), randn())
30+
# special cases (manually since FiniteDifferences/ChainRulesTestUtils fails at -Inf)
31+
@test @inferred(frule((NoTangent(), rand(), rand()), xexpy, x, -Inf)) === (0.0, 0.0)
32+
Ω, back = @inferred(rrule(xexpy, x, -Inf))
33+
@test Ω === 0.0
34+
@test back(rand()) === (NoTangent(), 0.0, 0.0)
35+
end
36+
1637
test_frule(logit, x)
1738
test_rrule(logit, x)
1839

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using LogExpFunctions
22
using ChainRulesTestUtils
3+
using ChainRulesCore
34
using ChangesOfVariables
45
using InverseFunctions
56
using OffsetArrays

0 commit comments

Comments
 (0)