@@ -124,9 +124,7 @@ struct LogDensityFunction{
124
124
# Get a set of dummy params to use for prep
125
125
x = map (identity, varinfo[:])
126
126
if use_closure (adtype)
127
- prep = DI. prepare_gradient (
128
- x -> logdensity_at (x, model, varinfo, context), adtype, x
129
- )
127
+ prep = DI. prepare_gradient (LogDensityAt (model, varinfo, context), adtype, x)
130
128
else
131
129
prep = DI. prepare_gradient (
132
130
logdensity_at,
@@ -184,6 +182,26 @@ function logdensity_at(
184
182
return getlogp (last (evaluate!! (model, varinfo_new, context)))
185
183
end
186
184
185
+ """
186
+ LogDensityAt{M<:Model,V<:AbstractVarInfo,C<:AbstractContext}(
187
+ model::M
188
+ varinfo::V
189
+ context::C
190
+ )
191
+
192
+ A callable struct that serves the same purpose as `x -> logdensity_at(x, model,
193
+ varinfo, context)`.
194
+ """
195
+ struct LogDensityAt{M<: Model ,V<: AbstractVarInfo ,C<: AbstractContext }
196
+ model:: M
197
+ varinfo:: V
198
+ context:: C
199
+ end
200
+ function (ld:: LogDensityAt )(x:: AbstractVector )
201
+ varinfo_new = unflatten (ld. varinfo, x)
202
+ return getlogp (last (evaluate!! (ld. model, varinfo_new, ld. context)))
203
+ end
204
+
187
205
# ## LogDensityProblems interface
188
206
189
207
function LogDensityProblems. capabilities (
@@ -209,7 +227,7 @@ function LogDensityProblems.logdensity_and_gradient(
209
227
# branches happen to return different types)
210
228
return if use_closure (f. adtype)
211
229
DI. value_and_gradient (
212
- x -> logdensity_at (x, f. model, f. varinfo, f. context), f. prep, f. adtype, x
230
+ LogDensityAt ( f. model, f. varinfo, f. context), f. prep, f. adtype, x
213
231
)
214
232
else
215
233
DI. value_and_gradient (
0 commit comments