Skip to content

Commit 0f1de5c

Browse files
committed
Fix strictness failure with DifferentiationInterface 0.7
1 parent 3510612 commit 0f1de5c

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

src/logdensityfunction.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ struct LogDensityFunction{
106106
adtype::AD
107107
"(internal use only) gradient preparation object for the model"
108108
prep::Union{Nothing,DI.GradientPrep}
109+
"(internal use only) the closure used for the gradient preparation"
110+
closure::Union{Nothing,Function}
109111

110112
function LogDensityFunction(
111113
model::Model,
@@ -124,10 +126,16 @@ struct LogDensityFunction{
124126
# Get a set of dummy params to use for prep
125127
x = map(identity, varinfo[:])
126128
if use_closure(adtype)
127-
prep = DI.prepare_gradient(
128-
x -> logdensity_at(x, model, varinfo, context), adtype, x
129-
)
129+
# The closure itself has to be stored inside the
130+
# LogDensityFunction to ensure that the signature of the
131+
# function being differentiated is the same as that used for
132+
# preparation. See
133+
# https://github.com/TuringLang/DynamicPPL.jl/pull/922 for an
134+
# explanation.
135+
closure = x -> logdensity_at(x, model, varinfo, context)
136+
prep = DI.prepare_gradient(closure, adtype, x)
130137
else
138+
closure = nothing
131139
prep = DI.prepare_gradient(
132140
logdensity_at,
133141
adtype,
@@ -139,7 +147,7 @@ struct LogDensityFunction{
139147
end
140148
end
141149
return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}(
142-
model, varinfo, context, adtype, prep
150+
model, varinfo, context, adtype, prep, closure
143151
)
144152
end
145153
end
@@ -208,9 +216,8 @@ function LogDensityProblems.logdensity_and_gradient(
208216
# Make branching statically inferrable, i.e. type-stable (even if the two
209217
# branches happen to return different types)
210218
return if use_closure(f.adtype)
211-
DI.value_and_gradient(
212-
x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x
213-
)
219+
f.closure === nothing && error("Closure not available; this should not happen")
220+
DI.value_and_gradient(f.closure, f.prep, f.adtype, x)
214221
else
215222
DI.value_and_gradient(
216223
logdensity_at,

0 commit comments

Comments
 (0)