Skip to content

Commit d1a5bab

Browse files
committed
determine_suitable_varinfo now only performs checks using the
provided context, but uses `SamplingContext` by default (as this should be a stricter check than just evaluation)
1 parent c06b080 commit d1a5bab

File tree

3 files changed

+18
-57
lines changed

3 files changed

+18
-57
lines changed

ext/DynamicPPLJETExt.jl

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,37 +50,29 @@ end
5050

5151
function DynamicPPL._determine_varinfo_jet(
5252
model::DynamicPPL.Model,
53-
context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext();
54-
verbose::Bool=false,
53+
context::DynamicPPL.AbstractContext;
5554
only_tilde::Bool=true,
5655
)
5756
# First we try with the typed varinfo.
58-
varinfo = DynamicPPL.typed_varinfo(model)
57+
varinfo = if DynamicPPL.hassampler(context)
58+
# Don't need to add sampling context for this to work.
59+
DynamicPPL.typed_varinfo(model, context)
60+
else
61+
# Need a sampling context to initialize the varinfo.
62+
DynamicPPL.typed_varinfo(model, DynamicPPL.SamplingContext(context))
63+
end
5964
issuccess = true
6065

6166
# Let's make sure that both evaluation and sampling doesn't result in type errors.
62-
issuccess, reports_eval = DynamicPPL.is_suitable_varinfo(
67+
issuccess, reports = DynamicPPL.is_suitable_varinfo(
6368
model, context, varinfo; only_tilde
6469
)
6570

66-
if issuccess
67-
# Evaluation succeeded, let's try sampling.
68-
issuccess_sample, reports_sample = DynamicPPL.is_suitable_varinfo(
69-
model, DynamicPPL.SamplingContext(context), varinfo; only_tilde
70-
)
71-
issuccess &= issuccess_sample
72-
if !issuccess && verbose
73-
# Show the user the issues.
74-
@warn "Sampling with typed varinfo failed with the following issues:"
75-
for report in reports_sample
76-
@warn report
77-
end
78-
end
79-
elseif verbose
80-
# Show the user the issues.
81-
@warn "Evaluaton with typed varinfo failed with the following issues:"
82-
for report in reports_eval
83-
@warn report
71+
if !issuccess
72+
# Useful information for debugging.
73+
@debug "Evaluaton with typed varinfo failed with the following issues:"
74+
for report in reports
75+
@debug report
8476
end
8577
end
8678

src/model_utils.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -246,24 +246,19 @@ See also: [`DynamicPPL.is_suitable_varinfo`](@ref).
246246
247247
# Arguments
248248
- `model`: The model for which to determine the varinfo.
249-
- `context`: The context to use for the evaluation and sampling. Default: `DefaultContext()`.
249+
- `context`: The context to use for the model evaluation. Default: `SamplingContext()`.
250250
251251
# Keyword Arguments
252-
- `verbose`: If `true`, the user will be warned about the issues that caused the fallback to untyped varinfo.
253252
- `only_tilde`: If `true`, only consider error reports occuring in the tilde pipeline. Default: `true`.
254-
255-
# Keyword Arguments
256-
- `verbose`: If `true`, the user will be warned about the issues that caused the fallback to untyped varinfo.
257253
"""
258254
function determine_suitable_varinfo(
259255
model::Model,
260-
context::AbstractContext=DefaultContext();
261-
verbose::Bool=false,
256+
context::AbstractContext=SamplingContext();
262257
only_tilde::Bool=true,
263258
)
264259
# If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that.
265260
if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing
266-
return _determine_varinfo_jet(model, context; only_tilde, verbose)
261+
return _determine_varinfo_jet(model, context; only_tilde)
267262
else
268263
@warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo."
269264
end

src/sampler.jl

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -77,33 +77,7 @@ function default_varinfo(
7777
context::AbstractContext,
7878
)
7979
init_sampler = initialsampler(sampler)
80-
varinfo = VarInfo(rng, model, init_sampler, context)
81-
82-
# If JET.jl has been loaded => use static checking to see if we can actually use the typed varinfo.
83-
if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing
84-
# Check evaluation.
85-
issuccess = first(is_suitable_varinfo(model, context, varinfo))
86-
if issuccess
87-
# Check the initial sampler.
88-
issuccess &= first(
89-
is_suitable_varinfo(model, SamplingContext(init_sampler, context), varinfo)
90-
)
91-
end
92-
93-
if issuccess
94-
# Check the actual sampler.
95-
issuccess &= first(
96-
is_suitable_varinfo(model, SamplingContext(sampler, context), varinfo)
97-
)
98-
end
99-
100-
if !issuccess
101-
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
102-
# TODO: Use a constructor which takes the rng and the sampler too.
103-
varinfo = untyped_varinfo(model)
104-
end
105-
end
106-
80+
varinfo = determine_suitable_varinfo(model, SamplingContext(rng, init_sampler, context))
10781
return varinfo
10882
end
10983

0 commit comments

Comments
 (0)