Skip to content

Commit d7d785a

Browse files
committed
typed_varinfo and untyped_varinfo handles wrapping passed context
in sampling context now so no need to handle this explicitly elsewhere
1 parent 686ed9f commit d7d785a

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

ext/DynamicPPLJETExt.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,8 @@ end
5151
function DynamicPPL._determine_varinfo_jet(
5252
model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_tilde::Bool=true
5353
)
54-
# We need a sampling context in the stack to initialize the varinfo.
55-
sampling_context = if DynamicPPL.hassampler(context)
56-
context
57-
else
58-
DynamicPPL.typed_varinfo(model, DynamicPPL.SamplingContext(context))
59-
end
6054
# First we try with the typed varinfo.
61-
varinfo = DynamicPPL.typed_varinfo(model, sampling_context)
55+
varinfo = DynamicPPL.typed_varinfo(model, context)
6256
issuccess = true
6357

6458
# Let's make sure that both evaluation and sampling doesn't result in type errors.
@@ -78,7 +72,7 @@ function DynamicPPL._determine_varinfo_jet(
7872
else
7973
# Warn the user that we can't use the type stable one.
8074
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
81-
DynamicPPL.untyped_varinfo(model, sampling_context)
75+
DynamicPPL.untyped_varinfo(model, context)
8276
end
8377
end
8478

0 commit comments

Comments
 (0)