Skip to content

Commit b70761f

Browse files
authored
allow rand! with explicit SIMD to be used for various dense arrays (JuliaLang#57101)
1 parent cb55389 commit b70761f

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

stdlib/Random/src/XoshiroSimd.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,20 +292,21 @@ end
292292
return i
293293
end
294294

295+
const MutableDenseArray = Union{Base.MutableDenseArrayType{T}, UnsafeView{T}} where {T}
295296

296-
function rand!(rng::Union{TaskLocalRNG, Xoshiro}, dst::Array{T}, ::SamplerTrivial{CloseOpen01{T}}) where {T<:Union{Float16,Float32,Float64}}
297+
function rand!(rng::Union{TaskLocalRNG, Xoshiro}, dst::MutableDenseArray{T}, ::SamplerTrivial{CloseOpen01{T}}) where {T<:Union{Float16,Float32,Float64}}
297298
GC.@preserve dst xoshiro_bulk(rng, convert(Ptr{UInt8}, pointer(dst)), length(dst)*sizeof(T), T, xoshiroWidth(), _bits2float)
298299
dst
299300
end
300301

301302
for T in BitInteger_types
302-
@eval function rand!(rng::Union{TaskLocalRNG, Xoshiro}, dst::Union{Array{$T}, UnsafeView{$T}}, ::SamplerType{$T})
303+
@eval function rand!(rng::Union{TaskLocalRNG, Xoshiro}, dst::MutableDenseArray{$T}, ::SamplerType{$T})
303304
GC.@preserve dst xoshiro_bulk(rng, convert(Ptr{UInt8}, pointer(dst)), length(dst)*sizeof($T), UInt8, xoshiroWidth())
304305
dst
305306
end
306307
end
307308

308-
function rand!(rng::Union{TaskLocalRNG, Xoshiro}, dst::Array{Bool}, ::SamplerType{Bool})
309+
function rand!(rng::Union{TaskLocalRNG, Xoshiro}, dst::MutableDenseArray{Bool}, ::SamplerType{Bool})
309310
GC.@preserve dst xoshiro_bulk(rng, convert(Ptr{UInt8}, pointer(dst)), length(dst), Bool, xoshiroWidth())
310311
dst
311312
end

stdlib/Random/test/runtests.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,9 +370,10 @@ for rng in ([], [MersenneTwister(0)], [RandomDevice()], [Xoshiro()])
370370
a8 = rand!(rng..., GenericArray{T}(undef, 2, 3), cc) ::GenericArray{T, 2}
371371
a9 = rand!(rng..., OffsetArray(Array{T}(undef, 5), 9), cc) ::OffsetArray{T, 1}
372372
a10 = rand!(rng..., OffsetArray(Array{T}(undef, 2, 3), (-2, 4)), cc) ::OffsetArray{T, 2}
373+
a11 = rand!(rng..., Memory{T}(undef, 5), cc) ::Memory{T}
373374
@test size(a1) == (5,)
374375
@test size(a2) == size(a3) == (2, 3)
375-
for a in [a0, a1..., a2..., a3..., a4..., a5..., a6..., a7..., a8..., a9..., a10...]
376+
for a in [a0, a1..., a2..., a3..., a4..., a5..., a6..., a7..., a8..., a9..., a10..., a11...]
376377
if C isa Type
377378
@test a isa C
378379
else
@@ -392,6 +393,7 @@ for rng in ([], [MersenneTwister(0)], [RandomDevice()], [Xoshiro()])
392393
(T <: Tuple || T <: Pair) && continue
393394
X = T == Bool ? T[0,1] : T[0,1,2]
394395
for A in (Vector{T}(undef, 5),
396+
Memory{T}(undef, 5),
395397
Matrix{T}(undef, 2, 3),
396398
GenericArray{T}(undef, 5),
397399
GenericArray{T}(undef, 2, 3),

0 commit comments

Comments
 (0)