Skip to content

Commit 595682c

Browse files
committed
Real tests
1 parent cd53a35 commit 595682c

File tree

4 files changed

+9
-5
lines changed

4 files changed

+9
-5
lines changed

src/abstractmcmc.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,14 +290,15 @@ function AbstractMCMC.step(
290290
logdensity_and_gradient = Base.Fix1(
291291
LogDensityProblems.logdensity_and_gradient, model.logdensity
292292
)
293-
θ = t_old.z.θ
293+
θ = copy(t_old.z.θ)
294294
grad = last(logdensity_and_gradient(θ))
295295

296296
stepsize = spl.stepsize(i)
297297
θ .+= (stepsize / 2) .* grad .+ sqrt(stepsize) .* randn(rng, eltype(θ), length(θ))
298298

299299
# Make new transition.
300-
t = transition(rng, h, κ, t_old.z)
300+
z = phasepoint(h, θ, t_old.z.r)
301+
t = transition(rng, h, κ, z)
301302

302303
# Adapt h and spl.
303304
tstat = stat(t)

src/constructors.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,5 @@ end
198198
function SGLD(stepsize, n_leapfrog; integrator=:leapfrog, metric=:diagonal)
199199
return SGLD(stepsize, n_leapfrog, integrator, metric)
200200
end
201+
202+
sampler_eltype(sampler::SGLD) = eltype(sampler.stepsize)

src/utilities.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ struct PolynomialStepsize{T<:Real}
115115
return new{T}(a, b, γ)
116116
end
117117
end
118+
Base.eltype(p::PolynomialStepsize{T}) where {T} = T
118119

119120
"""
120121
PolynomialStepsize(a[, b=0, γ=0.55])

test/abstractmcmc.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ using Statistics: mean
115115
samples_sgld = AbstractMCMC.sample(
116116
rng,
117117
model,
118-
hmc,
118+
sgld,
119119
n_adapts + n_samples;
120120
n_adapts=n_adapts,
121121
initial_params=θ_init,
@@ -129,11 +129,11 @@ using Statistics: mean
129129
for t in samples_sgld
130130
t.z.θ .= invlink_gdemo(t.z.θ)
131131
end
132-
m_est_hmc = mean(samples_sgld) do t
132+
m_est_sgld = mean(samples_sgld) do t
133133
t.z.θ
134134
end
135135

136-
@test m_est_hmc [49 / 24, 7 / 6] atol = RNDATOL
136+
@test m_est_sgld [49 / 24, 7 / 6] atol = RNDATOL
137137

138138
samples_custom = AbstractMCMC.sample(
139139
rng,

0 commit comments

Comments
 (0)