Skip to content

Commit 8fdecc0

Browse files
authored
Use accumulators to fix all logp calculations when sampling (#2630)
* Use new `getlogjoint` for optimisation * Change getlogjoint -> getlogjoint_internal where needed * Enforce re-evaluation when constructing `Transition` * fix tests * Remove extra evaluations from SGLD and SGHMC * Remove dead `transitions_from_chain` method (used to be part of `predict`) * metadata -> getstats_with_lp * Clean up some stray getlogp
1 parent 7124864 commit 8fdecc0

21 files changed

+228
-472
lines changed

ext/TuringDynamicHMCExt.jl

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ function DynamicPPL.initialstep(
6363

6464
# Define log-density function.
6565
= DynamicPPL.LogDensityFunction(
66-
model, DynamicPPL.getlogjoint, vi; adtype=spl.alg.adtype
66+
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype
6767
)
6868

6969
# Perform initial step.
@@ -73,14 +73,9 @@ function DynamicPPL.initialstep(
7373
steps = DynamicHMC.mcmc_steps(results.sampling_logdensity, results.final_warmup_state)
7474
Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q)
7575

76-
# Update the variables.
77-
vi = DynamicPPL.unflatten(vi, Q.q)
78-
# TODO(DPPL0.37/penelopeysm): This is obviously incorrect. Fix this.
79-
vi = DynamicPPL.setloglikelihood!!(vi, Q.ℓq)
80-
vi = DynamicPPL.setlogprior!!(vi, 0.0)
81-
8276
# Create first sample and state.
83-
sample = Turing.Inference.Transition(model, vi)
77+
vi = DynamicPPL.unflatten(vi, Q.q)
78+
sample = Turing.Inference.Transition(model, vi, nothing)
8479
state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ)
8580

8681
return sample, state
@@ -99,12 +94,9 @@ function AbstractMCMC.step(
9994
steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ℓ, state.stepsize)
10095
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache)
10196

102-
# Update the variables.
103-
vi = DynamicPPL.unflatten(vi, Q.q)
104-
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)
105-
10697
# Create next sample and state.
107-
sample = Turing.Inference.Transition(model, vi)
98+
vi = DynamicPPL.unflatten(vi, Q.q)
99+
sample = Turing.Inference.Transition(model, vi, nothing)
108100
newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize)
109101

110102
return sample, newstate

ext/TuringOptimExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ function Optim.optimize(
102102
options::Optim.Options=Optim.Options();
103103
kwargs...,
104104
)
105-
f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian)
105+
f = Optimisation.OptimLogDensity(model, DynamicPPL.getlogjoint)
106106
init_vals = DynamicPPL.getparams(f.ldf)
107107
optimizer = Optim.LBFGS()
108108
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
@@ -124,7 +124,7 @@ function Optim.optimize(
124124
options::Optim.Options=Optim.Options();
125125
kwargs...,
126126
)
127-
f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian)
127+
f = Optimisation.OptimLogDensity(model, DynamicPPL.getlogjoint)
128128
init_vals = DynamicPPL.getparams(f.ldf)
129129
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
130130
end
@@ -140,7 +140,7 @@ function Optim.optimize(
140140
end
141141

142142
function _map_optimize(model::DynamicPPL.Model, args...; kwargs...)
143-
f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian)
143+
f = Optimisation.OptimLogDensity(model, DynamicPPL.getlogjoint)
144144
return _optimize(f, args...; kwargs...)
145145
end
146146

src/mcmc/Inference.jl

Lines changed: 78 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ using DynamicPPL:
1717
setindex!!,
1818
push!!,
1919
setlogp!!,
20-
getlogp,
2120
getlogjoint,
21+
getlogjoint_internal,
2222
VarName,
2323
getsym,
2424
getdist,
@@ -123,71 +123,94 @@ end
123123
######################
124124
# Default Transition #
125125
######################
126-
# Default
127-
getstats(t) = nothing
126+
getstats(::Any) = NamedTuple()
128127

128+
# TODO(penelopeysm): Remove this abstract type by converting SGLDTransition,
129+
# SMCTransition, and PGTransition to Turing.Inference.Transition instead.
129130
abstract type AbstractTransition end
130131

131-
struct Transition{T,F<:AbstractFloat,S<:Union{NamedTuple,Nothing}} <: AbstractTransition
132+
struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
132133
θ::T
133-
lp::F # TODO: merge `lp` with `stat`
134-
stat::S
135-
end
136-
137-
Transition(θ, lp) = Transition(θ, lp, nothing)
138-
function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, t)
139-
θ = getparams(model, vi)
140-
lp = getlogjoint(vi)
141-
return Transition(θ, lp, getstats(t))
142-
end
134+
logprior::F
135+
loglikelihood::F
136+
stat::N
137+
138+
"""
139+
Transition(model::Model, vi::AbstractVarInfo, sampler_transition)
140+
141+
Construct a new `Turing.Inference.Transition` object using the outputs of a
142+
sampler step.
143+
144+
Here, `vi` represents a VarInfo _for which the appropriate parameters have
145+
already been set_. However, the accumulators (e.g. logp) may in general
146+
have junk contents. The role of this method is to re-evaluate `model` and
147+
thus set the accumulators to the correct values.
148+
149+
`sampler_transition` is the transition object returned by the sampler
150+
itself and is only used to extract statistics of interest.
151+
"""
152+
function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, sampler_transition)
153+
vi = DynamicPPL.setaccs!!(
154+
vi,
155+
(
156+
DynamicPPL.ValuesAsInModelAccumulator(true),
157+
DynamicPPL.LogPriorAccumulator(),
158+
DynamicPPL.LogLikelihoodAccumulator(),
159+
),
160+
)
161+
_, vi = DynamicPPL.evaluate!!(model, vi)
162+
163+
# Extract all the information we need
164+
vals_as_in_model = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values
165+
logprior = DynamicPPL.getlogprior(vi)
166+
loglikelihood = DynamicPPL.getloglikelihood(vi)
167+
168+
# Get additional statistics
169+
stats = getstats(sampler_transition)
170+
return new{typeof(vals_as_in_model),typeof(logprior),typeof(stats)}(
171+
vals_as_in_model, logprior, loglikelihood, stats
172+
)
173+
end
143174

144-
function metadata(t::Transition)
145-
stat = t.stat
146-
if stat === nothing
147-
return (lp=t.lp,)
148-
else
149-
return merge((lp=t.lp,), stat)
175+
function Transition(
176+
model::DynamicPPL.Model,
177+
untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata},
178+
sampler_transition,
179+
)
180+
# Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
181+
# much faster to convert it to a typed varinfo first, hence this method.
182+
# https://github.com/TuringLang/Turing.jl/issues/2604
183+
return Transition(model, DynamicPPL.typed_varinfo(untyped_vi), sampler_transition)
150184
end
151185
end
152186

153-
DynamicPPL.getlogjoint(t::Transition) = t.lp
154-
155-
# Metadata of VarInfo object
156-
metadata(vi::AbstractVarInfo) = (lp=getlogjoint(vi),)
187+
function getstats_with_lp(t::Transition)
188+
return merge(
189+
t.stat,
190+
(
191+
lp=t.logprior + t.loglikelihood,
192+
logprior=t.logprior,
193+
loglikelihood=t.loglikelihood,
194+
),
195+
)
196+
end
197+
function getstats_with_lp(vi::AbstractVarInfo)
198+
return (
199+
lp=DynamicPPL.getlogjoint(vi),
200+
logprior=DynamicPPL.getlogprior(vi),
201+
loglikelihood=DynamicPPL.getloglikelihood(vi),
202+
)
203+
end
157204

158205
##########################
159206
# Chain making utilities #
160207
##########################
161208

162-
"""
163-
getparams(model, t)
164-
165-
Return a named tuple of parameters.
166-
"""
167-
getparams(model, t) = t.θ
168-
function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
169-
# NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used.
170-
# Unfortunately, using `invlink` can cause issues in scenarios where the constraints
171-
# of the parameters change depending on the realizations. Hence we have to use
172-
# `values_as_in_model`, which re-runs the model and extracts the parameters
173-
# as they are seen in the model, i.e. in the constrained space. Moreover,
174-
# this means that the code below will work both of linked and invlinked `vi`.
175-
# Ref: https://github.com/TuringLang/Turing.jl/issues/2195
176-
# NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
177-
return DynamicPPL.values_as_in_model(model, true, deepcopy(vi))
178-
end
179-
function getparams(
180-
model::DynamicPPL.Model, untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata}
181-
)
182-
# values_as_in_model is unconscionably slow for untyped VarInfo. It's
183-
# much faster to convert it to a typed varinfo before calling getparams.
184-
# https://github.com/TuringLang/Turing.jl/issues/2604
185-
return getparams(model, DynamicPPL.typed_varinfo(untyped_vi))
209+
getparams(::DynamicPPL.Model, t::AbstractTransition) = t.θ
210+
function getparams(model::DynamicPPL.Model, vi::AbstractVarInfo)
211+
t = Transition(model, vi, nothing)
212+
return getparams(model, t)
186213
end
187-
function getparams(::DynamicPPL.Model, ::DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}})
188-
return Dict{VarName,Any}()
189-
end
190-
191214
function _params_to_array(model::DynamicPPL.Model, ts::Vector)
192215
names_set = OrderedSet{VarName}()
193216
# Extract the parameter names and values from each transition.
@@ -203,7 +226,6 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
203226
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
204227
mapreduce(collect, vcat, iters)
205228
end
206-
207229
nms = map(first, nms_and_vs)
208230
vs = map(last, nms_and_vs)
209231
for nm in nms
@@ -218,14 +240,9 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
218240
return names, vals
219241
end
220242

221-
function get_transition_extras(ts::AbstractVector{<:VarInfo})
222-
valmat = reshape([getlogjoint(t) for t in ts], :, 1)
223-
return [:lp], valmat
224-
end
225-
226243
function get_transition_extras(ts::AbstractVector)
227-
# Extract all metadata.
228-
extra_data = map(metadata, ts)
244+
# Extract stats + log probabilities from each transition or VarInfo
245+
extra_data = map(getstats_with_lp, ts)
229246
return names_values(extra_data)
230247
end
231248

@@ -334,7 +351,7 @@ function AbstractMCMC.bundle_samples(
334351
vals = map(values(sym_to_vns)) do vns
335352
map(Base.Fix1(getindex, params), vns)
336353
end
337-
return merge(NamedTuple(zip(keys(sym_to_vns), vals)), metadata(t))
354+
return merge(NamedTuple(zip(keys(sym_to_vns), vals)), getstats_with_lp(t))
338355
end
339356
end
340357

@@ -396,84 +413,4 @@ function DynamicPPL.get_matching_type(
396413
return Array{T,N}
397414
end
398415

399-
##############
400-
# Utilities #
401-
##############
402-
403-
"""
404-
405-
transitions_from_chain(
406-
[rng::AbstractRNG,]
407-
model::Model,
408-
chain::MCMCChains.Chains;
409-
sampler = DynamicPPL.SampleFromPrior()
410-
)
411-
412-
Execute `model` conditioned on each sample in `chain`, and return resulting transitions.
413-
414-
The returned transitions are represented in a `Vector{<:Turing.Inference.Transition}`.
415-
416-
# Details
417-
418-
In a bit more detail, the process is as follows:
419-
1. For every `sample` in `chain`
420-
1. For every `variable` in `sample`
421-
1. Set `variable` in `model` to its value in `sample`
422-
2. Execute `model` with variables fixed as above, sampling variables NOT present
423-
in `chain` using `SampleFromPrior`
424-
3. Return sampled variables and log-joint
425-
426-
# Example
427-
```julia-repl
428-
julia> using Turing
429-
430-
julia> @model function demo()
431-
m ~ Normal(0, 1)
432-
x ~ Normal(m, 1)
433-
end;
434-
435-
julia> m = demo();
436-
437-
julia> chain = Chains(randn(2, 1, 1), ["m"]); # 2 samples of `m`
438-
439-
julia> transitions = Turing.Inference.transitions_from_chain(m, chain);
440-
441-
julia> [Turing.Inference.getlogjoint(t) for t in transitions] # extract the logjoints
442-
2-element Array{Float64,1}:
443-
-3.6294991938628374
444-
-2.5697948166987845
445-
446-
julia> [first(t.θ.x) for t in transitions] # extract samples for `x`
447-
2-element Array{Array{Float64,1},1}:
448-
[-2.0844148956440796]
449-
[-1.704630494695469]
450-
```
451-
"""
452-
function transitions_from_chain(
453-
model::DynamicPPL.Model, chain::MCMCChains.Chains; kwargs...
454-
)
455-
return transitions_from_chain(Random.default_rng(), model, chain; kwargs...)
456-
end
457-
458-
function transitions_from_chain(
459-
rng::Random.AbstractRNG,
460-
model::DynamicPPL.Model,
461-
chain::MCMCChains.Chains;
462-
sampler=DynamicPPL.SampleFromPrior(),
463-
)
464-
vi = Turing.VarInfo(model)
465-
466-
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
467-
transitions = map(iters) do (sample_idx, chain_idx)
468-
# Set variables present in `chain` and mark those NOT present in chain to be resampled.
469-
DynamicPPL.setval_and_resample!(vi, chain, sample_idx, chain_idx)
470-
model(rng, vi, sampler)
471-
472-
# Convert `VarInfo` into `NamedTuple` and save.
473-
Transition(model, vi)
474-
end
475-
476-
return transitions
477-
end
478-
479416
end # module

src/mcmc/emcee.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,14 @@ function AbstractMCMC.step(
6565
end
6666

6767
# Compute initial transition and states.
68-
transition = map(Base.Fix1(Transition, model), vis)
68+
transition = [Transition(model, vi, nothing) for vi in vis]
6969

7070
# TODO: Make compatible with immutable `AbstractVarInfo`.
7171
state = EmceeState(
7272
vis[1],
7373
map(vis) do vi
7474
vi = DynamicPPL.link!!(vi, model)
75-
AMH.Transition(vi[:], DynamicPPL.getlogjoint(vi), false)
75+
AMH.Transition(vi[:], DynamicPPL.getlogjoint_internal(vi), false)
7676
end,
7777
)
7878

@@ -87,18 +87,17 @@ function AbstractMCMC.step(
8787
densitymodel = AMH.DensityModel(
8888
Base.Fix1(
8989
LogDensityProblems.logdensity,
90-
DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint, vi),
90+
DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi),
9191
),
9292
)
9393

9494
# Compute the next states.
95-
states = last(AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states))
95+
t, states = AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states)
9696

9797
# Compute the next transition and state.
9898
transition = map(states) do _state
9999
vi = DynamicPPL.unflatten(vi, _state.params)
100-
t = Transition(getparams(model, vi), _state.lp)
101-
return t
100+
return Transition(model, vi, t)
102101
end
103102
newstate = EmceeState(vi, states)
104103

src/mcmc/ess.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function DynamicPPL.initialstep(
3131
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
3232
error("ESS only supports Gaussian prior distributions")
3333
end
34-
return Transition(model, vi), vi
34+
return Transition(model, vi, nothing), vi
3535
end
3636

3737
function AbstractMCMC.step(
@@ -56,7 +56,7 @@ function AbstractMCMC.step(
5656
vi = DynamicPPL.unflatten(vi, sample)
5757
vi = DynamicPPL.setloglikelihood!!(vi, state.loglikelihood)
5858

59-
return Transition(model, vi), vi
59+
return Transition(model, vi, nothing), vi
6060
end
6161

6262
# Prior distribution of considered random variable

0 commit comments

Comments
 (0)