Skip to content

Commit a21b21d

Browse files
authored
Enable the option to switch between differentiating a closure, and a function with constant arguments (#1172)
* Enable closure/non-closure case for LogDensityFunction * Add changelog * Update link to benchmarks
1 parent 408ddb1 commit a21b21d

File tree

4 files changed

+138
-31
lines changed

4 files changed

+138
-31
lines changed

HISTORY.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# DynamicPPL Changelog
22

3+
## 0.39.9
4+
5+
The internals of `LogDensityFunction` have been changed slightly so that you do not need to specify `function_annotation` when performing AD with Enzyme.jl.
6+
There are also some small performance improvements with other AD backends.
7+
38
## 0.39.8
49

510
Allow the `getlogdensity` argument of `LogDensityFunction` to accept callable structs as well as functions.
@@ -29,6 +34,8 @@ In particular, when a test fails, it also tells you the tolerances needed to mak
2934

3035
`returned(model, parameters...)` now accepts any arguments that can be wrapped in `InitFromParams` (previously it would only accept `NamedTuple`, `AbstractDict{<:VarName}`, or a chain).
3136

37+
There should also be some minor performance improvements (maybe 10%) on AD with ForwardDiff / Mooncake.
38+
3239
## 0.39.1
3340

3441
`LogDensityFunction` now allows you to call `logdensity_and_gradient(ldf, x)` with `AbstractVector`s `x` that are not plain Vectors (they will be converted internally before calculating the gradient).

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.39.8"
3+
version = "0.39.9"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/logdensityfunction.jl

Lines changed: 128 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,18 @@ struct LogDensityFunction{
193193
else
194194
# Make backend-specific tweaks to the adtype
195195
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo)
196-
DI.prepare_gradient(
197-
LogDensityAt{Tlink}(model, getlogdensity, all_iden_ranges, all_ranges),
198-
adtype,
199-
x,
200-
)
196+
args = (model, getlogdensity, all_iden_ranges, all_ranges)
197+
if _use_closure(adtype)
198+
DI.prepare_gradient(LogDensityAt{Tlink}(args...), adtype, x)
199+
else
200+
DI.prepare_gradient(
201+
logdensity_at,
202+
adtype,
203+
x,
204+
DI.Constant(Val{Tlink}()),
205+
map(DI.Constant, args)...,
206+
)
207+
end
201208
end
202209
return new{
203210
Tlink,
@@ -237,6 +244,47 @@ end
237244
ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
238245
ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))
239246

247+
"""
248+
logdensity_at(
249+
params::AbstractVector{<:Real},
250+
::Val{Tlink},
251+
model::Model,
252+
getlogdensity::Any,
253+
iden_varname_ranges::NamedTuple,
254+
varname_ranges::Dict{VarName,RangeAndLinked},
255+
) where {Tlink}
256+
257+
Calculate the log density at the given `params`, using the provided
258+
information extracted from a `LogDensityFunction`.
259+
"""
260+
function logdensity_at(
261+
params::AbstractVector{<:Real},
262+
::Val{Tlink},
263+
model::Model,
264+
getlogdensity::Any,
265+
iden_varname_ranges::NamedTuple,
266+
varname_ranges::Dict{VarName,RangeAndLinked},
267+
) where {Tlink}
268+
strategy = InitFromParams(
269+
VectorWithRanges{Tlink}(iden_varname_ranges, varname_ranges, params), nothing
270+
)
271+
accs = ldf_accs(getlogdensity)
272+
_, vi = DynamicPPL.init!!(model, OnlyAccsVarInfo(accs), strategy)
273+
return getlogdensity(vi)
274+
end
275+
276+
"""
277+
LogDensityAt{Tlink}(
278+
model::Model,
279+
getlogdensity::Any,
280+
iden_varname_ranges::NamedTuple,
281+
varname_ranges::Dict{VarName,RangeAndLinked},
282+
) where {Tlink}
283+
284+
A callable struct that behaves in the same way as `logdensity_at`, but stores the model and
285+
other information internally. Having two separate functions/structs allows for better
286+
performance with AD backends.
287+
"""
240288
struct LogDensityAt{Tlink,M<:Model,F,N<:NamedTuple}
241289
model::M
242290
getlogdensity::F
@@ -253,36 +301,57 @@ struct LogDensityAt{Tlink,M<:Model,F,N<:NamedTuple}
253301
end
254302
end
255303
function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink}
256-
strategy = InitFromParams(
257-
VectorWithRanges{Tlink}(f.iden_varname_ranges, f.varname_ranges, params), nothing
304+
return logdensity_at(
305+
params,
306+
Val{Tlink}(),
307+
f.model,
308+
f.getlogdensity,
309+
f.iden_varname_ranges,
310+
f.varname_ranges,
258311
)
259-
accs = ldf_accs(f.getlogdensity)
260-
_, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy)
261-
return f.getlogdensity(vi)
262312
end
263313

264314
function LogDensityProblems.logdensity(
265315
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
266316
) where {Tlink}
267-
return LogDensityAt{Tlink}(
268-
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
269-
)(
270-
params
317+
return logdensity_at(
318+
params,
319+
Val{Tlink}(),
320+
ldf.model,
321+
ldf._getlogdensity,
322+
ldf._iden_varname_ranges,
323+
ldf._varname_ranges,
271324
)
272325
end
273326

274327
function LogDensityProblems.logdensity_and_gradient(
275328
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
276329
) where {Tlink}
330+
# `params` has to be converted to the same vector type that was used for AD preparation,
331+
# otherwise the preparation will not be valid.
277332
params = convert(_get_input_vector_type(ldf), params)
278-
return DI.value_and_gradient(
279-
LogDensityAt{Tlink}(
280-
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
281-
),
282-
ldf._adprep,
283-
ldf.adtype,
284-
params,
285-
)
333+
return if _use_closure(ldf.adtype)
334+
DI.value_and_gradient(
335+
LogDensityAt{Tlink}(
336+
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
337+
),
338+
ldf._adprep,
339+
ldf.adtype,
340+
params,
341+
)
342+
else
343+
DI.value_and_gradient(
344+
logdensity_at,
345+
ldf._adprep,
346+
ldf.adtype,
347+
params,
348+
DI.Constant(Val{Tlink}()),
349+
DI.Constant(ldf.model),
350+
DI.Constant(ldf._getlogdensity),
351+
DI.Constant(ldf._iden_varname_ranges),
352+
DI.Constant(ldf._varname_ranges),
353+
)
354+
end
286355
end
287356

288357
function LogDensityProblems.capabilities(
@@ -316,6 +385,43 @@ By default, this just returns the input unchanged.
316385
"""
317386
tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype
318387

388+
"""
389+
_use_closure(adtype::ADTypes.AbstractADType)
390+
391+
In LogDensityProblems, we want to calculate the derivative of `logdensity(f, x)` with
392+
respect to x, where f is the model (in our case LogDensityFunction or its arguments ) and is
393+
a constant. However, DifferentiationInterface generally expects a single-argument function
394+
g(x) to differentiate.
395+
396+
There are two ways of dealing with this:
397+
398+
1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f)
399+
400+
2. Use a constant DI.Context. This lets us pass a two-argument function to DI, as long as we
401+
also give it the 'inactive argument' (i.e. the model) wrapped in `DI.Constant`.
402+
403+
The relative performance of the two approaches, however, depends on the AD backend used.
404+
Some benchmarks are provided here: https://github.com/TuringLang/DynamicPPL.jl/pull/1172
405+
406+
This function is used to determine whether a given AD backend should use a closure or a
407+
constant. If `use_closure(adtype)` returns `true`, then the closure approach will be used.
408+
By default, this function returns `false`, i.e. the constant approach will be used.
409+
"""
410+
# For these AD backends both closure and no closure work, but it is just faster to not use a
411+
# closure (see link in the docstring).
412+
_use_closure(::ADTypes.AutoForwardDiff) = false
413+
_use_closure(::ADTypes.AutoMooncake) = false
414+
_use_closure(::ADTypes.AutoMooncakeForward) = false
415+
# For ReverseDiff, with the compiled tape, you _must_ use a closure because otherwise with
416+
# DI.Constant arguments the tape will always be recompiled upon each call to
417+
# value_and_gradient. For non-compiled ReverseDiff, it is faster to not use a closure.
418+
_use_closure(::ADTypes.AutoReverseDiff{compile}) where {compile} = !compile
419+
# For AutoEnzyme it allows us to avoid setting function_annotation
420+
_use_closure(::ADTypes.AutoEnzyme) = false
421+
# Since for most backends it's faster to not use a closure, we set that as the default
422+
# for unknown AD backends
423+
_use_closure(::ADTypes.AbstractADType) = false
424+
319425
######################################################
320426
# Helper functions to extract ranges and link status #
321427
######################################################

test/integration/enzyme/main.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,8 @@ import Enzyme: set_runtime_activity, Forward, Reverse, Const
66
using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test
77

88
ADTYPES = (
9-
(
10-
"EnzymeForward",
11-
AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const),
12-
),
13-
(
14-
"EnzymeReverse",
15-
AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const),
16-
),
9+
("EnzymeForward", AutoEnzyme(; mode=set_runtime_activity(Forward))),
10+
("EnzymeReverse", AutoEnzyme(; mode=set_runtime_activity(Reverse))),
1711
)
1812

1913
@testset "$ad_key" for (ad_key, ad_type) in ADTYPES

0 commit comments

Comments
 (0)