@@ -152,6 +152,7 @@ struct LogDensityFunction{
152152 ADP<: Union{Nothing,DI.GradientPrep} ,
153153 # type of the vector passed to logdensity functions
154154 X<: AbstractVector ,
155+ AC<: AccumulatorTuple ,
155156}
156157 model:: M
157158 adtype:: AD
@@ -160,11 +161,17 @@ struct LogDensityFunction{
160161 _varname_ranges:: Dict{VarName,RangeAndLinked}
161162 _adprep:: ADP
162163 _dim:: Int
164+ _accs:: AC
163165
164166 function LogDensityFunction(
165167 model:: Model ,
166168 getlogdensity:: Any = getlogjoint_internal,
167- varinfo:: AbstractVarInfo = VarInfo(model);
169+ # TODO (penelopeysm): It is a bit redundant to pass a VarInfo, as well as the
170+ # accumulators, into here. The truth is that the VarInfo is used ONLY for generating
171+ # the ranges and link status, so arguably we should only pass in a metadata; or when
172+ # VNT is done, we should pass in only a VNT.
173+ varinfo:: AbstractVarInfo = VarInfo(model),
174+ accs:: AccumulatorTuple = ldf_accs(getlogdensity);
168175 adtype:: Union{ADTypes.AbstractADType,Nothing} = nothing ,
169176 )
170177 # Figure out which variable corresponds to which index, and
@@ -193,7 +200,7 @@ struct LogDensityFunction{
193200 else
194201 # Make backend-specific tweaks to the adtype
195202 adtype = DynamicPPL. tweak_adtype(adtype, model, varinfo)
196- args = (model, getlogdensity, all_iden_ranges, all_ranges)
203+ args = (model, getlogdensity, all_iden_ranges, all_ranges, accs )
197204 if _use_closure(adtype)
198205 DI. prepare_gradient(LogDensityAt{Tlink}(args... ), adtype, x)
199206 else
@@ -214,8 +221,9 @@ struct LogDensityFunction{
214221 typeof(all_iden_ranges),
215222 typeof(prep),
216223 typeof(x),
224+ typeof(accs),
217225 }(
218- model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim
226+ model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim, accs
219227 )
220228 end
221229end
@@ -264,11 +272,11 @@ function logdensity_at(
264272 getlogdensity:: Any ,
265273 iden_varname_ranges:: NamedTuple ,
266274 varname_ranges:: Dict{VarName,RangeAndLinked} ,
275+ accs:: AccumulatorTuple ,
267276) where {Tlink}
268277 strategy = InitFromParams(
269278 VectorWithRanges{Tlink}(iden_varname_ranges, varname_ranges, params), nothing
270279 )
271- accs = ldf_accs(getlogdensity)
272280 _, vi = DynamicPPL. init!!(model, OnlyAccsVarInfo(accs), strategy)
273281 return getlogdensity(vi)
274282end
@@ -279,25 +287,30 @@ end
279287 getlogdensity::Any,
280288 iden_varname_ranges::NamedTuple,
281289 varname_ranges::Dict{VarName,RangeAndLinked},
290+ accs::AccumulatorTuple,
282291 ) where {Tlink}
283292
284293A callable struct that behaves in the same way as `logdensity_at`, but stores the model and
285294other information internally. Having two separate functions/structs allows for better
286295performance with AD backends.
287296"""
288- struct LogDensityAt{Tlink,M<: Model ,F,N<: NamedTuple }
297+ struct LogDensityAt{Tlink,M<: Model ,F,N<: NamedTuple ,A <: AccumulatorTuple }
289298 model:: M
290299 getlogdensity:: F
291300 iden_varname_ranges:: N
292301 varname_ranges:: Dict{VarName,RangeAndLinked}
302+ accs:: A
293303
294304 function LogDensityAt{Tlink}(
295305 model:: M ,
296306 getlogdensity:: F ,
297307 iden_varname_ranges:: N ,
298308 varname_ranges:: Dict{VarName,RangeAndLinked} ,
299- ) where {Tlink,M,F,N}
300- return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges)
309+ accs:: A ,
310+ ) where {Tlink,M,F,N,A}
311+ return new{Tlink,M,F,N,A}(
312+ model, getlogdensity, iden_varname_ranges, varname_ranges, accs
313+ )
301314 end
302315end
303316function (f:: LogDensityAt{Tlink} )(params:: AbstractVector{<:Real} ) where {Tlink}
@@ -308,6 +321,7 @@ function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink}
308321 f. getlogdensity,
309322 f. iden_varname_ranges,
310323 f. varname_ranges,
324+ f. accs,
311325 )
312326end
313327
@@ -321,6 +335,7 @@ function LogDensityProblems.logdensity(
321335 ldf. _getlogdensity,
322336 ldf. _iden_varname_ranges,
323337 ldf. _varname_ranges,
338+ ldf. _accs,
324339 )
325340end
326341
@@ -333,7 +348,11 @@ function LogDensityProblems.logdensity_and_gradient(
333348 return if _use_closure(ldf. adtype)
334349 DI. value_and_gradient(
335350 LogDensityAt{Tlink}(
336- ldf. model, ldf. _getlogdensity, ldf. _iden_varname_ranges, ldf. _varname_ranges
351+ ldf. model,
352+ ldf. _getlogdensity,
353+ ldf. _iden_varname_ranges,
354+ ldf. _varname_ranges,
355+ ldf. _accs,
337356 ),
338357 ldf. _adprep,
339358 ldf. adtype,
@@ -350,6 +369,7 @@ function LogDensityProblems.logdensity_and_gradient(
350369 DI. Constant(ldf. _getlogdensity),
351370 DI. Constant(ldf. _iden_varname_ranges),
352371 DI. Constant(ldf. _varname_ranges),
372+ DI. Constant(ldf. _accs),
353373 )
354374 end
355375end
0 commit comments