@@ -193,11 +193,18 @@ struct LogDensityFunction{
193193 else
194194 # Make backend-specific tweaks to the adtype
195195 adtype = DynamicPPL. tweak_adtype(adtype, model, varinfo)
196- DI. prepare_gradient(
197- LogDensityAt{Tlink}(model, getlogdensity, all_iden_ranges, all_ranges),
198- adtype,
199- x,
200- )
196+ args = (model, getlogdensity, all_iden_ranges, all_ranges)
197+ if _use_closure(adtype)
198+ DI. prepare_gradient(LogDensityAt{Tlink}(args... ), adtype, x)
199+ else
200+ DI. prepare_gradient(
201+ logdensity_at,
202+ adtype,
203+ x,
204+ DI. Constant(Val{Tlink}()),
205+ map(DI. Constant, args). .. ,
206+ )
207+ end
201208 end
202209 return new{
203210 Tlink,
237244ldf_accs(:: typeof (getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
238245ldf_accs(:: typeof (getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))
239246
247+ """
248+ logdensity_at(
249+ params::AbstractVector{<:Real},
250+ ::Val{Tlink},
251+ model::Model,
252+ getlogdensity::Any,
253+ iden_varname_ranges::NamedTuple,
254+ varname_ranges::Dict{VarName,RangeAndLinked},
255+ ) where {Tlink}
256+
257+ Calculate the log density at the given `params`, using the provided
258+ information extracted from a `LogDensityFunction`.
259+ """
260+ function logdensity_at(
261+ params:: AbstractVector{<:Real} ,
262+ :: Val{Tlink} ,
263+ model:: Model ,
264+ getlogdensity:: Any ,
265+ iden_varname_ranges:: NamedTuple ,
266+ varname_ranges:: Dict{VarName,RangeAndLinked} ,
267+ ) where {Tlink}
268+ strategy = InitFromParams(
269+ VectorWithRanges{Tlink}(iden_varname_ranges, varname_ranges, params), nothing
270+ )
271+ accs = ldf_accs(getlogdensity)
272+ _, vi = DynamicPPL. init!!(model, OnlyAccsVarInfo(accs), strategy)
273+ return getlogdensity(vi)
274+ end
275+
276+ """
277+ LogDensityAt{Tlink}(
278+ model::Model,
279+ getlogdensity::Any,
280+ iden_varname_ranges::NamedTuple,
281+ varname_ranges::Dict{VarName,RangeAndLinked},
282+ ) where {Tlink}
283+
284+ A callable struct that behaves in the same way as `logdensity_at`, but stores the model and
285+ other information internally. Having two separate functions/structs allows for better
286+ performance with AD backends.
287+ """
240288struct LogDensityAt{Tlink,M<: Model ,F,N<: NamedTuple }
241289 model:: M
242290 getlogdensity:: F
@@ -253,36 +301,57 @@ struct LogDensityAt{Tlink,M<:Model,F,N<:NamedTuple}
253301 end
254302end
255303function (f:: LogDensityAt{Tlink} )(params:: AbstractVector{<:Real} ) where {Tlink}
256- strategy = InitFromParams(
257- VectorWithRanges{Tlink}(f. iden_varname_ranges, f. varname_ranges, params), nothing
304+ return logdensity_at(
305+ params,
306+ Val{Tlink}(),
307+ f. model,
308+ f. getlogdensity,
309+ f. iden_varname_ranges,
310+ f. varname_ranges,
258311 )
259- accs = ldf_accs(f. getlogdensity)
260- _, vi = DynamicPPL. init!!(f. model, OnlyAccsVarInfo(accs), strategy)
261- return f. getlogdensity(vi)
262312end
263313
264314function LogDensityProblems. logdensity(
265315 ldf:: LogDensityFunction{Tlink} , params:: AbstractVector{<:Real}
266316) where {Tlink}
267- return LogDensityAt{Tlink}(
268- ldf. model, ldf. _getlogdensity, ldf. _iden_varname_ranges, ldf. _varname_ranges
269- )(
270- params
317+ return logdensity_at(
318+ params,
319+ Val{Tlink}(),
320+ ldf. model,
321+ ldf. _getlogdensity,
322+ ldf. _iden_varname_ranges,
323+ ldf. _varname_ranges,
271324 )
272325end
273326
274327function LogDensityProblems. logdensity_and_gradient(
275328 ldf:: LogDensityFunction{Tlink} , params:: AbstractVector{<:Real}
276329) where {Tlink}
330+ # `params` has to be converted to the same vector type that was used for AD preparation,
331+ # otherwise the preparation will not be valid.
277332 params = convert(_get_input_vector_type(ldf), params)
278- return DI. value_and_gradient(
279- LogDensityAt{Tlink}(
280- ldf. model, ldf. _getlogdensity, ldf. _iden_varname_ranges, ldf. _varname_ranges
281- ),
282- ldf. _adprep,
283- ldf. adtype,
284- params,
285- )
333+ return if _use_closure(ldf. adtype)
334+ DI. value_and_gradient(
335+ LogDensityAt{Tlink}(
336+ ldf. model, ldf. _getlogdensity, ldf. _iden_varname_ranges, ldf. _varname_ranges
337+ ),
338+ ldf. _adprep,
339+ ldf. adtype,
340+ params,
341+ )
342+ else
343+ DI. value_and_gradient(
344+ logdensity_at,
345+ ldf. _adprep,
346+ ldf. adtype,
347+ params,
348+ DI. Constant(Val{Tlink}()),
349+ DI. Constant(ldf. model),
350+ DI. Constant(ldf. _getlogdensity),
351+ DI. Constant(ldf. _iden_varname_ranges),
352+ DI. Constant(ldf. _varname_ranges),
353+ )
354+ end
286355end
287356
288357function LogDensityProblems. capabilities(
@@ -316,6 +385,43 @@ By default, this just returns the input unchanged.
316385"""
317386tweak_adtype(adtype:: ADTypes.AbstractADType , :: Model , :: AbstractVarInfo ) = adtype
318387
388+ """
389+ _use_closure(adtype::ADTypes.AbstractADType)
390+
391+ In LogDensityProblems, we want to calculate the derivative of `logdensity(f, x)` with
392+ respect to x, where f is the model (in our case LogDensityFunction or its arguments ) and is
393+ a constant. However, DifferentiationInterface generally expects a single-argument function
394+ g(x) to differentiate.
395+
396+ There are two ways of dealing with this:
397+
398+ 1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f)
399+
400+ 2. Use a constant DI.Context. This lets us pass a two-argument function to DI, as long as we
401+ also give it the 'inactive argument' (i.e. the model) wrapped in `DI.Constant`.
402+
403+ The relative performance of the two approaches, however, depends on the AD backend used.
404+ Some benchmarks are provided here: https://github.com/TuringLang/DynamicPPL.jl/pull/1172
405+
406+ This function is used to determine whether a given AD backend should use a closure or a
407+ constant. If `use_closure(adtype)` returns `true`, then the closure approach will be used.
408+ By default, this function returns `false`, i.e. the constant approach will be used.
409+ """
410+ # For these AD backends both closure and no closure work, but it is just faster to not use a
411+ # closure (see link in the docstring).
412+ _use_closure(:: ADTypes.AutoForwardDiff ) = false
413+ _use_closure(:: ADTypes.AutoMooncake ) = false
414+ _use_closure(:: ADTypes.AutoMooncakeForward ) = false
415+ # For ReverseDiff, with the compiled tape, you _must_ use a closure because otherwise with
416+ # DI.Constant arguments the tape will always be recompiled upon each call to
417+ # value_and_gradient. For non-compiled ReverseDiff, it is faster to not use a closure.
418+ _use_closure(:: ADTypes.AutoReverseDiff{compile} ) where {compile} = ! compile
419+ # For AutoEnzyme it allows us to avoid setting function_annotation
420+ _use_closure(:: ADTypes.AutoEnzyme ) = false
421+ # Since for most backends it's faster to not use a closure, we set that as the default
422+ # for unknown AD backends
423+ _use_closure(:: ADTypes.AbstractADType ) = false
424+
319425# #####################################################
320426# Helper functions to extract ranges and link status #
321427# #####################################################
0 commit comments