Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ logistic
logit
logcosh
logabssinh
logabstanh
log1psq
log1pexp
softplus
Expand Down
1 change: 1 addition & 0 deletions ext/LogExpFunctionsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ ChainRulesCore.@scalar_rule(logistic(x::Real), (Ω * (1 - Ω),))
ChainRulesCore.@scalar_rule(logit(x::Real), (inv(x * (1 - x)),))
ChainRulesCore.@scalar_rule(logcosh(x::Real), tanh(x))
ChainRulesCore.@scalar_rule(logabssinh(x::Real), coth(x))
ChainRulesCore.@scalar_rule(logabstanh(x::Real), inv(cosh(x) * sinh(x)))
ChainRulesCore.@scalar_rule(log1psq(x::Real), (2 * x / (1 + x^2),))
ChainRulesCore.@scalar_rule(log1pexp(x::Real), (logistic(x),))
ChainRulesCore.@scalar_rule(log1mexp(x::Real), (-exp(x - Ω),))
Expand Down
2 changes: 1 addition & 1 deletion src/LogExpFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import LinearAlgebra

export xlogx, xlogy, xlog1py, xexpx, xexpy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1,
softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, logsumexp!, softmax,
softmax!, logcosh, logabssinh, cloglog, cexpexp,
softmax!, logcosh, logabssinh, logabstanh, cloglog, cexpexp,
loglogistic, logitexp, log1mlogistic, logit1mexp

include("basicfuns.jl")
Expand Down
16 changes: 16 additions & 0 deletions src/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,22 @@ end
"""
$(SIGNATURES)

Return `log(abs(tanh(x)))`, evaluated carefully.

The implementation ensures `logabstanh(-x) = logabstanh(x)`.
"""
function logabstanh(x::Real)
a = abs(x)
if 8*a < 3
log(tanh(a))
else
log1p(-2/(exp(2*a)+1))
end
end

"""
$(SIGNATURES)

Return `log(1+x^2)` evaluated carefully for `abs(x)` very small or very large.
"""
log1psq(x::Real) = log1p(abs2(x))
Expand Down
13 changes: 10 additions & 3 deletions test/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,29 +93,36 @@ end
end
end

@testset "logcosh and logabssinh" begin
@testset "logcosh and logabssinh and logabstanh" begin
for x in (randn(), randn(Float32))
@test @inferred(logcosh(x)) isa typeof(x)
@test logcosh(x) ≈ log(cosh(x))
@test logcosh(-x) == logcosh(x)
@test @inferred(logabssinh(x)) isa typeof(x)
@test logabssinh(x) ≈ log(abs(sinh(x)))
@test logabssinh(-x) == logabssinh(x)
@test @inferred(logabstanh(x)) isa typeof(x)
@test logabstanh(x) ≈ log(abs(tanh(x)))
@test logabstanh(-x) == logabstanh(x)
end

# special values
for x in (-Inf, Inf, -Inf32, Inf32)
@test @inferred(logcosh(x)) === oftype(x, Inf)
@test @inferred(logabssinh(x)) === oftype(x, Inf)
@test @inferred(logabstanh(x)) === -oftype(x, 0)
end
for x in (NaN, NaN32)
@test @inferred(logcosh(x)) === x
@test @inferred(logabssinh(x)) === x
@test @inferred(logabstanh(x)) === x
end

@testset "accuracy of `logcosh`" begin
@testset "accuracy" begin
for t in (Float16, Float32, Float64)
@test ulp_error_maximum(logcosh, range(start = t(-3), stop = t(3), length = 1000)) < 3
ran = range(start = t(-3), stop = t(3), length = 1000)
@test ulp_error_maximum(logcosh, ran) < 3
@test ulp_error_maximum(logabstanh, ran) < 3
end
end
end
Expand Down
2 changes: 2 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@
test_rrule(logcosh, x)
test_frule(logabssinh, x)
test_rrule(logabssinh, x)
test_frule(logabstanh, x)
test_rrule(logabstanh, x)
end

@testset "log1pexp" begin
Expand Down