Skip to content

Commit afb44f5

Browse files
authored
change AbstractString and Integer hashing to use generic hashing interface (#59691)
Now that hashing has 3 interfaces (pointer (unsafe), array (indexable), iterable) in decreasing levels of typical optimization and performance, use those instead of making custom implementations for specific types. This automatically opts all AbstractString into fast hashing if they've correctly defined the `codeunit` string interface.
1 parent 5c294de commit afb44f5

File tree

5 files changed

+139
-96
lines changed

5 files changed

+139
-96
lines changed

base/gmp.jl

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -864,21 +864,48 @@ if Limb === UInt64 === UInt
864864

865865
using .Base: HASH_SECRET, hash_bytes, hash_finalizer
866866

867+
# UnsafeLimbView provides a safe iterator interface to BigInt limb data
868+
struct UnsafeLimbView <: AbstractVector{UInt8}
869+
bigint::BigInt
870+
start_byte::Int
871+
num_bytes::Int
872+
end
873+
874+
function Base.size(view::UnsafeLimbView)
875+
return (view.num_bytes,)
876+
end
877+
878+
function Base.getindex(view::UnsafeLimbView, i::Int)
879+
@boundscheck checkbounds(view, i)
880+
GC.@preserve view begin
881+
limb_index = div(view.start_byte + i - 2, 8) + 1
882+
byte_in_limb = (view.start_byte + i - 2) % 8
883+
limb = unsafe_load(view.bigint.d, limb_index)
884+
return UInt8((limb >> (8 * byte_in_limb)) & 0xff)
885+
end
886+
end
887+
888+
function Base.iterate(view::UnsafeLimbView, state::Int = 1)
889+
state > view.num_bytes && return nothing
890+
return @inbounds(view[state]), state + 1
891+
end
892+
893+
function Base.length(view::UnsafeLimbView)
894+
return view.num_bytes
895+
end
896+
867897
function hash_integer(n::BigInt, h::UInt)
868898
iszero(n) && return hash_integer(0, h)
869-
GC.@preserve n begin
870-
s = n.size
871-
h ⊻= (s < 0)
872-
873-
us = abs(s)
874-
leading_zero_bytes = div(leading_zeros(unsafe_load(n.d, us)), 8)
875-
hash_bytes(
876-
Ptr{UInt8}(n.d),
877-
8 * us - leading_zero_bytes,
878-
h,
879-
HASH_SECRET
880-
)
881-
end
899+
s = n.size
900+
h ⊻= (s < 0)
901+
902+
us = abs(s)
903+
leading_zero_bytes = div(leading_zeros(unsafe_load(n.d, us)), 8)
904+
num_bytes = 8 * us - leading_zero_bytes
905+
906+
# Use UnsafeLimbView for safe iterator-based access
907+
limb_view = UnsafeLimbView(n, 1, num_bytes)
908+
return hash_bytes(limb_view, h, HASH_SECRET)
882909
end
883910

884911
function hash(x::BigInt, h::UInt)
@@ -913,12 +940,11 @@ if Limb === UInt64 === UInt
913940
h ⊻= (sz < 0)
914941
leading_zero_bytes = div(leading_zeros(unsafe_load(x.d, asz)), 8)
915942
trailing_zero_bytes = div(pow, 8)
916-
return hash_bytes(
917-
Ptr{UInt8}(x.d) + trailing_zero_bytes,
918-
8 * asz - (leading_zero_bytes + trailing_zero_bytes),
919-
h,
920-
HASH_SECRET
921-
)
943+
num_bytes = 8 * asz - (leading_zero_bytes + trailing_zero_bytes)
944+
945+
# Use UnsafeLimbView for safe iterator-based access
946+
limb_view = UnsafeLimbView(x, trailing_zero_bytes + 1, num_bytes)
947+
return hash_bytes(limb_view, h, HASH_SECRET)
922948
end
923949
end
924950
end

base/hashing.jl

Lines changed: 89 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -70,80 +70,100 @@ hash(x::UInt64, h::UInt) = hash_uint64(hash_mix_linear(x, h))
7070
hash(x::Int64, h::UInt) = hash(bitcast(UInt64, x), h)
7171
hash(x::Union{Bool, Int8, UInt8, Int16, UInt16, Int32, UInt32}, h::UInt) = hash(Int64(x), h)
7272

73+
# IntegerCodeUnits provides a little-endian byte representation of integers
74+
struct IntegerCodeUnits{T<:Integer} <: AbstractVector{UInt8}
75+
value::T
76+
num_bytes::Int
77+
78+
function IntegerCodeUnits(x::T) where {T<:Integer}
79+
# Calculate number of bytes needed (always pad to full byte)
80+
u = abs(x)
81+
num_bytes = max(cld(top_set_bit(u), 8), 1)
82+
return new{T}(x, num_bytes)
83+
end
84+
end
85+
86+
function Base.size(units::IntegerCodeUnits)
87+
return (units.num_bytes,)
88+
end
89+
90+
function Base.length(units::IntegerCodeUnits)
91+
return units.num_bytes
92+
end
93+
94+
function Base.getindex(units::IntegerCodeUnits, i::Int)
95+
@boundscheck checkbounds(units, i)
96+
u = abs(units.value)
97+
byte_pos = i - 1
98+
return UInt8((u >>> (8 * byte_pos)) & 0xff)
99+
end
100+
101+
function Base.iterate(units::IntegerCodeUnits, state::Int = 1)
102+
state > units.num_bytes && return nothing
103+
return units[state], state + 1
104+
end
105+
106+
# Main interface function to get little-endian byte representation of integers
107+
codeunits(x::Integer) = IntegerCodeUnits(x)
108+
109+
# UTF8Units provides UTF-8 byte iteration for any AbstractString
110+
struct UTF8Units{T<:AbstractString}
111+
string::T
112+
end
113+
114+
utf8units(s::AbstractString) = codeunit(s) <: UInt8 ? codeunits(s) : UTF8Units(s)
115+
116+
# Iterator state: (char_iter_state, remaining_utf8_bytes)
117+
function Base.iterate(units::UTF8Units)
118+
char_result = iterate(units.string)
119+
char_result === nothing && return nothing
120+
char, char_state = char_result
121+
122+
# Decode char to UTF-8 bytes (similar to the write function)
123+
u = bswap(reinterpret(UInt32, char))
124+
125+
# Return first byte and set up state for remaining bytes
126+
first_byte = u % UInt8
127+
remaining_bytes = u >> 8
128+
return first_byte, (char_state, remaining_bytes)
129+
end
130+
131+
function Base.iterate(units::UTF8Units, state)
132+
char_state, remaining_bytes = state
133+
# If we have more bytes from current char, return next byte
134+
if remaining_bytes != 0
135+
byte = remaining_bytes % UInt8
136+
new_remaining = remaining_bytes >> 8
137+
return byte, (char_state, new_remaining)
138+
end
139+
140+
# Move to next char
141+
char_result = iterate(units.string, char_state)
142+
char_result === nothing && return nothing
143+
char, new_char_state = char_result
144+
145+
# Decode new char to UTF-8 bytes
146+
u = bswap(reinterpret(UInt32, char))
147+
148+
# Return first byte and set up state for remaining bytes
149+
first_byte = u % UInt8
150+
remaining_bytes = u >> 8
151+
152+
return first_byte, (new_char_state, remaining_bytes)
153+
end
154+
73155
hash_integer(x::Integer, h::UInt) = _hash_integer(x, UInt64(h)) % UInt
74156
function _hash_integer(
75157
x::Integer,
76158
seed::UInt64,
77159
secret::NTuple{4, UInt64} = HASH_SECRET
78160
)
161+
# Handle sign by XOR-ing with seed
79162
seed ⊻= (x < 0)
80-
u0 = abs(x) # n.b.: this hashes typemin(IntN) correctly even if abs fails
81-
u = u0
82-
83-
# always left-pad to full byte
84-
buflen = UInt(max(cld(top_set_bit(u), 8), 1))
85-
seed = seed hash_mix(seed secret[3], secret[2])
86-
87-
a = zero(UInt64)
88-
b = zero(UInt64)
89-
i = buflen
90-
91-
if buflen 16
92-
if buflen 4
93-
seed ⊻= buflen
94-
if buflen 8
95-
a = UInt64(u % UInt64)
96-
b = UInt64((u >>> (8 * (buflen - 8))) % UInt64)
97-
else
98-
a = UInt64(u % UInt32)
99-
b = UInt64((u >>> (8 * (buflen - 4))) % UInt32)
100-
end
101-
else # buflen > 0
102-
b0 = u % UInt8
103-
b1 = (u >>> (8 * div(buflen, 2))) % UInt8
104-
b2 = (u >>> (8 * (buflen - 1))) % UInt8
105-
a = (UInt64(b0) << 45) | UInt64(b2)
106-
b = UInt64(b1)
107-
end
108-
else
109-
if i > 48
110-
see1 = seed
111-
see2 = seed
112-
while i > 48
113-
l0 = u % UInt64; u >>>= 64
114-
l1 = u % UInt64; u >>>= 64
115-
l2 = u % UInt64; u >>>= 64
116-
l3 = u % UInt64; u >>>= 64
117-
l4 = u % UInt64; u >>>= 64
118-
l5 = u % UInt64; u >>>= 64
119-
120-
seed = hash_mix(l0 secret[1], l1 seed)
121-
see1 = hash_mix(l2 secret[2], l3 see1)
122-
see2 = hash_mix(l4 secret[3], l5 see2)
123-
i -= 48
124-
end
125-
seed ⊻= see1
126-
seed ⊻= see2
127-
end
128-
if i > 16
129-
l0 = u % UInt64; u >>>= 64
130-
l1 = u % UInt64; u >>>= 64
131-
seed = hash_mix(l0 secret[3], l1 seed)
132-
if i > 32
133-
l2 = u % UInt64; u >>>= 64
134-
l3 = u % UInt64; u >>>= 64
135-
seed = hash_mix(l2 secret[3], l3 seed)
136-
end
137-
end
138-
139-
a = (u0 >>> 8(buflen - 16)) % UInt64 i
140-
b = (u0 >>> 8(buflen - 8)) % UInt64
141-
end
142-
143-
a = a secret[2]
144-
b = b seed
145-
b, a = mul_parts(a, b)
146-
return hash_mix(a secret[4], b secret[2] i)
163+
# Get little-endian byte representation of absolute value
164+
# and hash using the new safe hash_bytes function
165+
u = abs(x) # n.b.: this hashes typemin(IntN) correctly even if abs fails
166+
return hash_bytes(codeunits(u), seed, secret)
147167
end
148168

149169

@@ -619,6 +639,8 @@ end
619639
return hash_mix(a secret[4], b secret[2] bytes_chunk)
620640
end
621641

642+
hash(data::AbstractString, h::UInt) =
643+
hash_bytes(utf8units(data), UInt64(h), HASH_SECRET) % UInt
622644
@assume_effects :total hash(data::String, h::UInt) =
623645
GC.@preserve data hash_bytes(pointer(data), sizeof(data), UInt64(h), HASH_SECRET) % UInt
624646

base/strings/basic.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -362,10 +362,6 @@ end
362362

363363
isless(a::Symbol, b::Symbol) = cmp(a, b) < 0
364364

365-
# hashing
366-
367-
hash(s::AbstractString, h::UInt) = hash(String(s)::String, h)
368-
369365
## character index arithmetic ##
370366

371367
"""

base/strings/lazy.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ iterate(s::LazyString, i::Integer) = iterate(String(s), i)
9696
isequal(a::LazyString, b::LazyString) = isequal(String(a), String(b))
9797
==(a::LazyString, b::LazyString) = (String(a) == String(b))
9898
ncodeunits(s::LazyString) = ncodeunits(String(s))
99-
codeunit(s::LazyString) = codeunit(String(s))
99+
codeunit(s::LazyString) = codeunit("") # returns UInt8
100100
codeunit(s::LazyString, i::Integer) = codeunit(String(s), i)
101+
codeunits(s::LazyString) = codeunits(String(s))
101102
isvalid(s::LazyString, i::Integer) = isvalid(String(s), i)

test/strings/basic.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,12 +1193,10 @@ end
11931193
apple_uint8 = Vector{UInt8}("Apple")
11941194
@test apple_uint8 == [0x41, 0x70, 0x70, 0x6c, 0x65]
11951195

1196-
apple_uint8 = Array{UInt8}("Apple")
1197-
@test apple_uint8 == [0x41, 0x70, 0x70, 0x6c, 0x65]
1198-
1199-
Base.String(::tstStringType) = "Test"
1196+
Base.codeunit(::tstStringType) = UInt8
1197+
Base.codeunits(t::tstStringType) = t.data
12001198
abstract_apple = tstStringType(apple_uint8)
1201-
@test hash(abstract_apple, UInt(1)) == hash("Test", UInt(1))
1199+
@test hash(abstract_apple, UInt(1)) == hash("Apple", UInt(1))
12021200

12031201
@test length("abc", 1, 3) == length("abc", UInt(1), UInt(3))
12041202

0 commit comments

Comments
 (0)