Skip to content

Commit 33f4d60

Browse files
committed
save ΔH accept to chain (again)
1 parent a07ab02 commit 33f4d60

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/sampling.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,11 +382,12 @@ end
382382
@⌛ function gibbs_sample_ϕ!(state, ds::DataSet)
383383
@unpack f°, ϕ°, θ, symp_kwargs, progress, step, nburnin_always_accept = state
384384
U = ϕ° -> lnP(:mix, f°, ϕ°, θ, ds)
385-
ϕ° = hmc_step(U, ϕ°, mass_matrix_ϕ(θ,ds); symp_kwargs, progress, always_accept=(step<nburnin_always_accept))
386-
@pack! state = ϕ°
385+
ϕ°, ΔH, accept = hmc_step(U, ϕ°, mass_matrix_ϕ(θ,ds); symp_kwargs, progress, always_accept=(step<nburnin_always_accept))
386+
@pack! state = ϕ°, ΔH, accept
387387
end
388388

389389
function hmc_step(U::Function, x, Λ; symp_kwargs, progress, always_accept)
390+
local ΔH, accept
390391
for kwargs in symp_kwargs
391392
p = simulate(Λ)
392393
(ΔH, xtest) = symplectic_integrate(
@@ -397,7 +398,7 @@ function hmc_step(U::Function, x, Λ; symp_kwargs, progress, always_accept)
397398
accept = batch(@. always_accept | (log(rand()) < $unbatch(ΔH)))
398399
@. x = accept * xtest + (1 - accept) * x
399400
end
400-
x
401+
x, ΔH, accept
401402
end
402403

403404
@⌛ function mass_matrix_ϕ(θ, ds)
@@ -434,8 +435,9 @@ end
434435
## postprocessing
435436

436437
@⌛ function gibbs_postprocess!(state, ds::DataSet)
437-
@unpack f, ϕ, θ, pbar_dict = state
438+
@unpack f, ϕ, θ, pbar_dict, ΔH = state
438439
lnP = pbar_dict["lnP"] = CMBLensing.lnP(0, select(state, lnP_arg_names(0, ds))..., ds)
440+
pbar_dict["ΔH"] = ΔH
439441
= ds.L(ϕ) * f
440442
@pack! state = f̃, lnP
441443
end

0 commit comments

Comments
 (0)