Skip to content

Commit bc04355

Browse files
committed
Fix JETExt properly
1 parent 70bb2c4 commit bc04355

File tree

3 files changed

+23
-32
lines changed

3 files changed

+23
-32
lines changed

ext/DynamicPPLJETExt.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ using JET: JET
66
function DynamicPPL.Experimental.is_suitable_varinfo(
77
model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; only_ddpl::Bool=true
88
)
9-
# Let's make sure that both evaluation and sampling doesn't result in type errors.
109
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(model, varinfo)
1110
# If specified, we only check errors originating somewhere in the DynamicPPL.jl.
1211
# This way we don't just fall back to untyped if the user's code is the issue.

src/context_implementations.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,6 @@ end
2828
function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi)
2929
return assume(rng, sampler, right, vn, vi)
3030
end
31-
function tilde_assume(rng::Random.AbstractRNG, ::InitContext, sampler, right, vn, vi)
32-
@warn(
33-
"Encountered SamplingContext->InitContext. This method will be removed in the next PR.",
34-
)
35-
# just pretend the `InitContext` isn't there for now.
36-
return assume(rng, sampler, right, vn, vi)
37-
end
3831
function tilde_assume(::DefaultContext, sampler, right, vn, vi)
3932
# same as above but no rng
4033
return assume(Random.default_rng(), sampler, right, vn, vi)

test/ext/DynamicPPLJETExt.jl

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
DynamicPPL.UntypedVarInfo
3131

3232
# Evaluation works (and it would even do so in practice), but sampling
33-
# fill fail due to storing `Cauchy{Float64}` in `Vector{Normal{Float64}}`.
33+
# will fail due to storing `Cauchy{Float64}` in `Vector{Normal{Float64}}`.
3434
@model function demo4()
3535
x ~ Bernoulli()
3636
if x
@@ -40,11 +40,6 @@
4040
end
4141
end
4242
@test DynamicPPL.Experimental.determine_suitable_varinfo(demo4()) isa
43-
DynamicPPL.NTVarInfo
44-
init_model = DynamicPPL.contextualize(
45-
demo4(), DynamicPPL.InitContext(DynamicPPL.InitFromPrior())
46-
)
47-
@test DynamicPPL.Experimental.determine_suitable_varinfo(init_model) isa
4843
DynamicPPL.UntypedVarInfo
4944

5045
# In this model, the type error occurs in the user code rather than in DynamicPPL.
@@ -67,33 +62,37 @@
6762

6863
@testset "demo models" begin
6964
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
70-
sampling_model = contextualize(model, SamplingContext(model.context))
7165
# Use debug logging below.
7266
varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model)
73-
# Check that the inferred varinfo is indeed suitable for evaluation and sampling
74-
f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
75-
model, varinfo
76-
)
77-
JET.test_call(f_eval, argtypes_eval)
78-
79-
f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
80-
sampling_model, varinfo
81-
)
82-
JET.test_call(f_sample, argtypes_sample)
8367
# For our demo models, they should all result in typed.
8468
is_typed = varinfo isa DynamicPPL.NTVarInfo
8569
@test is_typed
86-
# If the test failed, check why it didn't infer a typed varinfo
70+
# If the test failed, check what the type stability problem was for
71+
# the typed varinfo. This is mostly useful for debugging from test
72+
# logs.
8773
if !is_typed
74+
@info "Model `$(model.f)` is not type stable with typed varinfo."
8875
typed_vi = DynamicPPL.typed_varinfo(model)
89-
f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
90-
model, typed_vi
76+
77+
@info "Evaluating with DefaultContext:"
78+
model = DynamicPPL.contextualize(
79+
model,
80+
DynamicPPL.setleafcontext(model.context, DynamicPPL.DefaultContext()),
81+
)
82+
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
83+
model, varinfo
84+
)
85+
JET.test_call(f, argtypes)
86+
87+
@info "Initialising with InitContext:"
88+
model = DynamicPPL.contextualize(
89+
model,
90+
DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext()),
9191
)
92-
JET.test_call(f_eval, argtypes_eval)
93-
f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
94-
sampling_model, typed_vi
92+
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
93+
model, varinfo
9594
)
96-
JET.test_call(f_sample, argtypes_sample)
95+
JET.test_call(f, argtypes)
9796
end
9897
end
9998
end

0 commit comments

Comments
 (0)