Skip to content

Commit d503c3c

Browse files
committed
made filtering for errors only in the tilde pipeline optional
1 parent 5cd9009 commit d503c3c

File tree

3 files changed

+70
-21
lines changed

3 files changed

+70
-21
lines changed

ext/DynamicPPLJETExt.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,37 @@ module DynamicPPLJETExt
33
using DynamicPPL: DynamicPPL
44
using JET: JET
55

6+
"""
7+
is_tilde_instance(x)
8+
9+
Return `true` if `x` is a method instance of a tilde function, otherwise `false`.
10+
"""
11+
is_tilde_instance(x) = false
12+
is_tilde_instance(frame::JET.VirtualFrame) = is_tilde_instance(frame.linfo)
13+
is_tilde_instance(mi::Core.MethodInstance) = is_tilde_instance(mi.specTypes.parameters[1])
14+
is_tilde_instance(::Type{typeof(DynamicPPL.tilde_assume!!)}) = true
15+
is_tilde_instance(::Type{typeof(DynamicPPL.tilde_observe!!)}) = true
16+
is_tilde_instance(::Type{typeof(DynamicPPL.dot_tilde_assume!!)}) = true
17+
is_tilde_instance(::Type{typeof(DynamicPPL.dot_tilde_observe!!)}) = true
18+
19+
"""
20+
report_has_error_in_tilde(report)
21+
22+
Return `true` if the given error `report` contains a tilde function in its frames, otherwise `false`.
23+
24+
This is used to filter out reports that occur outside of the tilde pipeline, in an attempt to avoid
25+
warning the user about DynamicPPL doing something wrong when it is in fact an issue with the user's code.
26+
"""
27+
function report_has_error_in_tilde(report)
28+
frames = report.vst
29+
return any(is_tilde_instance, frames)
30+
end
31+
632
function DynamicPPL.determine_varinfo(
733
model::DynamicPPL.Model,
834
context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext();
935
verbose::Bool=false,
36+
only_tilde::Bool=true
1037
)
1138
# First we try with the typed varinfo.
1239
varinfo = DynamicPPL.typed_varinfo(model)
@@ -18,6 +45,9 @@ function DynamicPPL.determine_varinfo(
1845
)
1946
result_eval = JET.report_call(f_eval, argtypes_eval)
2047
reports_eval = JET.get_reports(result_eval)
48+
if only_tilde
49+
reports_eval = filter(report_has_error_in_tilde, reports_eval)
50+
end
2151
# If we get reports => we had issues so we use the untyped varinfo.
2252
issuccess &= length(reports_eval) == 0
2353
if issuccess
@@ -27,6 +57,9 @@ function DynamicPPL.determine_varinfo(
2757
)
2858
result_sample = JET.report_call(f_sample, argtypes_sample)
2959
reports_sample = JET.get_reports(result_sample)
60+
if only_tilde
61+
reports_sample = filter(report_has_error_in_tilde, reports_sample)
62+
end
3063
# If we get reports => we had issues so we use the untyped varinfo.
3164
issuccess &= length(reports_sample) == 0
3265
if !issuccess && verbose
@@ -55,3 +88,4 @@ function DynamicPPL.determine_varinfo(
5588
end
5689

5790
end
91+

test/ext/DynamicPPLJETExt.jl

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using JET: JET
2-
31
@testset "DynamicPPLJETExt.jl" begin
42
@testset "determine_varinfo" begin
53
@model function demo1()
@@ -11,7 +9,7 @@ using JET: JET
119
end
1210
end
1311
model = demo1()
14-
@test DynamicPPL.determine_varinfo(model) isa DynamicPPL.UntypedVarInfo
12+
@test DynamicPPL.determine_varinfo(model; verbose=true) isa DynamicPPL.UntypedVarInfo
1513

1614
@model demo2() = x ~ Normal()
1715
@test DynamicPPL.determine_varinfo(demo2()) isa DynamicPPL.TypedVarInfo
@@ -26,7 +24,7 @@ using JET: JET
2624
z ~ Normal()
2725
end
2826
end
29-
@test DynamicPPL.determine_varinfo(demo3()) isa DynamicPPL.UntypedVarInfo
27+
@test DynamicPPL.determine_varinfo(demo3(); verbose=true) isa DynamicPPL.UntypedVarInfo
3028

3129
# Evaluation works (and it would even do so in practice), but sampling
3230
# fill fail due to storing `Cauchy{Float64}` in `Vector{Normal{Float64}}`.
@@ -38,24 +36,38 @@ using JET: JET
3836
y ~ Cauchy() # different distibution, but same transformation => should work
3937
end
4038
end
41-
@test DynamicPPL.determine_varinfo(demo4()) isa DynamicPPL.UntypedVarInfo
39+
@test DynamicPPL.determine_varinfo(demo4(); verbose=true) isa DynamicPPL.UntypedVarInfo
40+
41+
# In this model, the type error occurs in the user code rather than in DynamicPPL.
42+
@model function demo5()
43+
x ~ Normal()
44+
xs = Any[]
45+
push!(xs, x)
46+
# `sum(::Vector{Any})` can potentially error unless the dynamic manages to resolve the
47+
# correct `zero` method. As a result, this code will run, but JET will raise this is an issue.
48+
return sum(xs)
49+
end
50+
# Should pass if we're only checking the tilde statements.
51+
@test DynamicPPL.determine_varinfo(demo5(); verbose=true) isa DynamicPPL.TypedVarInfo
52+
# Should fail if we're including errors in the model body.
53+
@test DynamicPPL.determine_varinfo(demo5(); verbose=true, only_tilde=false) isa DynamicPPL.UntypedVarInfo
4254
end
4355

44-
# @testset "demo models" begin
45-
# @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
46-
# varinfo = DynamicPPL.DynamicPPL.determine_varinfo(model)
47-
# # They should all result in typed.
48-
# @test varinfo isa DynamicPPL.TypedVarInfo
49-
# # But let's also make sure that they're not lying.
50-
# f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
51-
# model, varinfo
52-
# )
53-
# JET.test_call(f_eval, argtypes_eval)
56+
@testset "demo models" begin
57+
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
58+
varinfo = DynamicPPL.DynamicPPL.determine_varinfo(model)
59+
# They should all result in typed.
60+
@test varinfo isa DynamicPPL.TypedVarInfo
61+
# But let's also make sure that they're not lying.
62+
f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
63+
model, varinfo
64+
)
65+
JET.test_call(f_eval, argtypes_eval)
5466

55-
# f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
56-
# model, varinfo, DynamicPPL.SamplingContext()
57-
# )
58-
# JET.test_call(f_sample, argtypes_sample)
59-
# end
60-
# end
67+
f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
68+
model, varinfo, DynamicPPL.SamplingContext()
69+
)
70+
JET.test_call(f_sample, argtypes_sample)
71+
end
72+
end
6173
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ using Test
2525
using Distributions
2626
using LinearAlgebra # Diagonal
2727

28+
using JET: JET
29+
2830
using Combinatorics: combinations
2931

3032
using DynamicPPL: getargs_dottilde, getargs_tilde, Selector
@@ -71,6 +73,7 @@ include("test_util.jl")
7173

7274
@testset "extensions" begin
7375
include("ext/DynamicPPLMCMCChainsExt.jl")
76+
include("ext/DynamicPPLJETExt.jl")
7477
end
7578

7679
@testset "ad" begin

0 commit comments

Comments
 (0)