Skip to content

Commit e60eab0

Browse files
mhaurupenelopeysm
andauthored
Accumulators stage 2 (#925)
* Give LogDensityFunction the getlogdensity field * Allow missing LogPriorAccumulator when linking * Trim whitespace * Run formatter * Fix a few typos * Fix comma -> semicolon * Fix `LogDensityAt` invocation * Fix one last test * Fix tests --------- Co-authored-by: Penelope Yong <[email protected]>
1 parent f4dd46a commit e60eab0

File tree

8 files changed

+144
-70
lines changed

8 files changed

+144
-70
lines changed

benchmarks/src/DynamicPPLBenchmarks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
8686
vi = DynamicPPL.link(vi, model)
8787
end
8888

89-
f = DynamicPPL.LogDensityFunction(model, vi; adtype=adbackend)
89+
f = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint, vi; adtype=adbackend)
9090
# The parameters at which we evaluate f.
9191
θ = vi[:]
9292

src/logdensityfunction.jl

Lines changed: 91 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ is_supported(::ADTypes.AutoReverseDiff) = true
1818
"""
1919
LogDensityFunction(
2020
model::Model,
21-
varinfo::AbstractVarInfo=VarInfo(model);
21+
getlogdensity::Function=getlogjoint,
22+
varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity);
2223
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
2324
)
2425
@@ -28,9 +29,10 @@ A struct which contains a model, along with all the information necessary to:
2829
- and if `adtype` is provided, calculate the gradient of the log density at
2930
that point.
3031
31-
At its most basic level, a LogDensityFunction wraps the model together with the
32-
type of varinfo to be used. These must be known in order to calculate the log
33-
density (using [`DynamicPPL.evaluate!!`](@ref)).
32+
At its most basic level, a LogDensityFunction wraps the model together with a
33+
function that specifies how to extract the log density, and the type of
34+
VarInfo to be used. These must be known in order to calculate the log density
35+
(using [`DynamicPPL.evaluate!!`](@ref)).
3436
3537
If the `adtype` keyword argument is provided, then this struct will also store
3638
the adtype along with other information for efficient calculation of the
@@ -72,13 +74,13 @@ julia> LogDensityProblems.dimension(f)
7274
1
7375
7476
julia> # By default it uses `VarInfo` under the hood, but this is not necessary.
75-
f = LogDensityFunction(model, SimpleVarInfo(model));
77+
f = LogDensityFunction(model, getlogjoint, SimpleVarInfo(model));
7678
7779
julia> LogDensityProblems.logdensity(f, [0.0])
7880
-2.3378770664093453
7981
80-
julia> # LogDensityFunction respects the accumulators in VarInfo:
81-
f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPriorAccumulator(),)));
82+
julia> # One can also specify evaluating e.g. the log prior only:
83+
f_prior = LogDensityFunction(model, getlogprior);
8284
8385
julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
8486
true
@@ -93,11 +95,13 @@ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
9395
```
9496
"""
9597
struct LogDensityFunction{
96-
M<:Model,V<:AbstractVarInfo,AD<:Union{Nothing,ADTypes.AbstractADType}
98+
M<:Model,F<:Function,V<:AbstractVarInfo,AD<:Union{Nothing,ADTypes.AbstractADType}
9799
} <: AbstractModel
98100
"model used for evaluation"
99101
model::M
100-
"varinfo used for evaluation"
102+
"function to be called on `varinfo` to extract the log density. By default `getlogjoint`."
103+
getlogdensity::F
104+
"varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`."
101105
varinfo::V
102106
"AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated"
103107
adtype::AD
@@ -106,7 +110,8 @@ struct LogDensityFunction{
106110

107111
function LogDensityFunction(
108112
model::Model,
109-
varinfo::AbstractVarInfo=VarInfo(model);
113+
getlogdensity::Function=getlogjoint,
114+
varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity);
110115
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
111116
)
112117
if adtype === nothing
@@ -120,15 +125,22 @@ struct LogDensityFunction{
120125
# Get a set of dummy params to use for prep
121126
x = map(identity, varinfo[:])
122127
if use_closure(adtype)
123-
prep = DI.prepare_gradient(LogDensityAt(model, varinfo), adtype, x)
128+
prep = DI.prepare_gradient(
129+
LogDensityAt(model, getlogdensity, varinfo), adtype, x
130+
)
124131
else
125132
prep = DI.prepare_gradient(
126-
logdensity_at, adtype, x, DI.Constant(model), DI.Constant(varinfo)
133+
logdensity_at,
134+
adtype,
135+
x,
136+
DI.Constant(model),
137+
DI.Constant(getlogdensity),
138+
DI.Constant(varinfo),
127139
)
128140
end
129141
end
130-
return new{typeof(model),typeof(varinfo),typeof(adtype)}(
131-
model, varinfo, adtype, prep
142+
return new{typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(adtype)}(
143+
model, getlogdensity, varinfo, adtype, prep
132144
)
133145
end
134146
end
@@ -149,83 +161,112 @@ function LogDensityFunction(
149161
return if adtype === f.adtype
150162
f # Avoid recomputing prep if not needed
151163
else
152-
LogDensityFunction(f.model, f.varinfo; adtype=adtype)
164+
LogDensityFunction(f.model, f.getlogdensity, f.varinfo; adtype=adtype)
153165
end
154166
end
155167

168+
"""
169+
ldf_default_varinfo(model::Model, getlogdensity::Function)
170+
171+
Create the default AbstractVarInfo that should be used for evaluating the log density.
172+
173+
Only the accumulators necesessary for `getlogdensity` will be used.
174+
"""
175+
function ldf_default_varinfo(::Model, getlogdensity::Function)
176+
msg = """
177+
LogDensityFunction does not know what sort of VarInfo should be used when \
178+
`getlogdensity` is $getlogdensity. Please specify a VarInfo explicitly.
179+
"""
180+
return error(msg)
181+
end
182+
183+
ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) = VarInfo(model)
184+
185+
function ldf_default_varinfo(model::Model, ::typeof(getlogprior))
186+
return setaccs!!(VarInfo(model), (LogPriorAccumulator(),))
187+
end
188+
189+
function ldf_default_varinfo(model::Model, ::typeof(getloglikelihood))
190+
return setaccs!!(VarInfo(model), (LogLikelihoodAccumulator(),))
191+
end
192+
156193
"""
157194
logdensity_at(
158195
x::AbstractVector,
159196
model::Model,
197+
getlogdensity::Function,
160198
varinfo::AbstractVarInfo,
161199
)
162200
163-
Evaluate the log density of the given `model` at the given parameter values `x`,
164-
using the given `varinfo`. Note that the `varinfo` argument is provided only
165-
for its structure, in the sense that the parameters from the vector `x` are
166-
inserted into it, and its own parameters are discarded. It does, however,
167-
determine whether the log prior, likelihood, or joint is returned, based on
168-
which accumulators are set in it.
201+
Evaluate the log density of the given `model` at the given parameter values
202+
`x`, using the given `varinfo`. Note that the `varinfo` argument is provided
203+
only for its structure, in the sense that the parameters from the vector `x`
204+
are inserted into it, and its own parameters are discarded. `getlogdensity` is
205+
the function that extracts the log density from the evaluated varinfo.
169206
"""
170-
function logdensity_at(x::AbstractVector, model::Model, varinfo::AbstractVarInfo)
207+
function logdensity_at(
208+
x::AbstractVector, model::Model, getlogdensity::Function, varinfo::AbstractVarInfo
209+
)
171210
varinfo_new = unflatten(varinfo, x)
172211
varinfo_eval = last(evaluate!!(model, varinfo_new))
173-
has_prior = hasacc(varinfo_eval, Val(:LogPrior))
174-
has_likelihood = hasacc(varinfo_eval, Val(:LogLikelihood))
175-
if has_prior && has_likelihood
176-
return getlogjoint(varinfo_eval)
177-
elseif has_prior
178-
return getlogprior(varinfo_eval)
179-
elseif has_likelihood
180-
return getloglikelihood(varinfo_eval)
181-
else
182-
error("LogDensityFunction: varinfo tracks neither log prior nor log likelihood")
183-
end
212+
return getlogdensity(varinfo_eval)
184213
end
185214

186215
"""
187-
LogDensityAt{M<:Model,V<:AbstractVarInfo}(
216+
LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo}(
188217
model::M
218+
getlogdensity::F,
189219
varinfo::V
190220
)
191221
192222
A callable struct that serves the same purpose as `x -> logdensity_at(x, model,
193-
varinfo)`.
223+
getlogdensity, varinfo)`.
194224
"""
195-
struct LogDensityAt{M<:Model,V<:AbstractVarInfo}
225+
struct LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo}
196226
model::M
227+
getlogdensity::F
197228
varinfo::V
198229
end
199-
(ld::LogDensityAt)(x::AbstractVector) = logdensity_at(x, ld.model, ld.varinfo)
230+
function (ld::LogDensityAt)(x::AbstractVector)
231+
return logdensity_at(x, ld.model, ld.getlogdensity, ld.varinfo)
232+
end
200233

201234
### LogDensityProblems interface
202235

203236
function LogDensityProblems.capabilities(
204-
::Type{<:LogDensityFunction{M,V,Nothing}}
205-
) where {M,V}
237+
::Type{<:LogDensityFunction{M,F,V,Nothing}}
238+
) where {M,F,V}
206239
return LogDensityProblems.LogDensityOrder{0}()
207240
end
208241
function LogDensityProblems.capabilities(
209-
::Type{<:LogDensityFunction{M,V,AD}}
210-
) where {M,V,AD<:ADTypes.AbstractADType}
242+
::Type{<:LogDensityFunction{M,F,V,AD}}
243+
) where {M,F,V,AD<:ADTypes.AbstractADType}
211244
return LogDensityProblems.LogDensityOrder{1}()
212245
end
213246
function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector)
214-
return logdensity_at(x, f.model, f.varinfo)
247+
return logdensity_at(x, f.model, f.getlogdensity, f.varinfo)
215248
end
216249
function LogDensityProblems.logdensity_and_gradient(
217-
f::LogDensityFunction{M,V,AD}, x::AbstractVector
218-
) where {M,V,AD<:ADTypes.AbstractADType}
250+
f::LogDensityFunction{M,F,V,AD}, x::AbstractVector
251+
) where {M,F,V,AD<:ADTypes.AbstractADType}
219252
f.prep === nothing &&
220253
error("Gradient preparation not available; this should not happen")
221254
x = map(identity, x) # Concretise type
222255
# Make branching statically inferrable, i.e. type-stable (even if the two
223256
# branches happen to return different types)
224257
return if use_closure(f.adtype)
225-
DI.value_and_gradient(LogDensityAt(f.model, f.varinfo), f.prep, f.adtype, x)
258+
DI.value_and_gradient(
259+
LogDensityAt(f.model, f.getlogdensity, f.varinfo), f.prep, f.adtype, x
260+
)
226261
else
227262
DI.value_and_gradient(
228-
logdensity_at, f.prep, f.adtype, x, DI.Constant(f.model), DI.Constant(f.varinfo)
263+
logdensity_at,
264+
f.prep,
265+
f.adtype,
266+
x,
267+
DI.Constant(f.model),
268+
DI.Constant(f.getlogdensity),
269+
DI.Constant(f.varinfo),
229270
)
230271
end
231272
end
@@ -264,9 +305,9 @@ There are two ways of dealing with this:
264305
265306
1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f)
266307
267-
2. Use a constant context. This lets us pass a two-argument function to
268-
DifferentiationInterface, as long as we also give it the 'inactive argument'
269-
(i.e. the model) wrapped in `DI.Constant`.
308+
2. Use a constant DI.Context. This lets us pass a two-argument function to DI,
309+
as long as we also give it the 'inactive argument' (i.e. the model) wrapped
310+
in `DI.Constant`.
270311
271312
The relative performance of the two approaches, however, depends on the AD
272313
backend used. Some benchmarks are provided here:
@@ -292,7 +333,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model
292333
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
293334
"""
294335
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
295-
return LogDensityFunction(model, f.varinfo; adtype=f.adtype)
336+
return LogDensityFunction(model, f.getlogdensity, f.varinfo; adtype=f.adtype)
296337
end
297338

298339
"""

src/simple_varinfo.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,9 @@ function link!!(
619619
x = vi.values
620620
y, logjac = with_logabsdet_jacobian(b, x)
621621
vi_new = Accessors.@set(vi.values = y)
622-
vi_new = acclogprior!!(vi_new, -logjac)
622+
if hasacc(vi_new, Val(:LogPrior))
623+
vi_new = acclogprior!!(vi_new, -logjac)
624+
end
623625
return settrans!!(vi_new, t)
624626
end
625627

@@ -632,7 +634,9 @@ function invlink!!(
632634
y = vi.values
633635
x, logjac = with_logabsdet_jacobian(b, y)
634636
vi_new = Accessors.@set(vi.values = x)
635-
vi_new = acclogprior!!(vi_new, logjac)
637+
if hasacc(vi_new, Val(:LogPrior))
638+
vi_new = acclogprior!!(vi_new, logjac)
639+
end
636640
return settrans!!(vi_new, NoTransformation())
637641
end
638642

src/test_utils/ad.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ADTypes: AbstractADType, AutoForwardDiff
44
using Chairmarks: @be
55
import DifferentiationInterface as DI
66
using DocStringExtensions
7-
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link
7+
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint, link
88
using LogDensityProblems: logdensity, logdensity_and_gradient
99
using Random: AbstractRNG, default_rng
1010
using Statistics: median
@@ -88,6 +88,8 @@ $(TYPEDFIELDS)
8888
struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloat}
8989
"The DynamicPPL model that was tested"
9090
model::Model
91+
"The function used to extract the log density from the model"
92+
getlogdensity::Function
9193
"The VarInfo that was used"
9294
varinfo::AbstractVarInfo
9395
"The values at which the model was evaluated"
@@ -222,6 +224,7 @@ function run_ad(
222224
benchmark::Bool=false,
223225
atol::AbstractFloat=100 * eps(),
224226
rtol::AbstractFloat=sqrt(eps()),
227+
getlogdensity::Function=getlogjoint,
225228
rng::AbstractRNG=default_rng(),
226229
varinfo::AbstractVarInfo=link(VarInfo(rng, model), model),
227230
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
@@ -241,7 +244,8 @@ function run_ad(
241244
# Calculate log-density and gradient with the backend of interest
242245
verbose && @info "Running AD on $(model.f) with $(adtype)\n"
243246
verbose && println(" params : $(params)")
244-
ldf = LogDensityFunction(model, varinfo; adtype=adtype)
247+
ldf = LogDensityFunction(model, getlogdensity, varinfo; adtype=adtype)
248+
245249
value, grad = logdensity_and_gradient(ldf, params)
246250
# collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
247251
grad = collect(grad)
@@ -257,7 +261,9 @@ function run_ad(
257261
value_true = test.value
258262
grad_true = test.grad
259263
elseif test isa WithBackend
260-
ldf_reference = LogDensityFunction(model, varinfo; adtype=test.adtype)
264+
ldf_reference = LogDensityFunction(
265+
model, getlogdensity, varinfo; adtype=test.adtype
266+
)
261267
value_true, grad_true = logdensity_and_gradient(ldf_reference, params)
262268
# collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
263269
grad_true = collect(grad_true)
@@ -282,6 +288,7 @@ function run_ad(
282288

283289
return ADResult(
284290
model,
291+
getlogdensity,
285292
varinfo,
286293
params,
287294
adtype,

0 commit comments

Comments
 (0)