@@ -106,6 +106,8 @@ struct LogDensityFunction{
106
106
adtype:: AD
107
107
" (internal use only) gradient preparation object for the model"
108
108
prep:: Union{Nothing,DI.GradientPrep}
109
+ " (internal use only) the closure used for the gradient preparation"
110
+ closure:: Union{Nothing,Function}
109
111
110
112
function LogDensityFunction (
111
113
model:: Model ,
@@ -124,10 +126,16 @@ struct LogDensityFunction{
124
126
# Get a set of dummy params to use for prep
125
127
x = map (identity, varinfo[:])
126
128
if use_closure (adtype)
127
- prep = DI. prepare_gradient (
128
- x -> logdensity_at (x, model, varinfo, context), adtype, x
129
- )
129
+ # The closure itself has to be stored inside the
130
+ # LogDensityFunction to ensure that the signature of the
131
+ # function being differentiated is the same as that used for
132
+ # preparation. See
133
+ # https://github.com/TuringLang/DynamicPPL.jl/pull/922 for an
134
+ # explanation.
135
+ closure = x -> logdensity_at (x, model, varinfo, context)
136
+ prep = DI. prepare_gradient (closure, adtype, x)
130
137
else
138
+ closure = nothing
131
139
prep = DI. prepare_gradient (
132
140
logdensity_at,
133
141
adtype,
@@ -139,7 +147,7 @@ struct LogDensityFunction{
139
147
end
140
148
end
141
149
return new {typeof(model),typeof(varinfo),typeof(context),typeof(adtype)} (
142
- model, varinfo, context, adtype, prep
150
+ model, varinfo, context, adtype, prep, closure
143
151
)
144
152
end
145
153
end
@@ -208,9 +216,8 @@ function LogDensityProblems.logdensity_and_gradient(
208
216
# Make branching statically inferrable, i.e. type-stable (even if the two
209
217
# branches happen to return different types)
210
218
return if use_closure (f. adtype)
211
- DI. value_and_gradient (
212
- x -> logdensity_at (x, f. model, f. varinfo, f. context), f. prep, f. adtype, x
213
- )
219
+ f. closure === nothing && error (" Closure not available; this should not happen" )
220
+ DI. value_and_gradient (f. closure, f. prep, f. adtype, x)
214
221
else
215
222
DI. value_and_gradient (
216
223
logdensity_at,
0 commit comments