diff --git a/src/basicfuns.jl b/src/basicfuns.jl index 5cc0df2..dbb46f9 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -120,7 +120,13 @@ for ``0 < x < 1``. Its inverse is the [`logistic`](@ref) function. """ -logit(x::Real) = log(x / (one(x) - x)) +function logit(x::Real) + if 4 * x < 1 + -log(inv(x) - 1) + else + 2 * atanh(2*x - 1) + end +end """ $(SIGNATURES) diff --git a/test/basicfuns.jl b/test/basicfuns.jl index 6a84b54..af0ee66 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -86,6 +86,11 @@ end @test logistic(+750.0) === 1.0 @test iszero(logit(0.5)) @test logit(logistic(2)) ≈ 2.0 + @testset "accuracy of `logit`" begin + for t in (Float16, Float32, Float64) + @test 2 * ulp_error_maximum(logit, range(start = t(0), stop = t(1), length = 500)) < 3 + end + end end @testset "logcosh and logabssinh" begin diff --git a/test/common/ULPError.jl b/test/common/ULPError.jl new file mode 100644 index 0000000..ee62eca --- /dev/null +++ b/test/common/ULPError.jl @@ -0,0 +1,51 @@ +module ULPError + +export ulp_error, ulp_error_maximum + +function ulp_error(accurate::AbstractFloat, approximate::AbstractFloat) + # the ULP error is usually not required to great accuracy, so `Float32` should be precise enough + zero_return = 0f0 + inf_return = Inf32 + # handle floating-point edge cases + if !(isfinite(accurate) && isfinite(approximate)) + accur_is_nan = isnan(accurate) + approx_is_nan = isnan(approximate) + if accur_is_nan || approx_is_nan + return if accur_is_nan === approx_is_nan + zero_return + else + inf_return + end + end + if isinf(approximate) + return if isinf(accurate) && (signbit(accurate) == signbit(approximate)) + zero_return + else + inf_return + end + end + end + acc = if accurate isa Union{Float16, Float32} + # widen for better accuracy when doing so does not impact performance too much + widen(accurate) + else + accurate + end + abs(Float32((approximate - acc) / eps(approximate))::Float32) +end + +function ulp_error(accurate, approximate, x::AbstractFloat) + acc = accurate(x) + app = approximate(x) + ulp_error(acc, app) +end + +function ulp_error(func::Func, x::AbstractFloat) where {Func} + ulp_error(func ∘ BigFloat, func, x) +end + +function ulp_error_maximum(func::Func, iterator) where {Func} + maximum(Base.Fix1(ulp_error, func), iterator) +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 27b247f..5fcc6ad 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,9 @@ using OffsetArrays using Random using Test +include("common/ULPError.jl") +using .ULPError + Random.seed!(1234) include("basicfuns.jl")