Skip to content

Commit 6f670c4

Browse files
committed
Use LogDensityAt callable struct instead of closure
1 parent c7d89bb commit 6f670c4

File tree

1 file changed

+26
-15
lines changed

1 file changed

+26
-15
lines changed

src/logdensityfunction.jl

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,6 @@ struct LogDensityFunction{
106106
adtype::AD
107107
"(internal use only) gradient preparation object for the model"
108108
prep::Union{Nothing,DI.GradientPrep}
109-
"(internal use only) the closure used for the gradient preparation"
110-
closure::Union{Nothing,Function}
111109

112110
function LogDensityFunction(
113111
model::Model,
@@ -116,7 +114,6 @@ struct LogDensityFunction{
116114
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
117115
)
118116
if adtype === nothing
119-
closure = nothing
120117
prep = nothing
121118
else
122119
# Make backend-specific tweaks to the adtype
@@ -127,16 +124,8 @@ struct LogDensityFunction{
127124
# Get a set of dummy params to use for prep
128125
x = map(identity, varinfo[:])
129126
if use_closure(adtype)
130-
# The closure itself has to be stored inside the
131-
# LogDensityFunction to ensure that the signature of the
132-
# function being differentiated is the same as that used for
133-
# preparation. See
134-
# https://github.com/TuringLang/DynamicPPL.jl/pull/922 for an
135-
# explanation.
136-
closure = x -> logdensity_at(x, model, varinfo, context)
137-
prep = DI.prepare_gradient(closure, adtype, x)
127+
prep = DI.prepare_gradient(LogDensityAt(model, varinfo, context), adtype, x)
138128
else
139-
closure = nothing
140129
prep = DI.prepare_gradient(
141130
logdensity_at,
142131
adtype,
@@ -148,7 +137,7 @@ struct LogDensityFunction{
148137
end
149138
end
150139
return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}(
151-
model, varinfo, context, adtype, prep, closure
140+
model, varinfo, context, adtype, prep
152141
)
153142
end
154143
end
@@ -193,6 +182,27 @@ function logdensity_at(
193182
return getlogp(last(evaluate!!(model, varinfo_new, context)))
194183
end
195184

185+
"""
186+
LogDensityAt(
187+
x::AbstractVector,
188+
model::Model,
189+
varinfo::AbstractVarInfo,
190+
context::AbstractContext
191+
)
192+
193+
A callable struct that serves the same purpose as `x -> logdensity_at(x, model,
194+
varinfo, context)`.
195+
"""
196+
struct LogDensityAt
197+
model::Model
198+
varinfo::AbstractVarInfo
199+
context::AbstractContext
200+
end
201+
function (ld::LogDensityAt)(x::AbstractVector)
202+
varinfo_new = unflatten(ld.varinfo, x)
203+
return getlogp(last(evaluate!!(ld.model, varinfo_new, ld.context)))
204+
end
205+
196206
### LogDensityProblems interface
197207

198208
function LogDensityProblems.capabilities(
@@ -217,8 +227,9 @@ function LogDensityProblems.logdensity_and_gradient(
217227
# Make branching statically inferrable, i.e. type-stable (even if the two
218228
# branches happen to return different types)
219229
return if use_closure(f.adtype)
220-
f.closure === nothing && error("Closure not available; this should not happen")
221-
DI.value_and_gradient(f.closure, f.prep, f.adtype, x)
230+
DI.value_and_gradient(
231+
LogDensityAt(f.model, f.varinfo, f.context), f.prep, f.adtype, x
232+
)
222233
else
223234
DI.value_and_gradient(
224235
logdensity_at,

0 commit comments

Comments
 (0)