Skip to content

Commit c04612a

Browse files
committed
Add extra accs argument for LogDensityFunction
1 parent 47f6b7e commit c04612a

File tree

5 files changed

+72
-10
lines changed

5 files changed

+72
-10
lines changed

HISTORY.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# DynamicPPL Changelog
22

3+
## 0.39.11
4+
5+
Allow passing `accs::Union{NTuple{N,AbstractAccumulator},AccumulatorTuple}` into the `LogDensityFunction` constructor to specify custom accumulators to use when evaluating the model.
6+
Previously, this was hard-coded.
7+
38
## 0.39.10
49

510
Rename the internal functions `matchingvalue` and `get_matching_type` to `convert_model_argument` and `promote_model_type_argument` respectively.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.39.10"
3+
version = "0.39.11"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/accumulators.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ end
157157

158158
AccumulatorTuple(accs::Vararg{AbstractAccumulator}) = AccumulatorTuple(accs)
159159
AccumulatorTuple(nt::NamedTuple) = AccumulatorTuple(tuple(nt...))
160+
AccumulatorTuple(at::AccumulatorTuple) = at
160161

161162
# When showing with text/plain, leave out information about the wrapper AccumulatorTuple.
162163
Base.show(io::IO, mime::MIME"text/plain", at::AccumulatorTuple) = show(io, mime, at.nt)

src/logdensityfunction.jl

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ using Random: Random
3333
DynamicPPL.LogDensityFunction(
3434
model::Model,
3535
getlogdensity::Any=getlogjoint_internal,
36-
varinfo::AbstractVarInfo=VarInfo(model);
36+
varinfo::AbstractVarInfo=VarInfo(model)
37+
accs::Union{NTuple{<:Any,AbstractAccumulator},AccumulatorTuple}=DynamicPPL.ldf_accs(getlogdensity);
3738
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
3839
)
3940
@@ -72,6 +73,12 @@ If you provide one of these functions, a `VarInfo` will be automatically created
7273
you provide a different function, you have to manually create a VarInfo and pass it as the
7374
third argument.
7475
76+
`accs` allows you to specify an `AccumulatorTuple` or a tuple of `AbstractAccumulator`s
77+
which will be used _when evaluating the log density_`. (Note that the accumulators from the
78+
`VarInfo` argument are discarded.) By default, this uses an internal function,
79+
`DynamicPPL.ldf_accs`, which attempts to choose an appropriate set of accumulators based on
80+
which kind of log-density is being calculated.
81+
7582
If the `adtype` keyword argument is provided, then this struct will also store the adtype
7683
along with other information for efficient calculation of the gradient of the log density.
7784
Note that preparing a `LogDensityFunction` with an AD type `AutoBackend()` requires the AD
@@ -152,6 +159,7 @@ struct LogDensityFunction{
152159
ADP<:Union{Nothing,DI.GradientPrep},
153160
# type of the vector passed to logdensity functions
154161
X<:AbstractVector,
162+
AC<:AccumulatorTuple,
155163
}
156164
model::M
157165
adtype::AD
@@ -160,11 +168,19 @@ struct LogDensityFunction{
160168
_varname_ranges::Dict{VarName,RangeAndLinked}
161169
_adprep::ADP
162170
_dim::Int
171+
_accs::AC
163172

164173
function LogDensityFunction(
165174
model::Model,
166175
getlogdensity::Any=getlogjoint_internal,
167-
varinfo::AbstractVarInfo=VarInfo(model);
176+
# TODO(penelopeysm): It is a bit redundant to pass a VarInfo, as well as the
177+
# accumulators, into here. The truth is that the VarInfo is used ONLY for generating
178+
# the ranges and link status, so arguably we should only pass in a metadata; or when
179+
# VNT is done, we should pass in only a VNT.
180+
varinfo::AbstractVarInfo=VarInfo(model),
181+
accs::Union{NTuple{<:Any,AbstractAccumulator},AccumulatorTuple}=ldf_accs(
182+
getlogdensity
183+
);
168184
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
169185
)
170186
# Figure out which variable corresponds to which index, and
@@ -187,13 +203,15 @@ struct LogDensityFunction{
187203
end
188204
x = [val for val in varinfo[:]]
189205
dim = length(x)
206+
# convert to AccumulatorTuple if needed
207+
accs = AccumulatorTuple(accs)
190208
# Do AD prep if needed
191209
prep = if adtype === nothing
192210
nothing
193211
else
194212
# Make backend-specific tweaks to the adtype
195213
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo)
196-
args = (model, getlogdensity, all_iden_ranges, all_ranges)
214+
args = (model, getlogdensity, all_iden_ranges, all_ranges, accs)
197215
if _use_closure(adtype)
198216
DI.prepare_gradient(LogDensityAt{Tlink}(args...), adtype, x)
199217
else
@@ -214,8 +232,9 @@ struct LogDensityFunction{
214232
typeof(all_iden_ranges),
215233
typeof(prep),
216234
typeof(x),
235+
typeof(accs),
217236
}(
218-
model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim
237+
model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim, accs
219238
)
220239
end
221240
end
@@ -264,11 +283,11 @@ function logdensity_at(
264283
getlogdensity::Any,
265284
iden_varname_ranges::NamedTuple,
266285
varname_ranges::Dict{VarName,RangeAndLinked},
286+
accs::AccumulatorTuple,
267287
) where {Tlink}
268288
strategy = InitFromParams(
269289
VectorWithRanges{Tlink}(iden_varname_ranges, varname_ranges, params), nothing
270290
)
271-
accs = ldf_accs(getlogdensity)
272291
_, vi = DynamicPPL.init!!(model, OnlyAccsVarInfo(accs), strategy)
273292
return getlogdensity(vi)
274293
end
@@ -279,25 +298,30 @@ end
279298
getlogdensity::Any,
280299
iden_varname_ranges::NamedTuple,
281300
varname_ranges::Dict{VarName,RangeAndLinked},
301+
accs::AccumulatorTuple,
282302
) where {Tlink}
283303
284304
A callable struct that behaves in the same way as `logdensity_at`, but stores the model and
285305
other information internally. Having two separate functions/structs allows for better
286306
performance with AD backends.
287307
"""
288-
struct LogDensityAt{Tlink,M<:Model,F,N<:NamedTuple}
308+
struct LogDensityAt{Tlink,M<:Model,F,N<:NamedTuple,A<:AccumulatorTuple}
289309
model::M
290310
getlogdensity::F
291311
iden_varname_ranges::N
292312
varname_ranges::Dict{VarName,RangeAndLinked}
313+
accs::A
293314

294315
function LogDensityAt{Tlink}(
295316
model::M,
296317
getlogdensity::F,
297318
iden_varname_ranges::N,
298319
varname_ranges::Dict{VarName,RangeAndLinked},
299-
) where {Tlink,M,F,N}
300-
return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges)
320+
accs::A,
321+
) where {Tlink,M,F,N,A}
322+
return new{Tlink,M,F,N,A}(
323+
model, getlogdensity, iden_varname_ranges, varname_ranges, accs
324+
)
301325
end
302326
end
303327
function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink}
@@ -308,6 +332,7 @@ function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink}
308332
f.getlogdensity,
309333
f.iden_varname_ranges,
310334
f.varname_ranges,
335+
f.accs,
311336
)
312337
end
313338

@@ -321,6 +346,7 @@ function LogDensityProblems.logdensity(
321346
ldf._getlogdensity,
322347
ldf._iden_varname_ranges,
323348
ldf._varname_ranges,
349+
ldf._accs,
324350
)
325351
end
326352

@@ -333,7 +359,11 @@ function LogDensityProblems.logdensity_and_gradient(
333359
return if _use_closure(ldf.adtype)
334360
DI.value_and_gradient(
335361
LogDensityAt{Tlink}(
336-
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
362+
ldf.model,
363+
ldf._getlogdensity,
364+
ldf._iden_varname_ranges,
365+
ldf._varname_ranges,
366+
ldf._accs,
337367
),
338368
ldf._adprep,
339369
ldf.adtype,
@@ -350,6 +380,7 @@ function LogDensityProblems.logdensity_and_gradient(
350380
DI.Constant(ldf._getlogdensity),
351381
DI.Constant(ldf._iden_varname_ranges),
352382
DI.Constant(ldf._varname_ranges),
383+
DI.Constant(ldf._accs),
353384
)
354385
end
355386
end

test/logdensityfunction.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,31 @@ end
123123
x = vi[:]
124124
@test LogDensityProblems.logdensity(ldf, x) == sll.scale * logpdf(Normal(x[1]), 1.0)
125125
end
126+
127+
@testset "Custom accumulators" begin
128+
# Define an accumulator that always throws an error to test that custom
129+
# accumulators can be used with LogDensityFunction
130+
struct ErrorAccumulatorException <: Exception end
131+
struct ErrorAccumulator <: DynamicPPL.AbstractAccumulator end
132+
DynamicPPL.accumulator_name(::ErrorAccumulator) = :ERROR
133+
DynamicPPL.accumulate_assume!!(
134+
::ErrorAccumulator, ::Any, ::Any, ::VarName, ::Distribution
135+
) = throw(ErrorAccumulatorException())
136+
DynamicPPL.accumulate_observe!!(
137+
::ErrorAccumulator, ::Distribution, ::Any, ::Union{VarName,Nothing}
138+
) = throw(ErrorAccumulatorException())
139+
DynamicPPL.reset(ea::ErrorAccumulator) = ea
140+
Base.copy(ea::ErrorAccumulator) = ea
141+
# Construct an LDF
142+
@model function demo_error()
143+
return x ~ Normal()
144+
end
145+
model = demo_error()
146+
ldf = LogDensityFunction(
147+
model, getlogjoint, VarInfo(model), AccumulatorTuple(ErrorAccumulator())
148+
)
149+
@test_throws ErrorAccumulatorException LogDensityProblems.logdensity(ldf, [0.0])
150+
end
126151
end
127152

128153
@testset "LogDensityFunction: Type stability" begin

0 commit comments

Comments
 (0)