Skip to content

Commit c636dfd

Browse files
authored
Extend _randn to vectors
1 parent 6498e65 commit c636dfd

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/utilities.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,17 @@ const AbstractScalarOrVec{T} = Union{T,AbstractVector{T}} where {T<:AbstractFloa
22

33
# Support of passing a vector of RNGs
44

5+
function _randn(rng::AbstractRNG, ::Type{T}, n_chains::Int) where {T}
6+
return randn(rng, T, n_chains)
7+
end
58
function _randn(rng::AbstractRNG, ::Type{T}, dim::Int, n_chains::Int) where {T}
69
return randn(rng, T, dim, n_chains)
710
end
11+
12+
function _randn(rngs::AbstractVector{<:AbstractRNG}, ::Type{T}, n_chains::Int) where {T}
13+
@argcheck length(rngs) == n_chains
14+
return map(Base.Fix2(randn, T), rngs)
15+
end
816
function _randn(
917
rngs::AbstractVector{<:AbstractRNG},
1018
::Type{T},
@@ -13,9 +21,7 @@ function _randn(
1321
) where {T}
1422
@argcheck length(rngs) == n_chains
1523
out = similar(rngs, T, dim, n_chains)
16-
for (x, rng) in zip(eachcol(out), rngs)
17-
randn!(rng, x)
18-
end
24+
map!(randn!, eachcol(out), rngs)
1925
return out
2026
end
2127

0 commit comments

Comments
 (0)