@@ -106,8 +106,6 @@ struct LogDensityFunction{
106106 adtype:: AD
107107 " (internal use only) gradient preparation object for the model"
108108 prep:: Union{Nothing,DI.GradientPrep}
109- " (internal use only) the closure used for the gradient preparation"
110- closure:: Union{Nothing,Function}
111109
112110 function LogDensityFunction (
113111 model:: Model ,
@@ -116,7 +114,6 @@ struct LogDensityFunction{
116114 adtype:: Union{ADTypes.AbstractADType,Nothing} = nothing ,
117115 )
118116 if adtype === nothing
119- closure = nothing
120117 prep = nothing
121118 else
122119 # Make backend-specific tweaks to the adtype
@@ -127,16 +124,8 @@ struct LogDensityFunction{
127124 # Get a set of dummy params to use for prep
128125 x = map (identity, varinfo[:])
129126 if use_closure (adtype)
130- # The closure itself has to be stored inside the
131- # LogDensityFunction to ensure that the signature of the
132- # function being differentiated is the same as that used for
133- # preparation. See
134- # https://github.com/TuringLang/DynamicPPL.jl/pull/922 for an
135- # explanation.
136- closure = x -> logdensity_at (x, model, varinfo, context)
137- prep = DI. prepare_gradient (closure, adtype, x)
127+ prep = DI. prepare_gradient (LogDensityAt (model, varinfo, context), adtype, x)
138128 else
139- closure = nothing
140129 prep = DI. prepare_gradient (
141130 logdensity_at,
142131 adtype,
@@ -148,7 +137,7 @@ struct LogDensityFunction{
148137 end
149138 end
150139 return new {typeof(model),typeof(varinfo),typeof(context),typeof(adtype)} (
151- model, varinfo, context, adtype, prep, closure
140+ model, varinfo, context, adtype, prep
152141 )
153142 end
154143end
@@ -193,6 +182,27 @@ function logdensity_at(
193182 return getlogp (last (evaluate!! (model, varinfo_new, context)))
194183end
195184
185+ """
186+ LogDensityAt(
187+ x::AbstractVector,
188+ model::Model,
189+ varinfo::AbstractVarInfo,
190+ context::AbstractContext
191+ )
192+
193+ A callable struct that serves the same purpose as `x -> logdensity_at(x, model,
194+ varinfo, context)`.
195+ """
196+ struct LogDensityAt
197+ model:: Model
198+ varinfo:: AbstractVarInfo
199+ context:: AbstractContext
200+ end
201+ function (ld:: LogDensityAt )(x:: AbstractVector )
202+ varinfo_new = unflatten (ld. varinfo, x)
203+ return getlogp (last (evaluate!! (ld. model, varinfo_new, ld. context)))
204+ end
205+
196206# ## LogDensityProblems interface
197207
198208function LogDensityProblems. capabilities (
@@ -217,8 +227,9 @@ function LogDensityProblems.logdensity_and_gradient(
217227 # Make branching statically inferrable, i.e. type-stable (even if the two
218228 # branches happen to return different types)
219229 return if use_closure (f. adtype)
220- f. closure === nothing && error (" Closure not available; this should not happen" )
221- DI. value_and_gradient (f. closure, f. prep, f. adtype, x)
230+ DI. value_and_gradient (
231+ LogDensityAt (f. model, f. varinfo, f. context), f. prep, f. adtype, x
232+ )
222233 else
223234 DI. value_and_gradient (
224235 logdensity_at,
0 commit comments