Skip to content

Commit f16a5cf

Browse files
committed
Remove context argument from LogDensityFunction
1 parent c7c4638 commit f16a5cf

File tree

9 files changed

+58
-66
lines changed

9 files changed

+58
-66
lines changed

ext/TuringDynamicHMCExt.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,11 @@ function DynamicPPL.initialstep(
5858
# Ensure that initial sample is in unconstrained space.
5959
if !DynamicPPL.islinked(vi)
6060
vi = DynamicPPL.link!!(vi, model)
61-
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
61+
vi = last(DynamicPPL.evaluate!!(model, vi))
6262
end
6363

6464
# Define log-density function.
65-
= DynamicPPL.LogDensityFunction(
66-
model,
67-
vi,
68-
DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext());
69-
adtype=spl.alg.adtype,
70-
)
65+
= DynamicPPL.LogDensityFunction(model, vi; adtype=spl.alg.adtype)
7166

7267
# Perform initial step.
7368
results = DynamicHMC.mcmc_keep_warmup(

src/mcmc/ess.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ function AbstractMCMC.step(
4949
rng,
5050
EllipticalSliceSampling.ESSModel(
5151
ESSPrior(model, spl, vi),
52-
DynamicPPL.LogDensityFunction{:LogLikelihood}(
53-
model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext())
52+
ESSLikelihood(
53+
DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, vi)
5454
),
5555
),
5656
EllipticalSliceSampling.ESS(),
@@ -63,7 +63,7 @@ function AbstractMCMC.step(
6363

6464
return Transition(model, vi), vi
6565
end
66-
66+
f
6767
# Prior distribution of considered random variable
6868
struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T}
6969
model::M
@@ -97,6 +97,10 @@ function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
9797
sampler = p.sampler
9898
varinfo = p.varinfo
9999
# TODO: Surely there's a better way of doing this now that we have `SamplingContext`?
100+
# TODO(DPPL0.37/penelopeysm): This can be replaced with `init!!(p.model,
101+
# p.varinfo, PriorInit())` after TuringLang/DynamicPPL.jl#984. The reason
102+
# why we had to use the 'del' flag before this was because
103+
# SampleFromPrior() wouldn't overwrite existing variables.
100104
vns = keys(varinfo)
101105
for vn in vns
102106
set_flag!(varinfo, vn, "del")
@@ -108,14 +112,9 @@ end
108112
# Mean of prior distribution
109113
Distributions.mean(p::ESSPrior) = p.μ
110114

111-
function DynamicPPL.tilde_assume(
112-
rng::Random.AbstractRNG, ctx::DefaultContext, ::Sampler{<:ESS}, right, vn, vi
113-
)
114-
return DynamicPPL.tilde_assume(rng, ctx, SampleFromPrior(), right, vn, vi)
115+
# Evaluate log-likelihood of proposals
116+
struct ESSLogLikelihood{M<:Model,V<:AbstractVarInfo,AD<:ADTypes.AbstractADType}
117+
ldf::DynamicPPL.LogDensityFunction{M,V,AD}
115118
end
116119

117-
function DynamicPPL.tilde_observe!!(
118-
ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vn, vi
119-
)
120-
return DynamicPPL.tilde_observe!!(ctx, SampleFromPrior(), right, left, vn, vi)
121-
end
120+
(ℓ::ESSLogLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ.ldf, f)

src/mcmc/external_sampler.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,12 @@ getlogp_external(::Any, ::Any) = missing
9696
getlogp_external(mh::AdvancedMH.Transition, ::AdvancedMH.Transition) = mh.lp
9797
getlogp_external(hmc::AdvancedHMC.Transition, ::AdvancedHMC.HMCState) = hmc.stat.log_density
9898

99-
struct TuringState{S,V1<:AbstractVarInfo,M,V,C}
99+
struct TuringState{S,V1<:AbstractVarInfo,M,V}
100100
state::S
101101
# Note that this varinfo has the correct parameters and logp obtained from
102102
# the state, whereas `ldf.varinfo` will in general have junk inside it.
103103
varinfo::V1
104-
ldf::DynamicPPL.LogDensityFunction{M,V,C}
104+
ldf::DynamicPPL.LogDensityFunction{M,V}
105105
end
106106

107107
varinfo(state::TuringState) = state.varinfo

src/mcmc/gibbs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ function setparams_varinfo!!(
559559
params::AbstractVarInfo,
560560
)
561561
logdensity = DynamicPPL.LogDensityFunction(
562-
model, state.ldf.varinfo, state.ldf.context; adtype=sampler.alg.adtype
562+
model, state.ldf.varinfo; adtype=sampler.alg.adtype
563563
)
564564
new_inner_state = setparams_varinfo!!(
565565
AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params

src/mcmc/hmc.jl

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -190,16 +190,7 @@ function DynamicPPL.initialstep(
190190
# Create a Hamiltonian.
191191
metricT = getmetricT(spl.alg)
192192
metric = metricT(length(theta))
193-
ldf = DynamicPPL.LogDensityFunction(
194-
model,
195-
vi,
196-
# TODO(penelopeysm): Can we just use leafcontext(model.context)? Do we
197-
# need to pass in the sampler? (In fact LogDensityFunction defaults to
198-
# using leafcontext(model.context) so could we just remove the argument
199-
# entirely?)
200-
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context));
201-
adtype=spl.alg.adtype,
202-
)
193+
ldf = DynamicPPL.LogDensityFunction(model, vi; adtype=spl.alg.adtype)
203194
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
204195
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
205196
hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func)
@@ -305,16 +296,7 @@ end
305296

306297
function get_hamiltonian(model, spl, vi, state, n)
307298
metric = gen_metric(n, spl, state)
308-
ldf = DynamicPPL.LogDensityFunction(
309-
model,
310-
vi,
311-
# TODO(penelopeysm): Can we just use leafcontext(model.context)? Do we
312-
# need to pass in the sampler? (In fact LogDensityFunction defaults to
313-
# using leafcontext(model.context) so could we just remove the argument
314-
# entirely?)
315-
DynamicPPL.SamplingContext(spl, DynamicPPL.leafcontext(model.context));
316-
adtype=spl.alg.adtype,
317-
)
299+
ldf = DynamicPPL.LogDensityFunction(model, vi; adtype=spl.alg.adtype)
318300
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
319301
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
320302
return AHMC.Hamiltonian(metric, lp_func, lp_grad_func)

src/mcmc/mh.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ A log density function for the MH sampler.
189189
This variant uses the `set_namedtuple!` function to update the `VarInfo`.
190190
"""
191191
const MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} =
192-
DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD}
192+
DynamicPPL.LogDensityFunction{M,V,AD} where {AD}
193193

194194
function LogDensityProblems.logdensity(f::MHLogDensityFunction, x::NamedTuple)
195195
vi = deepcopy(f.varinfo)

src/mcmc/particle_mcmc.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,9 @@ function DynamicPPL.initialstep(
206206
)
207207

208208
# Perform particle sweep.
209+
@info "Hello!"
209210
logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler, spl)
211+
@info "Goodbye!"
210212

211213
# Extract the first particle and its weight.
212214
particle = particles.vals[1]
@@ -222,6 +224,7 @@ end
222224
function AbstractMCMC.step(
223225
::AbstractRNG, model::AbstractModel, spl::Sampler{<:SMC}, state::SMCState; kwargs...
224226
)
227+
@info "helloooooo from step"
225228
# Extract the index of the current particle.
226229
index = state.particleindex
227230

src/optimisation/Optimisation.jl

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ or
169169
OptimLogDensity(model; adtype=adtype)
170170
```
171171
172+
Here, `ctx` must be a context that contains an `OptimizationContext` as its
173+
leaf.
174+
172175
If not specified, `adtype` defaults to `AutoForwardDiff()`.
173176
174177
An OptimLogDensity does not, in itself, obey the LogDensityProblems interface.
@@ -189,24 +192,40 @@ optim_ld(z) # returns -logp
189192
```
190193
"""
191194
struct OptimLogDensity{
192-
M<:DynamicPPL.Model,
193-
F<:Function,
194-
V<:DynamicPPL.AbstractVarInfo,
195-
C<:DynamicPPL.AbstractContext,
196-
AD<:ADTypes.AbstractADType,
195+
M<:DynamicPPL.Model,F<:Function,V<:DynamicPPL.AbstractVarInfo,AD<:ADTypes.AbstractADType
197196
}
198-
ldf::DynamicPPL.LogDensityFunction{M,F,V,C,AD}
199-
end
197+
ldf::DynamicPPL.LogDensityFunction{M,F,V,AD}
200198

201-
function OptimLogDensity(
202-
model::DynamicPPL.Model,
203-
getlogdensity::Function,
204-
vi::DynamicPPL.AbstractVarInfo=DynamicPPL.ldf_default_varinfo(model, getlogdensity);
205-
adtype=AutoForwardDiff(),
206-
)
207-
return OptimLogDensity(
208-
DynamicPPL.LogDensityFunction(model, getlogdensity, vi; adtype=adtype)
199+
# Inner constructors enforce that the model has an OptimizationContext as
200+
# its leaf context.
201+
function OptimLogDensity(
202+
model::DynamicPPL.Model,
203+
getlogdensity::Function,
204+
vi::DynamicPPL.VarInfo,
205+
ctx::OptimizationContext;
206+
adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE,
209207
)
208+
new_context = DynamicPPL.setleafcontext(model, ctx)
209+
new_model = contextualize(model, new_context)
210+
return new{typeof(new_model),typeof(getlogdensity),typeof(vi),typeof(adtype)}(
211+
DynamicPPL.LogDensityFunction(new_model, getlogdensity, vi; adtype=adtype)
212+
)
213+
end
214+
function OptimLogDensity(
215+
model::DynamicPPL.Model,
216+
getlogdensity::Function,
217+
ctx::OptimizationContext;
218+
adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE,
219+
)
220+
# No varinfo
221+
return OptimLogDensity(
222+
model,
223+
getlogdensity,
224+
DynamicPPL.ldf_default_varinfo(model, getlogdensity),
225+
ctx;
226+
adtype=adtype,
227+
)
228+
end
210229
end
211230

212231
"""

src/variational/VariationalInference.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,6 @@ export vi, q_locationscale, q_meanfield_gaussian, q_fullrank_gaussian
1717

1818
include("deprecated.jl")
1919

20-
function make_logdensity(model::DynamicPPL.Model)
21-
weight = 1.0
22-
ctx = DynamicPPL.MiniBatchContext(DynamicPPL.DefaultContext(), weight)
23-
return DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx)
24-
end
25-
2620
"""
2721
q_initialize_scale(
2822
[rng::Random.AbstractRNG,]
@@ -68,7 +62,7 @@ function q_initialize_scale(
6862
num_max_trials::Int=10,
6963
reduce_factor::Real=one(eltype(scale)) / 2,
7064
)
71-
prob = make_logdensity(model)
65+
prob = LogDensityFunction(model)
7266
ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob)
7367
varinfo = DynamicPPL.VarInfo(model)
7468

@@ -309,7 +303,7 @@ function vi(
309303
)
310304
return AdvancedVI.optimize(
311305
rng,
312-
make_logdensity(model),
306+
LogDensityFunction(model),
313307
objective,
314308
q,
315309
n_iterations;

0 commit comments

Comments
 (0)