@@ -33,7 +33,8 @@ using Random: Random
3333 DynamicPPL.LogDensityFunction(
3434 model::Model,
3535 getlogdensity::Any=getlogjoint_internal,
36- varinfo::AbstractVarInfo=VarInfo(model);
36+ varinfo::AbstractVarInfo=VarInfo(model)
37+ accs::Union{NTuple{<:Any,AbstractAccumulator},AccumulatorTuple}=DynamicPPL.ldf_accs(getlogdensity);
3738 adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
3839 )
3940
@@ -72,6 +73,12 @@ If you provide one of these functions, a `VarInfo` will be automatically created
7273you provide a different function, you have to manually create a VarInfo and pass it as the
7374third argument.
7475
76+ `accs` allows you to specify an `AccumulatorTuple` or a tuple of `AbstractAccumulator`s
77+ which will be used _when evaluating the log density_`. (Note that the accumulators from the
78+ `VarInfo` argument are discarded.) By default, this uses an internal function,
79+ `DynamicPPL.ldf_accs`, which attempts to choose an appropriate set of accumulators based on
80+ which kind of log-density is being calculated.
81+
7582If the `adtype` keyword argument is provided, then this struct will also store the adtype
7683along with other information for efficient calculation of the gradient of the log density.
7784Note that preparing a `LogDensityFunction` with an AD type `AutoBackend()` requires the AD
@@ -152,6 +159,7 @@ struct LogDensityFunction{
152159 ADP<: Union{Nothing,DI.GradientPrep} ,
153160 # type of the vector passed to logdensity functions
154161 X<: AbstractVector ,
162+ AC<: AccumulatorTuple ,
155163}
156164 model:: M
157165 adtype:: AD
@@ -160,11 +168,19 @@ struct LogDensityFunction{
160168 _varname_ranges:: Dict{VarName,RangeAndLinked}
161169 _adprep:: ADP
162170 _dim:: Int
171+ _accs:: AC
163172
164173 function LogDensityFunction(
165174 model:: Model ,
166175 getlogdensity:: Any = getlogjoint_internal,
167- varinfo:: AbstractVarInfo = VarInfo(model);
176+ # TODO (penelopeysm): It is a bit redundant to pass a VarInfo, as well as the
177+ # accumulators, into here. The truth is that the VarInfo is used ONLY for generating
178+ # the ranges and link status, so arguably we should only pass in a metadata; or when
179+ # VNT is done, we should pass in only a VNT.
180+ varinfo:: AbstractVarInfo = VarInfo(model),
181+ accs:: Union{NTuple{<:Any,AbstractAccumulator},AccumulatorTuple} = ldf_accs(
182+ getlogdensity
183+ );
168184 adtype:: Union{ADTypes.AbstractADType,Nothing} = nothing ,
169185 )
170186 # Figure out which variable corresponds to which index, and
@@ -187,13 +203,15 @@ struct LogDensityFunction{
187203 end
188204 x = [val for val in varinfo[:]]
189205 dim = length(x)
206+ # convert to AccumulatorTuple if needed
207+ accs = AccumulatorTuple(accs)
190208 # Do AD prep if needed
191209 prep = if adtype === nothing
192210 nothing
193211 else
194212 # Make backend-specific tweaks to the adtype
195213 adtype = DynamicPPL. tweak_adtype(adtype, model, varinfo)
196- args = (model, getlogdensity, all_iden_ranges, all_ranges)
214+ args = (model, getlogdensity, all_iden_ranges, all_ranges, accs )
197215 if _use_closure(adtype)
198216 DI. prepare_gradient(LogDensityAt{Tlink}(args... ), adtype, x)
199217 else
@@ -214,8 +232,9 @@ struct LogDensityFunction{
214232 typeof(all_iden_ranges),
215233 typeof(prep),
216234 typeof(x),
235+ typeof(accs),
217236 }(
218- model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim
237+ model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim, accs
219238 )
220239 end
221240end
@@ -264,11 +283,11 @@ function logdensity_at(
264283 getlogdensity:: Any ,
265284 iden_varname_ranges:: NamedTuple ,
266285 varname_ranges:: Dict{VarName,RangeAndLinked} ,
286+ accs:: AccumulatorTuple ,
267287) where {Tlink}
268288 strategy = InitFromParams(
269289 VectorWithRanges{Tlink}(iden_varname_ranges, varname_ranges, params), nothing
270290 )
271- accs = ldf_accs(getlogdensity)
272291 _, vi = DynamicPPL. init!!(model, OnlyAccsVarInfo(accs), strategy)
273292 return getlogdensity(vi)
274293end
@@ -279,25 +298,30 @@ end
279298 getlogdensity::Any,
280299 iden_varname_ranges::NamedTuple,
281300 varname_ranges::Dict{VarName,RangeAndLinked},
301+ accs::AccumulatorTuple,
282302 ) where {Tlink}
283303
284304A callable struct that behaves in the same way as `logdensity_at`, but stores the model and
285305other information internally. Having two separate functions/structs allows for better
286306performance with AD backends.
287307"""
288- struct LogDensityAt{Tlink,M<: Model ,F,N<: NamedTuple }
308+ struct LogDensityAt{Tlink,M<: Model ,F,N<: NamedTuple ,A <: AccumulatorTuple }
289309 model:: M
290310 getlogdensity:: F
291311 iden_varname_ranges:: N
292312 varname_ranges:: Dict{VarName,RangeAndLinked}
313+ accs:: A
293314
294315 function LogDensityAt{Tlink}(
295316 model:: M ,
296317 getlogdensity:: F ,
297318 iden_varname_ranges:: N ,
298319 varname_ranges:: Dict{VarName,RangeAndLinked} ,
299- ) where {Tlink,M,F,N}
300- return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges)
320+ accs:: A ,
321+ ) where {Tlink,M,F,N,A}
322+ return new{Tlink,M,F,N,A}(
323+ model, getlogdensity, iden_varname_ranges, varname_ranges, accs
324+ )
301325 end
302326end
303327function (f:: LogDensityAt{Tlink} )(params:: AbstractVector{<:Real} ) where {Tlink}
@@ -308,6 +332,7 @@ function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink}
308332 f. getlogdensity,
309333 f. iden_varname_ranges,
310334 f. varname_ranges,
335+ f. accs,
311336 )
312337end
313338
@@ -321,6 +346,7 @@ function LogDensityProblems.logdensity(
321346 ldf. _getlogdensity,
322347 ldf. _iden_varname_ranges,
323348 ldf. _varname_ranges,
349+ ldf. _accs,
324350 )
325351end
326352
@@ -333,7 +359,11 @@ function LogDensityProblems.logdensity_and_gradient(
333359 return if _use_closure(ldf. adtype)
334360 DI. value_and_gradient(
335361 LogDensityAt{Tlink}(
336- ldf. model, ldf. _getlogdensity, ldf. _iden_varname_ranges, ldf. _varname_ranges
362+ ldf. model,
363+ ldf. _getlogdensity,
364+ ldf. _iden_varname_ranges,
365+ ldf. _varname_ranges,
366+ ldf. _accs,
337367 ),
338368 ldf. _adprep,
339369 ldf. adtype,
@@ -350,6 +380,7 @@ function LogDensityProblems.logdensity_and_gradient(
350380 DI. Constant(ldf. _getlogdensity),
351381 DI. Constant(ldf. _iden_varname_ranges),
352382 DI. Constant(ldf. _varname_ranges),
383+ DI. Constant(ldf. _accs),
353384 )
354385 end
355386end
0 commit comments