Skip to content

Commit 99d53c6

Browse files
committed
Add extra accs argument for LogDensityFunction
1 parent 47f6b7e commit 99d53c6

File tree

5 files changed

+62
-9
lines changed

5 files changed

+62
-9
lines changed

HISTORY.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# DynamicPPL Changelog
22

3+
## 0.39.11
4+
5+
Allow passing `accs::AccumulatorTuple` into the `LogDensityFunction` constructor to specify custom accumulators to use when evaluating the model.
6+
Previously, this was hard-coded.
7+
8+
Also exports `DynamicPPL.AccumulatorTuple`.
9+
310
## 0.39.10
411

512
Rename the internal functions `matchingvalue` and `get_matching_type` to `convert_model_argument` and `promote_model_type_argument` respectively.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.39.10"
3+
version = "0.39.11"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ export AbstractVarInfo,
5151
LogLikelihoodAccumulator,
5252
LogPriorAccumulator,
5353
LogJacobianAccumulator,
54+
AccumulatorTuple,
5455
push!!,
5556
empty!!,
5657
subset,

src/logdensityfunction.jl

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ struct LogDensityFunction{
152152
ADP<:Union{Nothing,DI.GradientPrep},
153153
# type of the vector passed to logdensity functions
154154
X<:AbstractVector,
155+
AC<:AccumulatorTuple,
155156
}
156157
model::M
157158
adtype::AD
@@ -160,11 +161,17 @@ struct LogDensityFunction{
160161
_varname_ranges::Dict{VarName,RangeAndLinked}
161162
_adprep::ADP
162163
_dim::Int
164+
_accs::AC
163165

164166
function LogDensityFunction(
165167
model::Model,
166168
getlogdensity::Any=getlogjoint_internal,
167-
varinfo::AbstractVarInfo=VarInfo(model);
169+
# TODO(penelopeysm): It is a bit redundant to pass a VarInfo, as well as the
170+
# accumulators, into here. The truth is that the VarInfo is used ONLY for generating
171+
# the ranges and link status, so arguably we should only pass in a metadata; or when
172+
# VNT is done, we should pass in only a VNT.
173+
varinfo::AbstractVarInfo=VarInfo(model),
174+
accs::AccumulatorTuple=ldf_accs(getlogdensity);
168175
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
169176
)
170177
# Figure out which variable corresponds to which index, and
@@ -193,7 +200,7 @@ struct LogDensityFunction{
193200
else
194201
# Make backend-specific tweaks to the adtype
195202
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo)
196-
args = (model, getlogdensity, all_iden_ranges, all_ranges)
203+
args = (model, getlogdensity, all_iden_ranges, all_ranges, accs)
197204
if _use_closure(adtype)
198205
DI.prepare_gradient(LogDensityAt{Tlink}(args...), adtype, x)
199206
else
@@ -214,8 +221,9 @@ struct LogDensityFunction{
214221
typeof(all_iden_ranges),
215222
typeof(prep),
216223
typeof(x),
224+
typeof(accs),
217225
}(
218-
model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim
226+
model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim, accs
219227
)
220228
end
221229
end
@@ -264,11 +272,11 @@ function logdensity_at(
264272
getlogdensity::Any,
265273
iden_varname_ranges::NamedTuple,
266274
varname_ranges::Dict{VarName,RangeAndLinked},
275+
accs::AccumulatorTuple,
267276
) where {Tlink}
268277
strategy = InitFromParams(
269278
VectorWithRanges{Tlink}(iden_varname_ranges, varname_ranges, params), nothing
270279
)
271-
accs = ldf_accs(getlogdensity)
272280
_, vi = DynamicPPL.init!!(model, OnlyAccsVarInfo(accs), strategy)
273281
return getlogdensity(vi)
274282
end
@@ -279,25 +287,30 @@ end
279287
getlogdensity::Any,
280288
iden_varname_ranges::NamedTuple,
281289
varname_ranges::Dict{VarName,RangeAndLinked},
290+
accs::AccumulatorTuple,
282291
) where {Tlink}
283292
284293
A callable struct that behaves in the same way as `logdensity_at`, but stores the model and
285294
other information internally. Having two separate functions/structs allows for better
286295
performance with AD backends.
287296
"""
288-
struct LogDensityAt{Tlink,M<:Model,F,N<:NamedTuple}
297+
struct LogDensityAt{Tlink,M<:Model,F,N<:NamedTuple,A<:AccumulatorTuple}
289298
model::M
290299
getlogdensity::F
291300
iden_varname_ranges::N
292301
varname_ranges::Dict{VarName,RangeAndLinked}
302+
accs::A
293303

294304
function LogDensityAt{Tlink}(
295305
model::M,
296306
getlogdensity::F,
297307
iden_varname_ranges::N,
298308
varname_ranges::Dict{VarName,RangeAndLinked},
299-
) where {Tlink,M,F,N}
300-
return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges)
309+
accs::A,
310+
) where {Tlink,M,F,N,A}
311+
return new{Tlink,M,F,N,A}(
312+
model, getlogdensity, iden_varname_ranges, varname_ranges, accs
313+
)
301314
end
302315
end
303316
function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink}
@@ -308,6 +321,7 @@ function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink}
308321
f.getlogdensity,
309322
f.iden_varname_ranges,
310323
f.varname_ranges,
324+
f.accs,
311325
)
312326
end
313327

@@ -321,6 +335,7 @@ function LogDensityProblems.logdensity(
321335
ldf._getlogdensity,
322336
ldf._iden_varname_ranges,
323337
ldf._varname_ranges,
338+
ldf._accs,
324339
)
325340
end
326341

@@ -333,7 +348,11 @@ function LogDensityProblems.logdensity_and_gradient(
333348
return if _use_closure(ldf.adtype)
334349
DI.value_and_gradient(
335350
LogDensityAt{Tlink}(
336-
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
351+
ldf.model,
352+
ldf._getlogdensity,
353+
ldf._iden_varname_ranges,
354+
ldf._varname_ranges,
355+
ldf._accs,
337356
),
338357
ldf._adprep,
339358
ldf.adtype,
@@ -350,6 +369,7 @@ function LogDensityProblems.logdensity_and_gradient(
350369
DI.Constant(ldf._getlogdensity),
351370
DI.Constant(ldf._iden_varname_ranges),
352371
DI.Constant(ldf._varname_ranges),
372+
DI.Constant(ldf._accs),
353373
)
354374
end
355375
end

test/logdensityfunction.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,31 @@ end
123123
x = vi[:]
124124
@test LogDensityProblems.logdensity(ldf, x) == sll.scale * logpdf(Normal(x[1]), 1.0)
125125
end
126+
127+
@testset "Custom accumulators" begin
128+
# Define an accumulator that always throws an error to test that custom
129+
# accumulators can be used with LogDensityFunction
130+
struct ErrorAccumulatorException <: Exception end
131+
struct ErrorAccumulator <: DynamicPPL.AbstractAccumulator end
132+
DynamicPPL.accumulator_name(::ErrorAccumulator) = :ERROR
133+
DynamicPPL.accumulate_assume!!(
134+
::ErrorAccumulator, ::Any, ::Any, ::VarName, ::Distribution
135+
) = throw(ErrorAccumulatorException())
136+
DynamicPPL.accumulate_observe!!(
137+
::ErrorAccumulator, ::Distribution, ::Any, ::Union{VarName,Nothing}
138+
) = throw(ErrorAccumulatorException())
139+
DynamicPPL.reset(ea::ErrorAccumulator) = ea
140+
Base.copy(ea::ErrorAccumulator) = ea
141+
# Construct an LDF
142+
@model function demo_error()
143+
return x ~ Normal()
144+
end
145+
model = demo_error()
146+
ldf = LogDensityFunction(
147+
model, getlogjoint, VarInfo(model), AccumulatorTuple(ErrorAccumulator())
148+
)
149+
@test_throws ErrorAccumulatorException LogDensityProblems.logdensity(ldf, [0.0])
150+
end
126151
end
127152

128153
@testset "LogDensityFunction: Type stability" begin

0 commit comments

Comments
 (0)