Skip to content

Commit ade6a97

Browse files
authored
Fix missing randn method
1 parent 8a163da commit ade6a97

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

src/metric.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ function _rand(
9999
metric::UnitEuclideanMetric{T},
100100
kinetic::GaussianKinetic,
101101
) where {T}
102-
r = randn(rng, T, size(metric)...)
102+
r = _randn(rng, T, size(metric)...)
103103
return r
104104
end
105105

@@ -108,7 +108,7 @@ function _rand(
108108
metric::DiagEuclideanMetric{T},
109109
kinetic::GaussianKinetic,
110110
) where {T}
111-
r = randn(rng, T, size(metric)...)
111+
r = _randn(rng, T, size(metric)...)
112112
r ./= metric.sqrtM⁻¹
113113
return r
114114
end
@@ -118,7 +118,7 @@ function _rand(
118118
metric::DenseEuclideanMetric{T},
119119
kinetic::GaussianKinetic,
120120
) where {T}
121-
r = randn(rng, T, size(metric)...)
121+
r = _randn(rng, T, size(metric)...)
122122
ldiv!(metric.cholM⁻¹, r)
123123
return r
124124
end

src/utilities.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
const AbstractScalarOrVec{T} = Union{T,AbstractVector{T}} where {T<:AbstractFloat}
22

3+
# Support of passing a vector of RNGs
4+
5+
function _randn(rng::AbstractRNG, ::Type{T}, dim::Int, n_chains::Int) where {T}
6+
return randn(rng, T, dim, n_chains)
7+
end
8+
function _randn(rngs::AbstractVector{<:AbstractRNG}, ::Type{T}, dim::Int, n_chains::Int) where {T}
9+
@argcheck length(rngs) == n_chains
10+
out = similar(rngs, T, dim, n_chains)
11+
for (x, rng) in zip(eachcol(out), rngs)
12+
randn!(rng, x)
13+
end
14+
return out
15+
end
16+
317
"""
418
`rand_coupled` produces coupled randomness given a vector of RNGs. For example,
519
when a vector of RNGs is provided, `rand_coupled` peforms a single `rand` call

0 commit comments

Comments
 (0)