Skip to content

Commit 72b0fe2

Browse files
sumiya11LilithHafneroscardssmith
authored
Add mul_hi function for bit integers (#57276)
Move the `_mul_high` function from base/multinverses.jl to base/int.jl. Rename it to `mul_hi`. Addresses #53855. I chose to reuse existing implementation over using the one proposed in #53855 by @LilithHafner because their performance is similar on my PC. ```julia # existing julia> @Btime Base.MultiplicativeInverses._mul_high(x, y) setup=((x,y)=(rand(UInt128),rand(UInt128))); 1.808 ns (0 allocations: 0 bytes) # 53855 julia> @Btime mul_hi_li2(x, y) setup=((x,y)=(rand(UInt128),rand(UInt128))); 1.800 ns (0 allocations: 0 bytes) ``` --------- Co-authored-by: Lilith Orion Hafner <[email protected]> Co-authored-by: Oscar Smith <[email protected]>
1 parent 7267793 commit 72b0fe2

File tree

4 files changed

+59
-24
lines changed

4 files changed

+59
-24
lines changed

base/int.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,44 @@ inv(x::Integer) = float(one(x)) / float(x)
9696
# skip promotion for system integer types
9797
(/)(x::BitInteger, y::BitInteger) = float(x) / float(y)
9898

99+
100+
"""
101+
mul_hi(a::T, b::T) where {T<:Base.Integer}
102+
103+
Returns the higher half of the product of `a` and `b` where `T` is a fixed size integer.
104+
105+
# Examples
106+
```jldoctest
107+
julia> Base.mul_hi(12345678987654321, 123456789)
108+
82624
109+
110+
julia> (widen(12345678987654321) * 123456789) >> 64
111+
82624
112+
113+
julia> Base.mul_hi(0xff, 0xff)
114+
0xfe
115+
```
116+
"""
117+
function mul_hi(a::T, b::T) where {T<:Integer}
118+
((widen(a)*b) >>> Base.top_set_bit(-1 % T)) % T
119+
end
120+
121+
function mul_hi(a::UInt128, b::UInt128)
122+
shift = sizeof(a)*4
123+
mask = typemax(UInt128) >> shift
124+
a1, a2 = a >>> shift, a & mask
125+
b1, b2 = b >>> shift, b & mask
126+
a1b1, a1b2, a2b1, a2b2 = a1*b1, a1*b2, a2*b1, a2*b2
127+
carry = ((a1b2 & mask) + (a2b1 & mask) + (a2b2 >>> shift)) >>> shift
128+
a1b1 + (a1b2 >>> shift) + (a2b1 >>> shift) + carry
129+
end
130+
131+
function mul_hi(a::Int128, b::Int128)
132+
shift = sizeof(a)*8 - 1
133+
t1, t2 = (a >> shift) & b % UInt128, (b >> shift) & a % UInt128
134+
(mul_hi(a % UInt128, b % UInt128) - t1 - t2) % Int128
135+
end
136+
99137
"""
100138
isodd(x::Number)::Bool
101139

base/multinverses.jl

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
module MultiplicativeInverses
44

5-
import Base: div, divrem, rem, unsigned
5+
import Base: div, divrem, mul_hi, rem, unsigned
66
using Base: IndexLinear, IndexCartesian, tail
77
export multiplicativeinverse
88

@@ -134,33 +134,13 @@ struct UnsignedMultiplicativeInverse{T<:Unsigned} <: MultiplicativeInverse{T}
134134
end
135135
UnsignedMultiplicativeInverse(x::Unsigned) = UnsignedMultiplicativeInverse{typeof(x)}(x)
136136

137-
# Returns the higher half of the product a*b
138-
function _mul_high(a::T, b::T) where {T<:Union{Signed, Unsigned}}
139-
((widen(a)*b) >>> (sizeof(a)*8)) % T
140-
end
141-
142-
function _mul_high(a::UInt128, b::UInt128)
143-
shift = sizeof(a)*4
144-
mask = typemax(UInt128) >> shift
145-
a1, a2 = a >>> shift, a & mask
146-
b1, b2 = b >>> shift, b & mask
147-
a1b1, a1b2, a2b1, a2b2 = a1*b1, a1*b2, a2*b1, a2*b2
148-
carry = ((a1b2 & mask) + (a2b1 & mask) + (a2b2 >>> shift)) >>> shift
149-
a1b1 + (a1b2 >>> shift) + (a2b1 >>> shift) + carry
150-
end
151-
function _mul_high(a::Int128, b::Int128)
152-
shift = sizeof(a)*8 - 1
153-
t1, t2 = (a >> shift) & b % UInt128, (b >> shift) & a % UInt128
154-
(_mul_high(a % UInt128, b % UInt128) - t1 - t2) % Int128
155-
end
156-
157137
function div(a::T, b::SignedMultiplicativeInverse{T}) where T
158-
x = _mul_high(a, b.multiplier)
138+
x = mul_hi(a, b.multiplier)
159139
x += (a*b.addmul) % T
160140
ifelse(abs(b.divisor) == 1, a*b.divisor, (signbit(x) + (x >> b.shift)) % T)
161141
end
162142
function div(a::T, b::UnsignedMultiplicativeInverse{T}) where T
163-
x = _mul_high(a, b.multiplier)
143+
x = mul_hi(a, b.multiplier)
164144
x = ifelse(b.add, convert(T, convert(T, (convert(T, a - x) >>> 1)) + x), x)
165145
ifelse(b.divisor == 1, a, x >>> b.shift)
166146
end

base/public.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,9 @@ public
8080
isoperator,
8181
isunaryoperator,
8282

83-
# scalar math
83+
# Integer math
8484
uabs,
85+
mul_hi,
8586

8687
# C interface
8788
cconvert,

test/numbers.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2579,6 +2579,22 @@ Base.:(==)(x::TestNumber, y::TestNumber) = x.inner == y.inner
25792579
Base.abs(x::TestNumber) = TestNumber(abs(x.inner))
25802580
@test abs2(TestNumber(3+4im)) == TestNumber(25)
25812581

2582+
@testset "mul_hi" begin
2583+
n = 1000
2584+
ground_truth(x, y) = ((widen(x)*y) >> (8*sizeof(typeof(x)))) % typeof(x)
2585+
for T in [UInt8, UInt16, UInt32, UInt64, UInt128, Int8, Int16, Int32, Int64, Int128]
2586+
for trait1 in [typemin, typemax]
2587+
for trait2 in [typemin, typemax]
2588+
x, y = trait1(T), trait2(T)
2589+
@test Base.mul_hi(x, y) === ground_truth(x, y)
2590+
end
2591+
end
2592+
for (x, y) in zip(rand(T, n), rand(T, n))
2593+
@test Base.mul_hi(x, y) === ground_truth(x, y)
2594+
end
2595+
end
2596+
end
2597+
25822598
@testset "multiplicative inverses" begin
25832599
function testmi(numrange, denrange)
25842600
for d in denrange

0 commit comments

Comments
 (0)