Skip to content

Commit e903d1c

Browse files
committed
fix ldf
1 parent aac93f1 commit e903d1c

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/mcmc/hmc.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,9 @@ function DynamicPPL.initialstep(
190190
# Create a Hamiltonian.
191191
metricT = getmetricT(spl.alg)
192192
metric = metricT(length(theta))
193-
ldf = DynamicPPL.LogDensityFunction(model, vi; adtype=spl.alg.adtype)
193+
ldf = DynamicPPL.LogDensityFunction(
194+
model, DynamicPPL.getlogjoint, vi; adtype=spl.alg.adtype
195+
)
194196
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
195197
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
196198
hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func)
@@ -296,7 +298,9 @@ end
296298

297299
function get_hamiltonian(model, spl, vi, state, n)
298300
metric = gen_metric(n, spl, state)
299-
ldf = DynamicPPL.LogDensityFunction(model, vi; adtype=spl.alg.adtype)
301+
ldf = DynamicPPL.LogDensityFunction(
302+
model, DynamicPPL.getlogjoint, vi; adtype=spl.alg.adtype
303+
)
300304
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
301305
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
302306
return AHMC.Hamiltonian(metric, lp_func, lp_grad_func)

0 commit comments

Comments
 (0)