Skip to content

Commit e07ecdb

Browse files
committed
moved the JET.jl-dependent experimental determine_varinfo into a
separate `Experimental` module, as discussed
1 parent b6b4bff commit e07ecdb

File tree

4 files changed

+23
-70
lines changed

4 files changed

+23
-70
lines changed

ext/DynamicPPLJETExt.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module DynamicPPLJETExt
33
using DynamicPPL: DynamicPPL
44
using JET: JET
55

6-
function DynamicPPL.is_suitable_varinfo(
6+
function DynamicPPL.Experimental.is_suitable_varinfo(
77
model::DynamicPPL.Model,
88
context::DynamicPPL.AbstractContext,
99
varinfo::DynamicPPL.AbstractVarInfo;
@@ -23,15 +23,17 @@ function DynamicPPL.is_suitable_varinfo(
2323
return length(JET.get_reports(result)) == 0, result
2424
end
2525

26-
function DynamicPPL._determine_varinfo_jet(
26+
function DynamicPPL.Experimental._determine_varinfo_jet(
2727
model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true
2828
)
2929
# First we try with the typed varinfo.
3030
varinfo = DynamicPPL.typed_varinfo(model, context)
3131
issuccess = true
3232

3333
# Let's make sure that both evaluation and sampling doesn't result in type errors.
34-
issuccess, result = DynamicPPL.is_suitable_varinfo(model, context, varinfo; only_ddpl)
34+
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
35+
model, context, varinfo; only_ddpl
36+
)
3537

3638
if !issuccess
3739
# Useful information for debugging.

src/DynamicPPL.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,17 +196,19 @@ include("values_as_in_model.jl")
196196
include("debug_utils.jl")
197197
using .DebugUtils
198198

199+
include("experimental.jl")
200+
199201
# Better error message if users forget to load the AD package
200202
if isdefined(Base.Experimental, :register_error_hint)
201203
function __init__()
202204
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
203205
requires_jet =
204-
exc.f === _determine_varinfo_jet &&
206+
exc.f === DynamicPPL.Experimental._determine_varinfo_jet &&
205207
length(argtypes) >= 2 &&
206208
argtypes[1] <: Model &&
207209
argtypes[2] <: AbstractContext
208210
requires_jet |=
209-
exc.f === is_suitable_varinfo &&
211+
exc.f === DynamicPPL.Experimental.is_suitable_varinfo &&
210212
length(argtypes) >= 3 &&
211213
argtypes[1] <: Model &&
212214
argtypes[2] <: AbstractContext &&

src/model_utils.jl

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -207,60 +207,3 @@ function value_iterator_from_chain(vi::AbstractVarInfo, chain)
207207
values_from_chain!(vi, chain, chain_idx, iteration_idx, OrderedDict{VarName,Any}())
208208
end
209209
end
210-
211-
"""
212-
is_suitable_varinfo(model::Model, context::AbstractContext, varinfo::AbstractVarInfo; kwargs...)
213-
214-
Check if the `model` supports evaluation using the provided `context` and `varinfo`.
215-
216-
!!! warning
217-
Loading JET.jl is required before calling this function.
218-
219-
# Arguments
220-
- `model`: The model to to verify the support for.
221-
- `context`: The context to use for the model evaluation.
222-
- `varinfo`: The varinfo to verify the support for.
223-
224-
# Keyword Arguments
225-
- `only_ddpl`: If `true`, only consider error reports occuring in the tilde pipeline. Default: `true`.
226-
227-
# Returns
228-
- `issuccess`: `true` if the model supports the varinfo, otherwise `false`.
229-
- `report`: The result of `report_call` from JET.jl.
230-
"""
231-
function is_suitable_varinfo end
232-
233-
# Internal hook for JET.jl to overload.
234-
function _determine_varinfo_jet end
235-
236-
"""
237-
determine_suitable_varinfo(model[, context]; verbose::Bool=false, only_ddpl::Bool=true)
238-
239-
Return a suitable varinfo for the given `model`.
240-
241-
See also: [`DynamicPPL.is_suitable_varinfo`](@ref).
242-
243-
!!! warning
244-
For full functionality, this requires JET.jl to be loaded.
245-
If JET.jl is not loaded, this function will assume the model is compatible with typed varinfo.
246-
247-
# Arguments
248-
- `model`: The model for which to determine the varinfo.
249-
- `context`: The context to use for the model evaluation. Default: `SamplingContext()`.
250-
251-
# Keyword Arguments
252-
- `only_ddpl`: If `true`, only consider error reports within DynamicPPL.jl.
253-
"""
254-
function determine_suitable_varinfo(
255-
model::Model, context::AbstractContext=SamplingContext(); only_ddpl::Bool=true
256-
)
257-
# If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that.
258-
return if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing
259-
_determine_varinfo_jet(model, context; only_ddpl)
260-
else
261-
# Warn the user.
262-
@warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo."
263-
# Otherwise, we use the, possibly incorrect, default typed varinfo (to stay backwards compat).
264-
typed_varinfo(model, context)
265-
end
266-
end

test/ext/DynamicPPLJETExt.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
end
1010
end
1111
model = demo1()
12-
@test DynamicPPL.determine_suitable_varinfo(model) isa DynamicPPL.UntypedVarInfo
12+
@test DynamicPPL.Experimental.determine_suitable_varinfo(model) isa
13+
DynamicPPL.UntypedVarInfo
1314

1415
@model demo2() = x ~ Normal()
15-
@test DynamicPPL.determine_suitable_varinfo(demo2()) isa DynamicPPL.TypedVarInfo
16+
@test DynamicPPL.Experimental.determine_suitable_varinfo(demo2()) isa
17+
DynamicPPL.TypedVarInfo
1618

1719
@model function demo3()
1820
# Just making sure that nothing strange happens when type inference fails.
@@ -24,7 +26,8 @@
2426
z ~ Normal()
2527
end
2628
end
27-
@test DynamicPPL.determine_suitable_varinfo(demo3()) isa DynamicPPL.UntypedVarInfo
29+
@test DynamicPPL.Experimental.determine_suitable_varinfo(demo3()) isa
30+
DynamicPPL.UntypedVarInfo
2831

2932
# Evaluation works (and it would even do so in practice), but sampling
3033
# fill fail due to storing `Cauchy{Float64}` in `Vector{Normal{Float64}}`.
@@ -36,7 +39,8 @@
3639
y ~ Cauchy() # different distibution, but same transformation
3740
end
3841
end
39-
@test DynamicPPL.determine_suitable_varinfo(demo4()) isa DynamicPPL.UntypedVarInfo
42+
@test DynamicPPL.Experimental.determine_suitable_varinfo(demo4()) isa
43+
DynamicPPL.UntypedVarInfo
4044

4145
# In this model, the type error occurs in the user code rather than in DynamicPPL.
4246
@model function demo5()
@@ -48,16 +52,18 @@
4852
return sum(xs)
4953
end
5054
# Should pass if we're only checking the tilde statements.
51-
@test DynamicPPL.determine_suitable_varinfo(demo5()) isa DynamicPPL.TypedVarInfo
55+
@test DynamicPPL.Experimental.determine_suitable_varinfo(demo5()) isa
56+
DynamicPPL.TypedVarInfo
5257
# Should fail if we're including errors in the model body.
53-
@test DynamicPPL.determine_suitable_varinfo(demo5(); only_ddpl=false) isa
54-
DynamicPPL.UntypedVarInfo
58+
@test DynamicPPL.Experimental.determine_suitable_varinfo(
59+
demo5(); only_ddpl=false
60+
) isa DynamicPPL.UntypedVarInfo
5561
end
5662

5763
@testset "demo models" begin
5864
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
5965
# Use debug logging below.
60-
varinfo = DynamicPPL.DynamicPPL.determine_suitable_varinfo(model)
66+
varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model)
6167
# They should all result in typed.
6268
@test varinfo isa DynamicPPL.TypedVarInfo
6369
# But let's also make sure that they're not lying.

0 commit comments

Comments
 (0)