Skip to content

fix accuracy of logit #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,13 @@ for ``0 < x < 1``.

Its inverse is the [`logistic`](@ref) function.
"""
logit(x::Real) = log(x / (one(x) - x))
function logit(x::Real)
if 4 * x < 1
-log(inv(x) - 1)
else
2 * atanh(2*x - 1)
end
end

"""
$(SIGNATURES)
Expand Down
7 changes: 7 additions & 0 deletions test/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ end
@test logistic(+750.0) === 1.0
@test iszero(logit(0.5))
@test logit(logistic(2)) ≈ 2.0
@testset "accuracy of `logit`" begin
for t in (Float16, Float32, Float64)
for x in range(start = t(0), stop = t(1), length = 500)
@test 2 * ulp_error(logit, x) < 3
end
end
end
end

@testset "logcosh and logabssinh" begin
Expand Down
48 changes: 48 additions & 0 deletions test/common/ULPError.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
module ULPError
export ulp_error, ulp_error_maximum
function ulp_error(accurate::AbstractFloat, approximate::AbstractFloat)
# the ULP error is usually not required to great accuracy, so `Float32` should be precise enough
zero_return = 0f0
inf_return = Inf32
# handle floating-point edge cases
if !(isfinite(accurate) && isfinite(approximate))
accur_is_nan = isnan(accurate)
approx_is_nan = isnan(approximate)
if accur_is_nan || approx_is_nan
return if accur_is_nan === approx_is_nan
zero_return
else
inf_return
end
end
if isinf(approximate)
return if isinf(accurate) && (signbit(accurate) == signbit(approximate))
zero_return
else
inf_return
end
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case isfinite(approximate) and isinf(accurate) is not handled here. Generally, I think it would be clearer to just have a single flat chain of if/elseif statements to handle all cases, e.g.,

if isinan(accurate) || isnan(approximate)
    return accur_is_nan === approx_is_nan ? zero_return : inf_return
elseif isinf(approximate)
    return isinf(accurate) && (signbit(accurate) == signbit(approximate)) ? zero_return : inf_return
else
...
end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case isfinite(approximate) and isinf(accurate) is not handled here.

Indeed. Only the cases not handled correctly by the general formula below are handled specially here.

Generally, I think it would be clearer to just have a single flat chain of if/elseif statements to handle all cases

That's basically what the current situation is, except that the branches are wrapped into the top-level single branch (if !(isfinite(accurate) && isfinite(approximate))) to prevent adverse effect on the performance of the general case.

acc = if accurate isa Union{Float16, Float32}
# widen for better accuracy when doing so does not impact performance too much
widen(accurate)
else
accurate
end
abs(Float32((approximate - acc) / eps(approximate))::Float32)
end
function ulp_error(accurate::Acc, approximate::App, x::AbstractFloat) where {Acc, App}
acc = accurate(x)
app = approximate(x)
ulp_error(acc, app)
end
function ulp_error(func::Func, x::AbstractFloat) where {Func}
ulp_error(func ∘ BigFloat, func, x)
end
function ulp_error_maximum(func::Func, iterator) where {Func}
function f(x::AbstractFloat)
ulp_error(func, x)
end
maximum(f, iterator)
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ using OffsetArrays
using Random
using Test

include("common/ULPError.jl")
using .ULPError

Random.seed!(1234)

include("basicfuns.jl")
Expand Down