Skip to content

Commit 93b0cc1

Browse files
authored
Add xlog1py (#27)
* Add `xlog1py` * Bump version
1 parent a6158f9 commit 93b0cc1

File tree

7 files changed

+41
-3
lines changed

7 files changed

+41
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LogExpFunctions"
22
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
33
authors = ["StatsFun.jl contributors, Tamas K. Papp <[email protected]>"]
4-
version = "0.3.1"
4+
version = "0.3.2"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ The original authors of these functions are the StatsFuns.jl contributors.
77
```@docs
88
xlogx
99
xlogy
10+
xlog1py
1011
logistic
1112
logit
1213
log1psq

src/LogExpFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import ChainRulesCore
77
import IrrationalConstants
88
import LinearAlgebra
99

10-
export xlogx, xlogy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1,
10+
export xlogx, xlogy, xlog1py, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1,
1111
softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, softmax,
1212
softmax!
1313

src/basicfuns.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,21 @@ function xlogy(x::Number, y::Number)
2929
return iszero(x) && !isnan(y) ? zero(result) : result
3030
end
3131

32+
"""
33+
$(SIGNATURES)
34+
35+
Return `x * log(1 + y)` for `y ≥ -1` with correct limit at ``x = 0``.
36+
37+
```jldoctest
38+
julia> xlog1py(0, -1)
39+
0.0
40+
```
41+
"""
42+
function xlog1py(x::Number, y::Number)
43+
result = x * log1p(y)
44+
return iszero(x) && !isnan(y) ? zero(result) : result
45+
end
46+
3247
# The following bounds are precomputed versions of the following abstract
3348
# function, but the implicit interface for AbstractFloat doesn't uniformly
3449
# enforce that all floating point types implement nextfloat and prevfloat.

src/chainrules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
ChainRulesCore.@scalar_rule(xlogx(x::Real), (1 + log(x),))
22
ChainRulesCore.@scalar_rule(xlogy(x::Real, y::Real), (log(y), x / y,))
3+
ChainRulesCore.@scalar_rule(xlog1py(x::Real, y::Real), (log1p(y), x / (1 + y),))
34

45
ChainRulesCore.@scalar_rule(logistic(x::Real), (Ω * (1 - Ω),))
56
ChainRulesCore.@scalar_rule(logit(x::Real), (inv(x * (1 - x)),))

test/basicfuns.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testset "xlogx & xlogy" begin
1+
@testset "xlogx, xlogy, and xlog1py" begin
22
@test iszero(xlogx(0))
33
@test xlogx(2) 2.0 * log(2.0)
44
@test_throws DomainError xlogx(-1)
@@ -11,6 +11,13 @@
1111
@test isnan(xlogy(2, NaN))
1212
@test isnan(xlogy(0, NaN))
1313

14+
@test iszero(xlog1py(0, 0))
15+
@test xlog1py(2, 3) 2.0 * log1p(3.0)
16+
@test_throws DomainError xlog1py(1, -2)
17+
@test isnan(xlog1py(NaN, 2))
18+
@test isnan(xlog1py(2, NaN))
19+
@test isnan(xlog1py(0, NaN))
20+
1421
# Since we allow complex/negative values, test for them. See comments in:
1522
# https://github.com/JuliaStats/StatsFuns.jl/pull/95
1623

@@ -26,6 +33,15 @@
2633
@test isnan(xlogy(Inf + im * NaN, 1))
2734
@test isnan(xlogy(0 + im * 0, NaN))
2835
@test iszero(xlogy(0 + im * 0, 0 + im * Inf))
36+
37+
@test xlog1py(-2, 3) == -xlog1py(2, 3)
38+
@test xlog1py(1 + im, 3) == (1 + im) * log1p(3)
39+
@test xlog1py(1 + im, 2 + im) == (1 + im) * log1p(2 + im)
40+
@test isnan(xlog1py(1 + NaN * im, -1 + im))
41+
@test isnan(xlog1py(0, -1 + NaN * im))
42+
@test isnan(xlog1py(Inf + im * NaN, 1))
43+
@test isnan(xlog1py(0 + im * 0, NaN))
44+
@test iszero(xlog1py(0 + im * 0, -1 + im * Inf))
2945
end
3046

3147
@testset "logistic & logit" begin

test/chainrules.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
y = rand()
77
test_frule(xlogy, x, y)
88
test_rrule(xlogy, x, y)
9+
10+
for z in (-y, y)
11+
test_frule(xlog1py, x, z)
12+
test_rrule(xlog1py, x, z)
13+
end
914
end
1015

1116
test_frule(logit, x)

0 commit comments

Comments
 (0)