Skip to content

fix accuracy of logit #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,13 @@ for ``0 < x < 1``.

Its inverse is the [`logistic`](@ref) function.
"""
logit(x::Real) = log(x / (one(x) - x))
function logit(x::Real)
if 4 * x < 1
-log(inv(x) - 1)
else
2 * atanh(2*x - 1)
end
end

"""
$(SIGNATURES)
Expand Down
7 changes: 7 additions & 0 deletions test/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ end
@test logistic(+750.0) === 1.0
@test iszero(logit(0.5))
@test logit(logistic(2)) ≈ 2.0
@testset "accuracy of `logit`" begin
for t in (Float16, Float32, Float64)
for x in range(start = t(0), stop = t(1), length = 500)
@test 2 * ulp_error(logit, x) < 3
end
end
end
end

@testset "logcosh and logabssinh" begin
Expand Down
57 changes: 57 additions & 0 deletions test/common/ULPError.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
module ULPError
export ulp_error, ulp_error_maximum
@noinline function throw_invalid()
throw(ArgumentError("invalid"))
end
function ulp_error(accurate::AbstractFloat, approximate::AbstractFloat)
# the ULP error is usually not required to great accuracy, so `Float32` should be precise enough
zero_return = Float32(0)
inf_return = Float32(Inf)
let accur_is_nan = isnan(accurate), approx_is_nan = isnan(approximate)
if accur_is_nan || approx_is_nan
if accur_is_nan === approx_is_nan
return zero_return
end
return inf_return
end
end
if isinf(accurate) || iszero(accurate) # handle floating-point edge cases
if isinf(accurate)
if isinf(approximate) && (signbit(accurate) == signbit(approximate))
return zero_return
end
return inf_return
end
# `iszero(accurate)`
if iszero(approximate)
return zero_return
end
return inf_return
end
# assuming `precision(BigFloat)` is great enough
acc = if accurate isa BigFloat
accurate
else
BigFloat(accurate)::BigFloat
end
err = abs(Float32((approximate - acc) / eps(approximate))::Float32)
if isnan(err)
@noinline throw_invalid() # unexpected
end
err
end
function ulp_error(accurate::Acc, approximate::App, x::AbstractFloat) where {Acc, App}
acc = accurate(x)
app = approximate(x)
ulp_error(acc, app)
end
function ulp_error(func::Func, x::AbstractFloat) where {Func}
ulp_error(func ∘ BigFloat, func, x)
end
function ulp_error_maximum(func::Func, iterator) where {Func}
function f(x::AbstractFloat)
ulp_error(func, x)
end
maximum(f, iterator)
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ using OffsetArrays
using Random
using Test

include("common/ULPError.jl")
using .ULPError

Random.seed!(1234)

include("basicfuns.jl")
Expand Down