Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 14 additions & 28 deletions stdlib/Random/src/MersenneTwister.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ const MT_CACHE_I = 501 << 4 # number of bytes in the UInt128 cache

mutable struct MersenneTwister <: AbstractRNG
seed::Any
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does a MersenneTwister really need to store it's seed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes looks like this can be deleted.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's there to recreate the state of an instance from few values, so that what show outputs corresponds to a real constructor. E.g

julia> MersenneTwister("a seed")
MersenneTwister("a seed")

But I was thinking to perhaps normalize seeds into an UInt128 value and store that instead of the original seed. The example above would then look like

julia> MersenneTwister("a seed")
MersenneTwister(0xb4802685c420b29be64de36a1f90815f)

julia> MersenneTwister(0xb4802685c420b29be64de36a1f90815f)
MersenneTwister(0xb4802685c420b29be64de36a1f90815f)

This would involve an additional round with SeedHasher, but I'm precisely preparing another PR to make that much faster.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about just having a HashedSeed type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The user seed is removed in #60204, without an additional round with SeedHasher, by just storing two UInt128 values instead of one. Printing is uglier, but who cares.

state::DSFMT_state
vals::Vector{Float64}
ints::Vector{UInt128}
const state::DSFMT_state
const vals::Memory{Float64}
const ints::Vector{UInt128} # it's temporarily resized internally
idxF::Int
idxI::Int

Expand All @@ -21,25 +21,13 @@ mutable struct MersenneTwister <: AbstractRNG
adv_vals::Int64 # state of advance when vals is filled-up
adv_ints::Int64 # state of advance when ints is filled-up

function MersenneTwister(seed, state, vals, ints, idxF, idxI,
adv, adv_jump, adv_vals, adv_ints)
length(vals) == MT_CACHE_F && 0 <= idxF <= MT_CACHE_F ||
throw(DomainError((length(vals), idxF),
"`length(vals)` and `idxF` must be consistent with $MT_CACHE_F"))
length(ints) == MT_CACHE_I >> 4 && 0 <= idxI <= MT_CACHE_I ||
throw(DomainError((length(ints), idxI),
"`length(ints)` and `idxI` must be consistent with $MT_CACHE_I"))
new(seed, state, vals, ints, idxF, idxI,
adv, adv_jump, adv_vals, adv_ints)
end
global _MersenneTwister(::UndefInitializer) =
new(nothing, DSFMT_state(),
Memory{Float64}(undef, MT_CACHE_F),
Vector{UInt128}(undef, MT_CACHE_I >> 4),
MT_CACHE_F, 0, 0, Base.GMP.ZERO, -1, -1)
end

MersenneTwister(seed, state::DSFMT_state) =
MersenneTwister(seed, state,
Vector{Float64}(undef, MT_CACHE_F),
Vector{UInt128}(undef, MT_CACHE_I >> 4),
MT_CACHE_F, 0, 0, 0, -1, -1)

"""
MersenneTwister(seed)
MersenneTwister()
Expand Down Expand Up @@ -72,8 +60,7 @@ julia> x1 == x2
true
```
"""
MersenneTwister(seed=nothing) =
seed!(MersenneTwister(Vector{UInt32}(), DSFMT_state()), seed)
MersenneTwister(seed=nothing) = seed!(_MersenneTwister(undef), seed)


function copy!(dst::MersenneTwister, src::MersenneTwister)
Expand All @@ -90,10 +77,7 @@ function copy!(dst::MersenneTwister, src::MersenneTwister)
dst
end

copy(src::MersenneTwister) =
MersenneTwister(src.seed, copy(src.state), copy(src.vals), copy(src.ints),
src.idxF, src.idxI, src.adv, src.adv_jump, src.adv_vals, src.adv_ints)

copy(src::MersenneTwister) = copy!(_MersenneTwister(undef), src)

==(r1::MersenneTwister, r2::MersenneTwister) =
r1.seed == r2.seed && r1.state == r2.state &&
Expand Down Expand Up @@ -250,7 +234,7 @@ function initstate!(r::MersenneTwister, data::StridedVector, seed)
dsfmt_init_by_array(r.state, reinterpret(UInt32, data))
reset_caches!(r)
r.adv = 0
r.adv_jump = 0
r.adv_jump = Base.GMP.ZERO
return r
end

Expand Down Expand Up @@ -561,7 +545,9 @@ end
function _randjump(r::MersenneTwister, jumppoly::DSFMT.GF2X)
adv = r.adv
adv_jump = r.adv_jump
s = MersenneTwister(r.seed, DSFMT.dsfmt_jump(r.state, jumppoly))
s = _MersenneTwister(undef)
s.seed = r.seed
copy!(s.state, DSFMT.dsfmt_jump(r.state, jumppoly))
reset_caches!(s)
s.adv = adv
s.adv_jump = adv_jump
Expand Down
12 changes: 0 additions & 12 deletions stdlib/Random/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -646,18 +646,6 @@ end
# MersenneTwister initialization with invalid values
@test_throws DomainError DSFMT.DSFMT_state(zeros(Int32, rand(0:DSFMT.JN32-1)))

@test_throws DomainError MersenneTwister(zeros(UInt32, 1), DSFMT.DSFMT_state(),
zeros(Float64, 10), zeros(UInt128, MT_CACHE_I>>4), 0, 0, 0, 0, -1, -1)

@test_throws DomainError MersenneTwister(zeros(UInt32, 1), DSFMT.DSFMT_state(),
zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>4), -1, 0, 0, 0, -1, -1)

@test_throws DomainError MersenneTwister(zeros(UInt32, 1), DSFMT.DSFMT_state(),
zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>3), 0, 0, 0, 0, -1, -1)

@test_throws DomainError MersenneTwister(zeros(UInt32, 1), DSFMT.DSFMT_state(),
zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>4), 0, -1, 0, 0, -1, -1)

# seed is private to MersenneTwister
let seed = rand(UInt32, 10)
r = MersenneTwister(seed)
Expand Down