Skip to content

Commit f2889a0

Browse files
committed
Merge remote-tracking branch 'origin/torfjelde/RAM' into torfjelde/RAM
2 parents 56ec717 + da431b4 commit f2889a0

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

src/RobustAdaptiveMetropolis.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ julia> # Set the seed so get some consistency.
5858
julia> # Sample!
5959
chain = sample(model, RAM(), 10_000; chain_type=Chains, num_warmup, progress=false, initial_params=zeros(2));
6060
61-
julia> norm(cov(Array(chain)) - [1.0 0.5; 0.5 1.0]) < 0.2
61+
julia> isapprox(cov(Array(chain)), model.A; rtol = 0.2)
6262
true
6363
```
6464
@@ -134,15 +134,14 @@ function step_inner(
134134

135135
# Sample the proposal.
136136
x = state.x
137-
U = randn(rng, d)
138-
x_new = x + state.S * U
137+
U = randn(rng, eltype(x), d)
138+
x_new = muladd(state.S, U, x)
139139

140140
# Compute the acceptance probability.
141141
lp = state.logprob
142142
lp_new = LogDensityProblems.logdensity(f, x_new)
143143
logα = min(lp_new - lp, zero(lp)) # `min` because we'll use this for updating
144-
# TODO: use `randexp` instead.
145-
isaccept = log(rand(rng)) < logα
144+
isaccept = randexp(rng) > -logα
146145

147146
return x_new, lp_new, U, logα, isaccept
148147
end
@@ -177,13 +176,14 @@ function AbstractMCMC.step(
177176
d = LogDensityProblems.dimension(f)
178177

179178
# Initial parameter state.
180-
x = initial_params === nothing ? rand(rng, d) : initial_params
179+
T = initial_params === nothing ? eltype(sampler.γ) : Base.promote_type(eltype(sampler.γ), eltype(initial_params))
180+
x = initial_params === nothing ? rand(rng, T, d) : convert(AbstractVector{T}, initial_params)
181181
# Initialize the Cholesky factor of the covariance matrix.
182-
S = LowerTriangular(sampler.S === nothing ? diagm(0 => ones(eltype(sampler.γ), d)) : sampler.S)
182+
S = LowerTriangular(sampler.S === nothing ? diagm(0 => ones(T, d)) : convert(AbstractMatrix{T}, sampler.S))
183183

184-
# Constuct the initial state.
184+
# Construct the initial state.
185185
lp = LogDensityProblems.logdensity(f, x)
186-
state = RAMState(x, lp, S, 0.0, 0, 1, true)
186+
state = RAMState(x, lp, S, zero(T), 0, 1, true)
187187

188188
return AdvancedMH.Transition(x, lp, true), state
189189
end
@@ -207,7 +207,7 @@ function valid_eigenvalues(S, lower_bound, upper_bound)
207207
(lower_bound == 0 && upper_bound == Inf) && return true
208208
# Note that this is just the diagonal when `S` is triangular.
209209
eigenvals = LinearAlgebra.eigvals(S)
210-
return all(lower_bound .<= eigenvals .<= upper_bound)
210+
return all(x -> lower_bound <= x <= upper_bound, eigenvals)
211211
end
212212

213213
function AbstractMCMC.step_warmup(

0 commit comments

Comments
 (0)