diff --git a/docs/src/index.md b/docs/src/index.md index 3c0a7f3..9454e0c 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -16,6 +16,7 @@ logistic logit logcosh logabssinh +logabstanh log1psq log1pexp softplus diff --git a/ext/LogExpFunctionsChainRulesCoreExt.jl b/ext/LogExpFunctionsChainRulesCoreExt.jl index 397603a..28fe8fe 100644 --- a/ext/LogExpFunctionsChainRulesCoreExt.jl +++ b/ext/LogExpFunctionsChainRulesCoreExt.jl @@ -118,6 +118,7 @@ ChainRulesCore.@scalar_rule(logistic(x::Real), (Ω * (1 - Ω),)) ChainRulesCore.@scalar_rule(logit(x::Real), (inv(x * (1 - x)),)) ChainRulesCore.@scalar_rule(logcosh(x::Real), tanh(x)) ChainRulesCore.@scalar_rule(logabssinh(x::Real), coth(x)) +ChainRulesCore.@scalar_rule(logabstanh(x::Real), inv(cosh(x) * sinh(x))) ChainRulesCore.@scalar_rule(log1psq(x::Real), (2 * x / (1 + x^2),)) ChainRulesCore.@scalar_rule(log1pexp(x::Real), (logistic(x),)) ChainRulesCore.@scalar_rule(log1mexp(x::Real), (-exp(x - Ω),)) diff --git a/src/LogExpFunctions.jl b/src/LogExpFunctions.jl index 05287eb..cdb917b 100644 --- a/src/LogExpFunctions.jl +++ b/src/LogExpFunctions.jl @@ -8,7 +8,7 @@ import LinearAlgebra export xlogx, xlogy, xlog1py, xexpx, xexpy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1, softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, logsumexp!, softmax, - softmax!, logcosh, logabssinh, cloglog, cexpexp, + softmax!, logcosh, logabssinh, logabstanh, cloglog, cexpexp, loglogistic, logitexp, log1mlogistic, logit1mexp include("basicfuns.jl") diff --git a/src/basicfuns.jl b/src/basicfuns.jl index a9d7f2c..bd20aeb 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -218,6 +218,22 @@ end """ $(SIGNATURES) +Return `log(abs(tanh(x)))`, evaluated carefully. + +The implementation ensures `logabstanh(-x) = logabstanh(x)`. +""" +function logabstanh(x::Real) + a = abs(x) + if 8*a < 3 + log(tanh(a)) + else + log1p(-2/(exp(2*a)+1)) + end +end + +""" +$(SIGNATURES) + Return `log(1+x^2)` evaluated carefully for `abs(x)` very small or very large. """ log1psq(x::Real) = log1p(abs2(x)) diff --git a/test/basicfuns.jl b/test/basicfuns.jl index c16e2d1..cc5bccd 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -93,7 +93,7 @@ end end end -@testset "logcosh and logabssinh" begin +@testset "logcosh and logabssinh and logabstanh" begin for x in (randn(), randn(Float32)) @test @inferred(logcosh(x)) isa typeof(x) @test logcosh(x) ≈ log(cosh(x)) @@ -101,21 +101,28 @@ end @test @inferred(logabssinh(x)) isa typeof(x) @test logabssinh(x) ≈ log(abs(sinh(x))) @test logabssinh(-x) == logabssinh(x) + @test @inferred(logabstanh(x)) isa typeof(x) + @test logabstanh(x) ≈ log(abs(tanh(x))) + @test logabstanh(-x) == logabstanh(x) end # special values for x in (-Inf, Inf, -Inf32, Inf32) @test @inferred(logcosh(x)) === oftype(x, Inf) @test @inferred(logabssinh(x)) === oftype(x, Inf) + @test @inferred(logabstanh(x)) === -oftype(x, 0) end for x in (NaN, NaN32) @test @inferred(logcosh(x)) === x @test @inferred(logabssinh(x)) === x + @test @inferred(logabstanh(x)) === x end - @testset "accuracy of `logcosh`" begin + @testset "accuracy" begin for t in (Float16, Float32, Float64) - @test ulp_error_maximum(logcosh, range(start = t(-3), stop = t(3), length = 1000)) < 3 + ran = range(start = t(-3), stop = t(3), length = 1000) + @test ulp_error_maximum(logcosh, ran) < 3 + @test ulp_error_maximum(logabstanh, ran) < 3 end end end diff --git a/test/chainrules.jl b/test/chainrules.jl index f844d0b..c19ae93 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -64,6 +64,8 @@ test_rrule(logcosh, x) test_frule(logabssinh, x) test_rrule(logabssinh, x) + test_frule(logabstanh, x) + test_rrule(logabstanh, x) end @testset "log1pexp" begin