Skip to content

Commit 9f482f3

Browse files
devmotiontorfjelde
andauthored
Unify log density function types (#1846)
* Unify log density function types * Some fixes * More fixes * Some more fixes * Another fix * Apply suggestions from code review Co-authored-by: Tor Erlend Fjelde <[email protected]> * Update src/contrib/inference/dynamichmc.jl * Update OptimInterface.jl * Fix implementation of Optim interface * Update ModeEstimation.jl * Fix tests * Update mh.jl Co-authored-by: Tor Erlend Fjelde <[email protected]>
1 parent bab91b3 commit 9f482f3

File tree

9 files changed

+88
-138
lines changed

9 files changed

+88
-138
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.21.6"
3+
version = "0.21.7"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -37,7 +37,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3737
[compat]
3838
AbstractMCMC = "4"
3939
AdvancedHMC = "0.3.0"
40-
AdvancedMH = "0.6"
40+
AdvancedMH = "0.6.8"
4141
AdvancedPS = "0.3.4"
4242
AdvancedVI = "0.1"
4343
BangBang = "0.3"

src/contrib/inference/dynamichmc.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@ DynamicNUTS{AD}(space::Symbol...) where AD = DynamicNUTS{AD, space}()
1919

2020
DynamicPPL.getspace(::DynamicNUTS{<:Any, space}) where {space} = space
2121

22-
struct DynamicHMCLogDensity{M<:Model,S<:Sampler{<:DynamicNUTS},V<:AbstractVarInfo}
23-
model::M
24-
sampler::S
25-
varinfo::V
26-
end
22+
# Only define traits for `DynamicNUTS` sampler to avoid type piracy and surprises
23+
# TODO: Implement generally with `LogDensityProblems`
24+
const DynamicHMCLogDensity{M<:Model,S<:Sampler{<:DynamicNUTS},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,S,DynamicPPL.DefaultContext}
2725

2826
function DynamicHMC.dimension(ℓ::DynamicHMCLogDensity)
2927
return length(ℓ.varinfo[ℓ.sampler])
@@ -37,7 +35,7 @@ function DynamicHMC.logdensity_and_gradient(
3735
::DynamicHMCLogDensity,
3836
x::AbstractVector,
3937
)
40-
return gradient_logp(x, ℓ.varinfo, ℓ.model, ℓ.sampler)
38+
return gradient_logp(x, ℓ.varinfo, ℓ.model, ℓ.sampler, ℓ.context)
4139
end
4240

4341
"""
@@ -64,7 +62,7 @@ function gibbs_state(
6462
varinfo::AbstractVarInfo,
6563
)
6664
# Update the previous evaluation.
67-
= DynamicHMCLogDensity(model, spl, varinfo)
65+
= Turing.LogDensityFunction(varinfo, model, spl, DynamicPPL.DefaultContext())
6866
Q = DynamicHMC.evaluate_ℓ(ℓ, varinfo[spl])
6967
return DynamicNUTSState(varinfo, Q, state.metric, state.stepsize)
7068
end
@@ -87,7 +85,7 @@ function DynamicPPL.initialstep(
8785
# Perform initial step.
8886
results = DynamicHMC.mcmc_keep_warmup(
8987
rng,
90-
DynamicHMCLogDensity(model, spl, vi),
88+
Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()),
9189
0;
9290
initialization = (q = vi[spl],),
9391
reporter = DynamicHMC.NoProgressReport(),
@@ -115,7 +113,7 @@ function AbstractMCMC.step(
115113
)
116114
# Compute next sample.
117115
vi = state.vi
118-
= DynamicHMCLogDensity(model, spl, vi)
116+
= Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())
119117
steps = DynamicHMC.mcmc_steps(
120118
rng,
121119
DynamicHMC.NUTS(),

src/inference/emcee.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ function AbstractMCMC.step(
7474
)
7575
# Generate a log joint function.
7676
vi = state.vi
77-
densitymodel = AMH.DensityModel(gen_logπ(vi, SampleFromPrior(), model))
77+
densitymodel = AMH.DensityModel(Turing.LogDensityFunction(vi, model, SampleFromPrior(), DynamicPPL.DefaultContext()))
7878

7979
# Compute the next states.
8080
states = last(AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states))

src/inference/ess.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ function AbstractMCMC.step(
6464
sample, state = AbstractMCMC.step(
6565
rng,
6666
EllipticalSliceSampling.ESSModel(
67-
ESSPrior(model, spl, vi), ESSLogLikelihood(model, spl, vi),
67+
ESSPrior(model, spl, vi), Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()),
6868
),
6969
EllipticalSliceSampling.ESS(),
7070
oldstate,
@@ -124,13 +124,9 @@ end
124124
Distributions.mean(p::ESSPrior) = p.μ
125125

126126
# Evaluate log-likelihood of proposals
127-
struct ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo}
128-
model::M
129-
sampler::S
130-
varinfo::V
131-
end
127+
const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,S,DynamicPPL.DefaultContext()}
132128

133-
function (ℓ::ESSLogLikelihood)(f)
129+
function (ℓ::ESSLogLikelihood)(f::AbstractVector)
134130
sampler =.sampler
135131
varinfo = setindex!!(ℓ.varinfo, f, sampler)
136132
varinfo = last(DynamicPPL.evaluate!!(ℓ.model, varinfo, sampler))

src/inference/hmc.jl

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ function DynamicPPL.initialstep(
160160
metricT = getmetricT(spl.alg)
161161
metric = metricT(length(theta))
162162
∂logπ∂θ = gen_∂logπ∂θ(vi, spl, model)
163-
logπ = gen_logπ(vi, spl, model)
163+
logπ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())
164164
hamiltonian = AHMC.Hamiltonian(metric, logπ, ∂logπ∂θ)
165165

166166
# Compute phase point z.
@@ -262,7 +262,7 @@ end
262262

263263
function get_hamiltonian(model, spl, vi, state, n)
264264
metric = gen_metric(n, spl, state)
265-
ℓπ = gen_logπ(vi, spl, model)
265+
ℓπ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())
266266
∂ℓπ∂θ = gen_∂logπ∂θ(vi, spl, model)
267267
return AHMC.Hamiltonian(metric, ℓπ, ∂ℓπ∂θ)
268268
end
@@ -435,28 +435,6 @@ function gen_∂logπ∂θ(vi, spl::Sampler, model)
435435
return ∂logπ∂θ
436436
end
437437

438-
"""
439-
gen_logπ(vi, spl::Sampler, model)
440-
441-
Generate a function that takes `θ` and returns logpdf at `θ` for the model specified by
442-
`(vi, spl, model)`.
443-
"""
444-
function gen_logπ(vi_base, spl::AbstractSampler, model)
445-
function logπ(x)::Float64
446-
vi = vi_base
447-
x_old, lj_old = vi[spl], getlogp(vi)
448-
vi = setindex!!(vi, x, spl)
449-
vi = last(DynamicPPL.evaluate!!(model, vi, spl))
450-
lj = getlogp(vi)
451-
# Don't really need to capture these will only be
452-
# necessary if `vi` is indeed mutable.
453-
setindex!!(vi, x_old, spl)
454-
setlogp!!(vi, lj_old)
455-
return lj
456-
end
457-
return logπ
458-
end
459-
460438
gen_metric(dim::Int, spl::Sampler{<:Hamiltonian}, state) = AHMC.UnitEuclideanMetric(dim)
461439
function gen_metric(dim::Int, spl::Sampler{<:AdaptiveHamiltonian}, state)
462440
return AHMC.renew(state.hamiltonian.metric, AHMC.getM⁻¹(state.adaptor.pc))
@@ -567,7 +545,7 @@ function HMCState(
567545

568546
# Get the initial log pdf and gradient functions.
569547
∂logπ∂θ = gen_∂logπ∂θ(vi, spl, model)
570-
logπ = gen_logπ(vi, spl, model)
548+
logπ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())
571549

572550
# Get the metric type.
573551
metricT = getmetricT(spl.alg)

src/inference/mh.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -242,19 +242,15 @@ A log density function for the MH sampler.
242242
243243
This variant uses the `set_namedtuple!` function to update the `VarInfo`.
244244
"""
245-
struct MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} <: Function # Relax AMH.DensityModel?
246-
model::M
247-
sampler::S
248-
vi::V
249-
end
245+
const MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,S,DynamicPPL.DefaultContext}
250246

251-
function (f::MHLogDensityFunction)(x)
247+
function (f::MHLogDensityFunction)(x::NamedTuple)
252248
sampler = f.sampler
253-
vi = f.vi
249+
vi = f.varinfo
254250

255251
x_old, lj_old = vi[sampler], getlogp(vi)
256252
set_namedtuple!(vi, x)
257-
vi_new = last(DynamicPPL.evaluate!!(f.model, vi, DynamicPPL.DefaultContext()))
253+
vi_new = last(DynamicPPL.evaluate!!(f.model, vi, f.context))
258254
lj = getlogp(vi_new)
259255

260256
# Reset old `vi`.
@@ -376,7 +372,7 @@ function propose!(
376372
prev_trans = AMH.Transition(vt, getlogp(vi))
377373

378374
# Make a new transition.
379-
densitymodel = AMH.DensityModel(MHLogDensityFunction(model, spl, vi))
375+
densitymodel = AMH.DensityModel(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()))
380376
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)
381377

382378
# TODO: Make this compatible with immutable `VarInfo`.
@@ -404,7 +400,7 @@ function propose!(
404400
prev_trans = AMH.Transition(vals, getlogp(vi))
405401

406402
# Make a new transition.
407-
densitymodel = AMH.DensityModel(gen_logπ(vi, spl, model))
403+
densitymodel = AMH.DensityModel(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()))
408404
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)
409405

410406
# TODO: Make this compatible with immutable `VarInfo`.

src/modes/ModeEstimation.jl

Lines changed: 37 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -78,74 +78,57 @@ end
7878
"""
7979
OptimLogDensity{M<:Model,C<:Context,V<:VarInfo}
8080
81-
A struct that stores the log density function of a `DynamicPPL` model.
81+
A struct that stores the negative log density function of a `DynamicPPL` model.
8282
"""
83-
struct OptimLogDensity{M<:Model,C<:AbstractContext,V<:VarInfo}
84-
"A `DynamicPPL.Model` constructed either with the `@model` macro or manually."
85-
model::M
86-
"A `DynamicPPL.AbstractContext` used to evaluate the model. `LikelihoodContext` or `DefaultContext` are typical for MAP/MLE."
87-
context::C
88-
"A `DynamicPPL.VarInfo` struct that will be used to update model parameters."
89-
vi::V
90-
end
83+
const OptimLogDensity{M<:Model,C<:OptimizationContext,V<:VarInfo} = Turing.LogDensityFunction{V,M,DynamicPPL.SampleFromPrior,C}
9184

9285
"""
93-
OptimLogDensity(model::Model, context::AbstractContext)
86+
OptimLogDensity(model::Model, context::OptimizationContext)
9487
9588
Create a callable `OptimLogDensity` struct that evaluates a model using the given `context`.
9689
"""
97-
function OptimLogDensity(model::Model, context::AbstractContext)
90+
function OptimLogDensity(model::Model, context::OptimizationContext)
9891
init = VarInfo(model)
99-
return OptimLogDensity(model, context, init)
92+
return Turing.LogDensityFunction(init, model, DynamicPPL.SampleFromPrior(), context)
10093
end
10194

10295
"""
10396
(f::OptimLogDensity)(z)
10497
105-
Evaluate the log joint (with `DefaultContext`) or log likelihood (with `LikelihoodContext`)
98+
Evaluate the negative log joint (with `DefaultContext`) or log likelihood (with `LikelihoodContext`)
10699
at the array `z`.
107100
"""
108-
function (f::OptimLogDensity)(z)
109-
spl = DynamicPPL.SampleFromPrior()
110-
111-
varinfo = DynamicPPL.VarInfo(f.vi, spl, z)
112-
f.model(varinfo, spl, f.context)
113-
return -DynamicPPL.getlogp(varinfo)
101+
function (f::OptimLogDensity)(z::AbstractVector)
102+
sampler = f.sampler
103+
varinfo = DynamicPPL.VarInfo(f.varinfo, sampler, z)
104+
return -getlogp(last(DynamicPPL.evaluate!!(f.model, varinfo, sampler, f.context)))
114105
end
115106

116-
function (f::OptimLogDensity)(F, G, H, z)
117-
# Throw an error if a second order method was used.
118-
if H !== nothing
119-
error("Second order optimization is not yet supported.")
120-
end
121-
122-
spl = DynamicPPL.SampleFromPrior()
123-
107+
function (f::OptimLogDensity)(F, G, z)
124108
if G !== nothing
125-
# Calculate log joint and the gradient
126-
l, g = Turing.gradient_logp(
109+
# Calculate negative log joint and its gradient.
110+
sampler = f.sampler
111+
neglogp, ∇neglogp = Turing.gradient_logp(
127112
z,
128-
DynamicPPL.VarInfo(f.vi, spl, z),
113+
DynamicPPL.VarInfo(f.varinfo, sampler, z),
129114
f.model,
130-
spl,
131-
f.context
115+
sampler,
116+
f.context,
132117
)
133118

134-
# Use the negative gradient because we are minimizing.
135-
G[:] = -g
119+
# Save the gradient to the pre-allocated array.
120+
copyto!(G, ∇neglogp)
136121

137-
# If F is something, return that since we already have the
138-
# log joint.
122+
# If F is something, the negative log joint is requested as well.
123+
# We have already computed it as a by-product above and hence return it directly.
139124
if F !== nothing
140-
F = -l
141-
return F
125+
return neglogp
142126
end
143127
end
144128

145-
# No gradient necessary, just return the log joint.
129+
# Only negative log joint requested but no gradient.
146130
if F !== nothing
147-
F = f(z)
148-
return F
131+
return f(z)
149132
end
150133

151134
return nothing
@@ -158,16 +141,16 @@ end
158141
#################################################
159142

160143
function transform!(f::OptimLogDensity)
161-
spl = DynamicPPL.SampleFromPrior()
144+
spl = f.sampler
162145

163146
## Check link status of vi in OptimLogDensity
164-
linked = DynamicPPL.islinked(f.vi, spl)
147+
linked = DynamicPPL.islinked(f.varinfo, spl)
165148

166149
## transform into constrained or unconstrained space depending on current state of vi
167150
if !linked
168-
DynamicPPL.link!(f.vi, spl)
151+
DynamicPPL.link!(f.varinfo, spl)
169152
else
170-
DynamicPPL.invlink!(f.vi, spl)
153+
DynamicPPL.invlink!(f.varinfo, spl)
171154
end
172155

173156
return nothing
@@ -249,8 +232,8 @@ function _optim_objective(model::DynamicPPL.Model, ::MAP, ::constrained_space{fa
249232
obj = OptimLogDensity(model, ctx)
250233

251234
transform!(obj)
252-
init = Init(obj.vi, constrained_space{false}())
253-
t = ParameterTransform(obj.vi, constrained_space{true}())
235+
init = Init(obj.varinfo, constrained_space{false}())
236+
t = ParameterTransform(obj.varinfo, constrained_space{true}())
254237

255238
return (obj=obj, init = init, transform=t)
256239
end
@@ -259,8 +242,8 @@ function _optim_objective(model::DynamicPPL.Model, ::MAP, ::constrained_space{tr
259242
ctx = OptimizationContext(DynamicPPL.DefaultContext())
260243
obj = OptimLogDensity(model, ctx)
261244

262-
init = Init(obj.vi, constrained_space{true}())
263-
t = ParameterTransform(obj.vi, constrained_space{true}())
245+
init = Init(obj.varinfo, constrained_space{true}())
246+
t = ParameterTransform(obj.varinfo, constrained_space{true}())
264247

265248
return (obj=obj, init = init, transform=t)
266249
end
@@ -270,8 +253,8 @@ function _optim_objective(model::DynamicPPL.Model, ::MLE, ::constrained_space{f
270253
obj = OptimLogDensity(model, ctx)
271254

272255
transform!(obj)
273-
init = Init(obj.vi, constrained_space{false}())
274-
t = ParameterTransform(obj.vi, constrained_space{true}())
256+
init = Init(obj.varinfo, constrained_space{false}())
257+
t = ParameterTransform(obj.varinfo, constrained_space{true}())
275258

276259
return (obj=obj, init = init, transform=t)
277260
end
@@ -280,8 +263,8 @@ function _optim_objective(model::DynamicPPL.Model, ::MLE, ::constrained_space{tr
280263
ctx = OptimizationContext(DynamicPPL.LikelihoodContext())
281264
obj = OptimLogDensity(model, ctx)
282265

283-
init = Init(obj.vi, constrained_space{true}())
284-
t = ParameterTransform(obj.vi, constrained_space{true}())
266+
init = Init(obj.varinfo, constrained_space{true}())
267+
t = ParameterTransform(obj.varinfo, constrained_space{true}())
285268

286269
return (obj=obj, init = init, transform=t)
287270
end
@@ -309,8 +292,7 @@ function optim_function(
309292
else
310293
OptimizationFunction(
311294
l;
312-
grad = (G,x,p) -> obj(nothing, G, nothing, x),
313-
hess = (H,x,p) -> obj(nothing, nothing, H, x),
295+
grad = (G,x,p) -> obj(nothing, G, x),
314296
)
315297
end
316298

0 commit comments

Comments
 (0)