Skip to content

Commit d7a46e1

Browse files
committed
Change getlogjoint -> getlogjoint_internal where needed
1 parent 26b8bff commit d7a46e1

File tree

11 files changed

+28
-31
lines changed

11 files changed

+28
-31
lines changed

ext/TuringDynamicHMCExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ function DynamicPPL.initialstep(
6363

6464
# Define log-density function.
6565
= DynamicPPL.LogDensityFunction(
66-
model, DynamicPPL.getlogjoint, vi; adtype=spl.alg.adtype
66+
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype
6767
)
6868

6969
# Perform initial step.

src/mcmc/Inference.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ using DynamicPPL:
1919
setlogp!!,
2020
getlogp,
2121
getlogjoint,
22+
getlogjoint_internal,
2223
VarName,
2324
getsym,
2425
getdist,
@@ -136,11 +137,13 @@ end
136137

137138
Transition(θ, lp) = Transition(θ, lp, nothing)
138139
function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, t)
140+
# TODO(DPPL0.37/penelopeysm): Fix this
139141
θ = getparams(model, vi)
140-
lp = getlogjoint(vi)
142+
lp = getlogjoint_internal(vi)
141143
return Transition(θ, lp, getstats(t))
142144
end
143145

146+
# TODO(DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
144147
function metadata(t::Transition)
145148
stat = t.stat
146149
if stat === nothing
@@ -150,9 +153,11 @@ function metadata(t::Transition)
150153
end
151154
end
152155

156+
# TODO(DPPL0.37/penelopeysm): Fix this
153157
DynamicPPL.getlogjoint(t::Transition) = t.lp
154158

155159
# Metadata of VarInfo object
160+
# TODO(DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
156161
metadata(vi::AbstractVarInfo) = (lp=getlogjoint(vi),)
157162

158163
##########################

src/mcmc/emcee.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ function AbstractMCMC.step(
7272
vis[1],
7373
map(vis) do vi
7474
vi = DynamicPPL.link!!(vi, model)
75-
AMH.Transition(vi[:], DynamicPPL.getlogjoint(vi), false)
75+
AMH.Transition(vi[:], DynamicPPL.getlogjoint_internal(vi), false)
7676
end,
7777
)
7878

@@ -87,7 +87,7 @@ function AbstractMCMC.step(
8787
densitymodel = AMH.DensityModel(
8888
Base.Fix1(
8989
LogDensityProblems.logdensity,
90-
DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint, vi),
90+
DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi),
9191
),
9292
)
9393

src/mcmc/external_sampler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ function AbstractMCMC.step(
163163

164164
# Construct LogDensityFunction
165165
f = DynamicPPL.LogDensityFunction(
166-
model, DynamicPPL.getlogjoint, varinfo; adtype=alg.adtype
166+
model, DynamicPPL.getlogjoint_internal, varinfo; adtype=alg.adtype
167167
)
168168

169169
# Then just call `AbstractMCMC.step` with the right arguments.

src/mcmc/gibbs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ function setparams_varinfo!!(
568568
params::AbstractVarInfo,
569569
)
570570
logdensity = DynamicPPL.LogDensityFunction(
571-
model, DynamicPPL.getlogjoint, state.ldf.varinfo; adtype=sampler.alg.adtype
571+
model, DynamicPPL.getlogjoint_internal, state.ldf.varinfo; adtype=sampler.alg.adtype
572572
)
573573
new_inner_state = setparams_varinfo!!(
574574
AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params

src/mcmc/hmc.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ function DynamicPPL.initialstep(
193193
metricT = getmetricT(spl.alg)
194194
metric = metricT(length(theta))
195195
ldf = DynamicPPL.LogDensityFunction(
196-
model, DynamicPPL.getlogjoint, vi; adtype=spl.alg.adtype
196+
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype
197197
)
198198
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
199199
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
@@ -308,7 +308,7 @@ end
308308
function get_hamiltonian(model, spl, vi, state, n)
309309
metric = gen_metric(n, spl, state)
310310
ldf = DynamicPPL.LogDensityFunction(
311-
model, DynamicPPL.getlogjoint, vi; adtype=spl.alg.adtype
311+
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype
312312
)
313313
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
314314
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)

src/mcmc/mh.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ function propose!!(
304304

305305
# Create a sampler and the previous transition.
306306
mh_sampler = AMH.MetropolisHastings(dt)
307-
prev_trans = AMH.Transition(vt, DynamicPPL.getlogjoint(vi), false)
307+
prev_trans = AMH.Transition(vt, DynamicPPL.getlogjoint_internal(vi), false)
308308

309309
# Make a new transition.
310310
spl_model = DynamicPPL.contextualize(
@@ -313,7 +313,7 @@ function propose!!(
313313
densitymodel = AMH.DensityModel(
314314
Base.Fix1(
315315
LogDensityProblems.logdensity,
316-
DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint, vi),
316+
DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint_internal, vi),
317317
),
318318
)
319319
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)
@@ -341,7 +341,7 @@ function propose!!(
341341

342342
# Create a sampler and the previous transition.
343343
mh_sampler = AMH.MetropolisHastings(spl.alg.proposals)
344-
prev_trans = AMH.Transition(vals, DynamicPPL.getlogjoint(vi), false)
344+
prev_trans = AMH.Transition(vals, DynamicPPL.getlogjoint_internal(vi), false)
345345

346346
# Make a new transition.
347347
spl_model = DynamicPPL.contextualize(
@@ -350,7 +350,7 @@ function propose!!(
350350
densitymodel = AMH.DensityModel(
351351
Base.Fix1(
352352
LogDensityProblems.logdensity,
353-
DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint, vi),
353+
DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint_internal, vi),
354354
),
355355
)
356356
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)

src/mcmc/particle_mcmc.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ end
147147

148148
function SMCTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, weight)
149149
theta = getparams(model, vi)
150-
lp = DynamicPPL.getlogjoint(vi)
150+
lp = DynamicPPL.getlogjoint_internal(vi)
151151
return SMCTransition(theta, lp, weight)
152152
end
153153

@@ -323,7 +323,7 @@ varinfo(state::PGState) = state.vi
323323

324324
function PGTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, logevidence)
325325
theta = getparams(model, vi)
326-
lp = DynamicPPL.getlogjoint(vi)
326+
lp = DynamicPPL.getlogjoint_internal(vi)
327327
return PGTransition(theta, lp, logevidence)
328328
end
329329

src/mcmc/sghmc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ end
200200

201201
function SGLDTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, stepsize)
202202
theta = getparams(model, vi)
203-
lp = DynamicPPL.getlogjoint(vi)
203+
lp = DynamicPPL.getlogjoint_internal(vi)
204204
return SGLDTransition(theta, lp, stepsize)
205205
end
206206

test/mcmc/external_sampler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ function initialize_nuts(model::DynamicPPL.Model)
2121

2222
# Create a LogDensityFunction
2323
f = DynamicPPL.LogDensityFunction(
24-
model, DynamicPPL.getlogjoint, linked_vi; adtype=Turing.DEFAULT_ADTYPE
24+
model, DynamicPPL.getlogjoint_internal, linked_vi; adtype=Turing.DEFAULT_ADTYPE
2525
)
2626

2727
# Choose parameter dimensionality and initial parameter value

0 commit comments

Comments
 (0)