Skip to content

Commit f39ec3d

Browse files
authored
fix exponent and top_set_bit fallbacks for Integer (#59508)
fixes #53887
1 parent 7dc50b3 commit f39ec3d

File tree

4 files changed

+103
-6
lines changed

4 files changed

+103
-6
lines changed

base/gmp.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,10 +605,19 @@ Number of ones in the binary representation of abs(x).
605605
"""
606606
count_ones_abs(x::BigInt) = iszero(x) ? 0 : MPZ.mpn_popcount(x)
607607

608+
# all uses of _bit_magnitude MUST ensure at callsite that `x` is strictly positive, otherwise it is UB
609+
_bit_magnitude(x::BigInt) = x.size * sizeof(Limb) << 3 - leading_zeros(GC.@preserve x unsafe_load(x.d, x.size))
610+
611+
function exponent(x::BigInt)
612+
iszero(x) && throw(DomainError(x, "cannot be zero"))
613+
ux = abs(x)
614+
return _bit_magnitude(ux) - 1
615+
end
616+
608617
function top_set_bit(x::BigInt)
609618
isnegative(x) && throw(DomainError(x, "top_set_bit only supports negative arguments when they have type BitSigned."))
610619
iszero(x) && return 0
611-
x.size * sizeof(Limb) << 3 - leading_zeros(GC.@preserve x unsafe_load(x.d, x.size))
620+
return _bit_magnitude(x)
612621
end
613622

614623
divrem(x::BigInt, y::BigInt, ::typeof(RoundToZero) = RoundToZero) = MPZ.tdiv_qr(x, y)

base/intfuncs.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,8 @@ function ndigits0z(x::Integer, b::Integer)
772772
end
773773

774774
# Extends the definition in base/int.jl
775-
top_set_bit(x::Integer) = ceil(Integer, log2(x + oneunit(x)))
775+
# assume x >= 0. result is implementation-defined for negative values
776+
top_set_bit(x::Integer) = iszero(x) ? 0 : exponent(x) + 1
776777

777778
"""
778779
ndigits(n::Integer; base::Integer=10, pad::Integer=1)

base/math.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,32 @@ function _exponent_finite_nonzero(x::T) where T<:IEEEFloat
977977
return k - exponent_bias(T)
978978
end
979979

980+
function _ilog2_step(y::T, d::T, s) where {T<:Integer}
981+
if (y >> s) >= d
982+
y, n = _ilog2_step(y, d*d, s+s)
983+
else
984+
n = 0
985+
end
986+
if y >= d
987+
y >>= s
988+
n = Base.checked_add(n, s)
989+
end
990+
return y, n
991+
end
992+
993+
function exponent(x::Integer)
994+
iszero(x) && throw(DomainError(x, "cannot be zero"))
995+
ux = Base.uabs(x)
996+
_, n = _ilog2_step(ux, one(ux) + one(ux), 1)
997+
return n
998+
end
999+
1000+
function exponent(x::Base.BitInteger)
1001+
iszero(x) && throw(DomainError(x, "cannot be zero"))
1002+
ux = Base.uabs(x)
1003+
return 8sizeof(ux) - leading_zeros(ux) - 1
1004+
end
1005+
9801006
"""
9811007
significand(x)
9821008

test/intfuncs.jl

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ end
563563
x::Int
564564
end
565565
MyInt(x::MyInt) = x
566-
Base.:+(a::MyInt, b::MyInt) = a.x + b.x
566+
Base.uabs(x::MyInt) = Base.uabs(x.x)
567567

568568
for n in 0:100
569569
x = ceil(Int, log2(n + 1))
@@ -579,9 +579,6 @@ end
579579
@test 32 == Base.top_set_bit(Int32(n)) == Base.top_set_bit(unsigned(Int32(n)))
580580
@test 8 == Base.top_set_bit(Int8(n)) == Base.top_set_bit(unsigned(Int8(n)))
581581
@test_throws DomainError Base.top_set_bit(big(n))
582-
# This error message should never be exposed to the end user anyway.
583-
err = n == -1 ? InexactError : DomainError
584-
@test_throws err Base.top_set_bit(MyInt(n))
585582
end
586583

587584
@test count_zeros(Int64(1)) == 63
@@ -601,6 +598,70 @@ end
601598
@test isqrt(Int8(5)) === Int8(2)
602599
end
603600

601+
@testset "exponent and top_set_bit consistency" begin
602+
for _T in (Int8, Int16, Int32, Int64, Int128)
603+
for issigned in (false, true)
604+
T = issigned ? _T : unsigned(_T)
605+
nbits = 8sizeof(T)
606+
@test_throws DomainError exponent(T(0))
607+
@test Base.top_set_bit(T(0)) == 0
608+
@test Base.top_set_bit(T(0)) == invoke(Base.top_set_bit, Tuple{Integer}, T(0))
609+
610+
for i in 0:(nbits - (issigned ? 2 : 1))
611+
p2 = T(1) << i
612+
@test exponent(p2) == i
613+
@test exponent(p2) == invoke(exponent, Tuple{Integer}, p2)
614+
@test Base.top_set_bit(p2) == i + 1
615+
@test Base.top_set_bit(p2) == invoke(Base.top_set_bit, Tuple{Integer}, p2)
616+
617+
p2m1 = p2 - T(1)
618+
if p2m1 != 0
619+
@test exponent(p2m1) == i - 1
620+
@test exponent(p2m1) == invoke(exponent, Tuple{Integer}, p2m1)
621+
@test Base.top_set_bit(p2m1) == i
622+
@test Base.top_set_bit(p2m1) == invoke(Base.top_set_bit, Tuple{Integer}, p2m1)
623+
end
624+
625+
p2p1 = p2 + T(1)
626+
if p2p1 != 0
627+
@test exponent(p2p1) == max(i, 1)
628+
@test exponent(p2p1) == invoke(exponent, Tuple{Integer}, p2p1)
629+
@test Base.top_set_bit(p2p1) == max(i, 1) + 1
630+
@test Base.top_set_bit(p2p1) == invoke(Base.top_set_bit, Tuple{Integer}, p2p1)
631+
end
632+
end
633+
634+
@test exponent(typemax(T)) == nbits - (issigned ? 2 : 1)
635+
@test exponent(typemax(T)) == invoke(exponent, Tuple{Integer}, typemax(T))
636+
expected_max = !issigned ? nbits : nbits - 1
637+
@test Base.top_set_bit(typemax(T)) == expected_max
638+
@test Base.top_set_bit(typemax(T)) == invoke(Base.top_set_bit, Tuple{Integer}, typemax(T))
639+
640+
if issigned
641+
for val in [T(-1), T(-2), T(-17), typemin(T)]
642+
expected = exponent(abs(BigInt(val)))
643+
@test exponent(val) == expected
644+
@test exponent(val) == invoke(exponent, Tuple{Integer}, val)
645+
@test Base.top_set_bit(val) == nbits
646+
@test invoke(Base.top_set_bit, Tuple{Integer}, val) == expected + 1
647+
end
648+
end
649+
end
650+
651+
@test exponent(big(2)^100) == 100
652+
@test exponent(big(2)^100 - 1) == 99
653+
@test exponent(big(2)^100 + 1) == 100
654+
@test exponent(big(-1)) == 0
655+
@test_throws DomainError exponent(big(0))
656+
657+
@test Base.top_set_bit(big(0)) == 0
658+
@test Base.top_set_bit(big(2)^100) == 101
659+
@test Base.top_set_bit(big(2)^100 - 1) == 100
660+
@test Base.top_set_bit(big(2)^100 + 1) == 101
661+
@test_throws DomainError Base.top_set_bit(big(-1))
662+
end
663+
end
664+
604665
@testset "issue #4884" begin
605666
@test isqrt(9223372030926249000) == 3037000498
606667
@test isqrt(typemax(Int128)) == parse(Int128,"13043817825332782212")

0 commit comments

Comments
 (0)