Skip to content

Commit 856025e

Browse files
committed
Random: make a new "strong" RNG using SHA
The most convenient way to define `seed!` for new RNGs is via an another RNG, with `seed!(rng::AbstractRNG, seeder::AbstractRNG)`. But RNGs want to also support more usual seeds. In order to allow them to only define the method above, a new `SeedHasher` RNG is implemented, whose purpose is to convert an initial given seed into a stream of random numbers. Given that it's not always "safe" to seed an RNG from another RNG, `SeedHasher` uses a strong cryptographic hash (SHA2) to produces random streams. The generic `seed!(rng::AbstractRNG, seed)` method now takes care of forwarding the call to `seed!(rng, SeedHasher(seed))`.
1 parent f2e223f commit 856025e

File tree

4 files changed

+128
-42
lines changed

4 files changed

+128
-42
lines changed

stdlib/Random/src/RNGs.jl

Lines changed: 88 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -285,10 +285,81 @@ end
285285

286286
### seeding
287287

288+
"""
289+
Random.SeedHasher(seed=nothing)
290+
291+
Create a `Random.SeedHasher` RNG object, which generates random bytes with the help
292+
of a cryptographic hash function (SHA2), via calls to [`Random.hash_seed`](@ref).
293+
294+
Given two seeds `s1` and `s2`, the random streams generated by
295+
`SeedHasher(s1)` and `SeedHasher(s2)` should be distinct if and only if
296+
`s1` and `s2` are distinct.
297+
298+
This RNG is used by default in `Random.seed!(::AbstractRNG, seed::Any)`, such that
299+
RNGs usually need only to implement `seed!(rng, ::AbstractRNG)`.
300+
301+
This is an internal type, subject to change.
302+
"""
303+
mutable struct SeedHasher <: AbstractRNG
304+
bytes::Vector{UInt8}
305+
idx::Int
306+
cnt::Int64
307+
308+
SeedHasher(seed=nothing) = seed!(new(), seed)
309+
end
310+
311+
seed!(rng::SeedHasher, seeder::AbstractRNG) = seed!(rng, rand(seeder, UInt64, 4))
312+
seed!(rng::SeedHasher, ::Nothing) = seed!(rng, RandomDevice())
313+
314+
function seed!(rng::SeedHasher, seed)
315+
# typically, no more than 256 bits will be needed, so use
316+
# SHA2_256 because it's faster
317+
ctx = SHA2_256_CTX()
318+
hash_seed(seed, ctx)
319+
rng.bytes = SHA.digest!(ctx)::Vector{UInt8}
320+
rng.idx = 0
321+
rng.cnt = 0
322+
rng
323+
end
324+
325+
@noinline function rehash!(rng::SeedHasher)
326+
# more random bytes are necessary, from now on use SHA2_512 to generate
327+
# more bytes at once
328+
ctx = SHA2_512_CTX()
329+
SHA.update!(ctx, rng.bytes)
330+
# also hash the counter, just for the extremely unlikely case where the hash of
331+
# rng.bytes is equal to rng.bytes (i.e. rng.bytes is a "fixed point"), or more generally
332+
# if there is a small cycle
333+
SHA.update!(ctx, reinterpret(NTuple{8, UInt8}, rng.cnt += 1))
334+
rng.bytes = SHA.digest!(ctx)
335+
rng.idx = 0
336+
rng
337+
end
338+
339+
function rand(rng::SeedHasher, ::SamplerType{UInt8})
340+
rng.idx < length(rng.bytes) || rehash!(rng)
341+
rng.bytes[rng.idx += 1]
342+
end
343+
344+
for TT = Base.BitInteger_types
345+
TT === UInt8 && continue
346+
@eval function rand(rng::SeedHasher, ::SamplerType{$TT})
347+
xx = zero($TT)
348+
for ii = 0:sizeof($TT)-1
349+
xx |= (rand(rng, UInt8) % $TT) << (8 * ii)
350+
end
351+
xx
352+
end
353+
end
354+
355+
rand(rng::SeedHasher, ::SamplerType{Bool}) = rand(rng, UInt8) % Bool
356+
357+
rng_native_52(::SeedHasher) = UInt64
358+
359+
288360
#### hash_seed()
289361

290-
function hash_seed(seed::Integer)
291-
ctx = SHA.SHA2_256_CTX()
362+
function hash_seed(seed::Integer, ctx::SHA_CTX)
292363
neg = signbit(seed)
293364
if neg
294365
seed = ~seed
@@ -302,21 +373,18 @@ function hash_seed(seed::Integer)
302373
end
303374
# make sure the hash of negative numbers is different from the hash of positive numbers
304375
neg && SHA.update!(ctx, (0x01,))
305-
SHA.digest!(ctx)
376+
nothing
306377
end
307378

308-
function hash_seed(seed::Union{AbstractArray{UInt32}, AbstractArray{UInt64}})
309-
ctx = SHA.SHA2_256_CTX()
379+
function hash_seed(seed::Union{AbstractArray{UInt32}, AbstractArray{UInt64}}, ctx::SHA_CTX)
310380
for xx in seed
311381
SHA.update!(ctx, reinterpret(NTuple{8, UInt8}, UInt64(xx)))
312382
end
313383
# discriminate from hash_seed(::Integer)
314384
SHA.update!(ctx, (0x10,))
315-
SHA.digest!(ctx)
316385
end
317386

318-
function hash_seed(str::AbstractString)
319-
ctx = SHA.SHA2_256_CTX()
387+
function hash_seed(str::AbstractString, ctx::SHA_CTX)
320388
# convert to String such that `codeunits(str)` below is consistent between equal
321389
# strings of different types
322390
str = String(str)
@@ -331,25 +399,29 @@ function hash_seed(str::AbstractString)
331399
SHA.update!(ctx, (pad % UInt8,))
332400
end
333401
SHA.update!(ctx, (0x05,))
334-
SHA.digest!(ctx)
335402
end
336403

337404

338405
"""
339-
hash_seed(seed)::AbstractVector{UInt8}
406+
Random.hash_seed(seed, ctx::SHA_CTX)::AbstractVector{UInt8}
407+
408+
Update `ctx` via `SHA.update!` with the content of `seed`.
409+
This function is used by the [`SeedHasher`](@ref) RNG to produce
410+
random bytes.
340411
341-
Return a cryptographic hash of `seed` of size 256 bits (32 bytes).
342412
`seed` can currently be of type
343413
`Union{Integer, AbstractString, AbstractArray{UInt32}, AbstractArray{UInt64}}`,
344414
but modules can extend this function for types they own.
345415
346-
`hash_seed` is "injective" : if `n != m`, then `hash_seed(n) != `hash_seed(m)`.
347-
Moreover, if `n == m`, then `hash_seed(n) == hash_seed(m)`.
348-
349-
This is an internal function subject to change.
416+
`hash_seed` is "injective" : for two equivalent context objects `cn` and `cm`,
417+
if `n != m`, then `cn` and `cm` will be distinct after calling
418+
`hash_seed(n, cn); hash_seed(m, cm)`.
419+
Moreover, if `n == m`, then `cn` and `cm` remain equivalent after calling
420+
`hash_seed(n, cn); hash_seed(m, cm)`.
350421
"""
351422
hash_seed
352423

424+
353425
#### seed!()
354426

355427
function initstate!(r::MersenneTwister, data::StridedVector, seed)
@@ -372,7 +444,7 @@ end
372444
# seeds, while having them printed reasonably tersely.
373445
seed!(r::MersenneTwister, seeder::AbstractRNG) = seed!(r, rand(seeder, UInt128))
374446
seed!(r::MersenneTwister, ::Nothing) = seed!(r, RandomDevice())
375-
seed!(r::MersenneTwister, seed) = initstate!(r, hash_seed(seed), seed)
447+
seed!(r::MersenneTwister, seed) = initstate!(r, rand(SeedHasher(seed), UInt32, 8), seed)
376448

377449

378450
### Global RNG

stdlib/Random/src/Random.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ include("DSFMT.jl")
1313
using .DSFMT
1414
using Base.GMP.MPZ
1515
using Base.GMP: Limb
16-
import SHA
16+
using SHA: SHA, SHA2_256_CTX, SHA2_512_CTX, SHA_CTX
1717

1818
using Base: BitInteger, BitInteger_types, BitUnsigned, require_one_based_indexing
1919
import Base: copymutable, copy, copy!, ==, hash, convert,
@@ -457,11 +457,14 @@ julia> rand(Xoshiro(), Bool) # not reproducible either
457457
true
458458
```
459459
"""
460-
function seed!(rng::AbstractRNG, seed=nothing)
460+
function seed!(rng::AbstractRNG, seed::Any=nothing)
461461
if seed === nothing
462462
seed!(rng, RandomDevice())
463-
else
463+
elseif seed isa AbstractRNG
464+
# avoid getting into an infinite recursive call from the other branches
464465
throw(MethodError(seed!, (rng, seed)))
466+
else
467+
seed!(rng, SeedHasher(seed))
465468
end
466469
end
467470

stdlib/Random/src/Xoshiro.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -246,11 +246,6 @@ hash(x::Union{TaskLocalRNG, Xoshiro}, h::UInt) = hash(getstate(x), h + 0x49a62c2
246246
seed!(rng::Union{TaskLocalRNG, Xoshiro}, seeder::AbstractRNG) =
247247
initstate!(rng, rand(seeder, NTuple{4, UInt64}))
248248

249-
seed!(rng::Union{TaskLocalRNG, Xoshiro}, ::Nothing) = @invoke seed!(rng::AbstractRNG, nothing::Any)
250-
251-
seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed) =
252-
initstate!(rng, reinterpret(UInt64, hash_seed(seed)))
253-
254249

255250
@inline function rand(x::Union{TaskLocalRNG, Xoshiro}, ::SamplerType{UInt64})
256251
s0, s1, s2, s3 = getstate(x)

stdlib/Random/test/runtests.jl

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ using Random
1111
using Random.DSFMT
1212

1313
using Random: default_rng, Sampler, SamplerRangeFast, SamplerRangeInt, SamplerRangeNDL, MT_CACHE_F, MT_CACHE_I
14-
using Random: jump_128, jump_192, jump_128!, jump_192!
14+
using Random: jump_128, jump_192, jump_128!, jump_192!, SeedHasher
1515

16+
import SHA
1617
import Future # randjump
1718

1819
function test_uniform(xs::AbstractArray{T}) where {T<:AbstractFloat}
@@ -297,7 +298,7 @@ for f in (:<, :<=, :>, :>=, :(==), :(!=))
297298
end
298299

299300
# test all rand APIs
300-
for rng in ([], [MersenneTwister(0)], [RandomDevice()], [Xoshiro()])
301+
for rng in ([], [MersenneTwister(0)], [RandomDevice()], [Xoshiro(0)], [SeedHasher(0)])
301302
realrng = rng == [] ? default_rng() : only(rng)
302303
ftypes = [Float16, Float32, Float64, FakeFloat64, BigFloat]
303304
cftypes = [ComplexF16, ComplexF32, ComplexF64, ftypes...]
@@ -453,7 +454,8 @@ function hist(X, n)
453454
end
454455

455456
@testset "uniform distribution of floats" begin
456-
for rng in [MersenneTwister(), RandomDevice(), Xoshiro()],
457+
seed = rand(UInt128)
458+
for rng in [MersenneTwister(seed), RandomDevice(), Xoshiro(seed), SeedHasher(seed)],
457459
T in [Float16, Float32, Float64, BigFloat],
458460
prec in (T == BigFloat ? [3, 53, 64, 100, 256, 1000] : [256])
459461

@@ -480,7 +482,8 @@ end
480482
# but also for 3 linear combinations of positions (for the array version)
481483
lcs = unique!.([rand(1:n, 2), rand(1:n, 3), rand(1:n, 5)])
482484
aslcs = zeros(Int, 3)
483-
for rng = (MersenneTwister(), RandomDevice(), Xoshiro())
485+
seed = rand(UInt128)
486+
for rng = (MersenneTwister(seed), RandomDevice(), Xoshiro(seed), SeedHasher(seed))
484487
for scalar = [false, true]
485488
fill!(a, 0)
486489
fill!(as, 0)
@@ -662,7 +665,7 @@ end
662665
@testset "Random.seed!(rng, ...) returns rng" begin
663666
# issue #21248
664667
seed = rand(UInt)
665-
for m = ([MersenneTwister(seed)], [Xoshiro(seed)], [])
668+
for m = ([MersenneTwister(seed)], [Xoshiro(seed)], [SeedHasher(seed)], [])
666669
m2 = m == [] ? default_rng() : m[1]
667670
@test Random.seed!(m...) === m2
668671
@test Random.seed!(m..., rand(UInt)) === m2
@@ -708,7 +711,7 @@ end
708711
# this shouldn't crash (#22403)
709712
@test_throws MethodError rand!(Union{UInt,Int}[1, 2, 3])
710713

711-
@testset "$RNG() & Random.seed!(rng::$RNG) initializes randomly" for RNG in (MersenneTwister, RandomDevice, Xoshiro)
714+
@testset "$RNG() & Random.seed!(rng::$RNG) initializes randomly" for RNG in (MersenneTwister, RandomDevice, Xoshiro, SeedHasher)
712715
m = RNG()
713716
a = rand(m, Int)
714717
m = RNG()
@@ -729,7 +732,7 @@ end
729732
@test rand(m, Int) (a, b, c, d)
730733
end
731734

732-
@testset "$RNG(seed) & Random.seed!(m::$RNG, seed) produce the same stream" for RNG=(MersenneTwister, Xoshiro)
735+
@testset "$RNG(seed) & Random.seed!(m::$RNG, seed) produce the same stream" for RNG=(MersenneTwister, Xoshiro, SeedHasher)
733736
seeds = Any[0, 1, 2, 10000, 10001, rand(UInt32, 8), randstring(), randstring(), rand(UInt128, 3)...]
734737
if RNG == Xoshiro
735738
push!(seeds, rand(UInt64, rand(1:4)))
@@ -769,7 +772,10 @@ struct RandomStruct23964 end
769772
@test_throws MethodError rand(RandomStruct23964())
770773
end
771774

772-
@testset "rand(::$(typeof(RNG)), ::UnitRange{$T}" for RNG (MersenneTwister(rand(UInt128)), RandomDevice(), Xoshiro()),
775+
@testset "rand(::$(typeof(RNG)), ::UnitRange{$T}" for RNG (MersenneTwister(rand(UInt128)),
776+
RandomDevice(),
777+
Xoshiro(rand(UInt128)),
778+
SeedHasher(rand(UInt128))),
773779
T (Bool, Int8, Int16, Int32, UInt32, Int64, Int128, UInt128)
774780
if T === Bool
775781
@test rand(RNG, false:true) (false, true)
@@ -912,8 +918,11 @@ end
912918
@test rand(rng) == rand(GLOBAL_RNG)
913919
end
914920

915-
@testset "RNGs broadcast as scalars: T" for T in (MersenneTwister, RandomDevice)
916-
@test length.(rand.(T(), 1:3)) == 1:3
921+
@testset "RNGs broadcast as scalars: $(typeof(RNG))" for RNG in (MersenneTwister(0),
922+
RandomDevice(),
923+
Xoshiro(0),
924+
SeedHasher(0))
925+
@test length.(rand.(RNG, 1:3)) == 1:3
917926
end
918927

919928
@testset "generated scalar integers do not overlap" begin
@@ -1211,7 +1220,14 @@ end
12111220
end
12121221
end
12131222

1223+
12141224
@testset "seed! and hash_seed" begin
1225+
function hash_seed(seed)
1226+
ctx = SHA.SHA2_256_CTX()
1227+
Random.hash_seed(seed, ctx)
1228+
bytes2hex(SHA.digest!(ctx))
1229+
end
1230+
12151231
# Test that:
12161232
# 1) if n == m, then hash_seed(n) == hash_seed(m)
12171233
# 2) if n != m, then hash_seed(n) != hash_seed(m)
@@ -1224,12 +1240,12 @@ end
12241240
T <: Signed && push!(seeds, T(0), T(1), T(2), T(-1), T(-2))
12251241
end
12261242

1227-
vseeds = Dict{Vector{UInt8}, BigInt}()
1243+
vseeds = Dict{String, BigInt}()
12281244
for seed = seeds
12291245
bigseed = big(seed)
1230-
vseed = Random.hash_seed(bigseed)
1246+
vseed = hash_seed(bigseed)
12311247
# test property 1) above
1232-
@test Random.hash_seed(seed) == vseed
1248+
@test hash_seed(seed) == vseed
12331249
# test property 2) above
12341250
@test bigseed == get!(vseeds, vseed, bigseed)
12351251
# test that the property 1) is actually inherited by `seed!`
@@ -1241,16 +1257,16 @@ end
12411257
end
12421258

12431259
seed32 = rand(UInt32, rand(1:9))
1244-
hash32 = Random.hash_seed(seed32)
1245-
@test Random.hash_seed(map(UInt64, seed32)) == hash32
1260+
hash32 = hash_seed(seed32)
1261+
@test hash_seed(map(UInt64, seed32)) == hash32
12461262
@test hash32 keys(vseeds)
12471263

12481264
seed_str = randstring()
12491265
seed_gstr = GenericString(seed_str)
1250-
@test Random.hash_seed(seed_str) == Random.hash_seed(seed_gstr)
1251-
string_seeds = Set{Vector{UInt8}}()
1266+
@test hash_seed(seed_str) == hash_seed(seed_gstr)
1267+
string_seeds = Set{String}()
12521268
for ch = 'A':'z'
1253-
vseed = Random.hash_seed(string(ch))
1269+
vseed = hash_seed(string(ch))
12541270
@test vseed keys(vseeds)
12551271
@test vseed string_seeds
12561272
push!(string_seeds, vseed)

0 commit comments

Comments
 (0)