Skip to content

Commit 145f471

Browse files
torfjeldegithub-actions[bot]mhauruTor Fjelde
authored
Using JET.jl to determine if typed varinfo is okay (#728)
* fixed calls to `to_linked_internal_transform` * fixed incorrect call to `acclogp_assume!!` * added `determine_varinfo` and an implementation using JET for this * made filtering for errors only in the tilde pipeline optional * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed incorrect comment * added test for the branch we were currently imssing * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * renamed `determine_varinfo` to `determine_suitable_varinfo` with fallback to current behavior + `supports_varinfo` to `is_suitable_varinfo` * removed now-redundant init used with Requires.jl, since this is no longer needed on Julia 1.10 and onwards + added error hint for when JET.jl has not been loaded * `determine_suitable_varinfo` now only performs checks using the provided context, but uses `SamplingContext` by default (as this should be a stricter check than just evaluation) * formatting * updated error hint * added def of `untyped_varinfo` which takes just `model` and `context` * fixed incorrect call to `untyped_varinfo` in `_determine_varinfo_jet` * explicitly call `typed_varinfo` when we want such a thing rather than the ambiguous `VarINfo` * `typed_varinfo` and `untyped_varinfo` handles wrapping passed context in sampling context now so no need to handle this explicitly elsewhere * use `determine_suitable_varinfo` in `LogDensityFunction` when not constructed * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * formatting * fixed a bug in `DynamicPPLJETExt.is_tilde_instance` * updated docs * Update docs/src/internals/varinfo.md Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added back def of `untyped_varinfo` that shouldn't have been removed + fixed call in docs * minor codestyle improvement * temporary hack to debug what's happening * more debugging * use the `target_modules` kwarg in `report_call` instead of manually filtering the frames * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * more debugging * more debugging * more debugging: try with new bijectors.jl * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * removed the hacky debugging stuff used for the CI * removed now-redudant filtering methods since we use JET's own filters * bump Bijectors.jl compat entry to 0.15.1 in test so JET.jl tests pass * moved the JET.jl-dependent experimental `determine_varinfo` into a separate `Experimental` module, as discussed * forgot to add the experimenta.jl file in previous commit * reverted changes to `default_varinfo` and `LogDensityFunction` * added a bunch of docs for introduced and existing methods Added docs for `determine_suitable_varinfo` and existing methods that should be documented, e.g. `untyped_varinfo`, `typed_varinfo`, and `default_varinfo` * added doctests to `determine_suitable_varinfo` * added JET.jl as a dep to docs * fixed referencing in docs * fixed docstring * fixed doctest * Update Project.toml * applied suggestions from @mhauru Co-authored-by: Markus Hauru <[email protected]> * fixed doctests * finally fixed doctests * removed unnecessary `typed_varinfo` and `untyped_varinfo` methods * added filter to ignore source of warnings in doctest --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Markus Hauru <[email protected]> Co-authored-by: Tor Fjelde <[email protected]>
1 parent 0548ddf commit 145f471

File tree

12 files changed

+329
-38
lines changed

12 files changed

+329
-38
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2929
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3030
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
3131
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
32+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
3233
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
3334
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
3435
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
@@ -37,6 +38,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3738
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
3839
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
3940
DynamicPPLForwardDiffExt = ["ForwardDiff"]
41+
DynamicPPLJETExt = ["JET"]
4042
DynamicPPLMCMCChainsExt = ["MCMCChains"]
4143
DynamicPPLMooncakeExt = ["Mooncake"]
4244
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
@@ -55,6 +57,7 @@ Distributions = "0.25"
5557
DocStringExtensions = "0.9"
5658
EnzymeCore = "0.6 - 0.8"
5759
ForwardDiff = "0.10"
60+
JET = "0.9"
5861
LinearAlgebra = "1.6"
5962
LogDensityProblems = "2"
6063
LogDensityProblemsAD = "1.7.0"

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
66
DocumenterMermaid = "a078cd44-4d9c-4618-b545-3ab9d77f9177"
77
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
88
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
9+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
910
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1011
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
1112
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
@@ -18,6 +19,7 @@ Documenter = "1"
1819
DocumenterMermaid = "0.1"
1920
FillArrays = "0.13, 1"
2021
ForwardDiff = "0.10"
22+
JET = "0.9"
2123
LogDensityProblems = "2"
2224
MCMCChains = "5, 6"
2325
StableRNGs = "1"

docs/src/api.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,13 @@ AbstractVarInfo
265265

266266
But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary.
267267

268+
For constructing the "default" typed and untyped varinfo types used in DynamicPPL (see [the section on varinfo design](@ref "Design of `VarInfo`") for more on this), we have the following two methods:
269+
270+
```@docs
271+
DynamicPPL.untyped_varinfo
272+
DynamicPPL.typed_varinfo
273+
```
274+
268275
#### `VarInfo`
269276

270277
```@docs
@@ -425,6 +432,19 @@ DynamicPPL.loadstate
425432
DynamicPPL.initialsampler
426433
```
427434

435+
Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`.
436+
437+
```@docs
438+
DynamicPPL.default_varinfo
439+
```
440+
441+
There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_varinfo`](@ref), which uses static checking via [JET.jl](https://github.com/aviatesk/JET.jl) to determine whether one should use [`DynamicPPL.typed_varinfo`](@ref) or [`DynamicPPL.untyped_varinfo`](@ref), depending on which supports the model:
442+
443+
```@docs
444+
DynamicPPL.Experimental.determine_suitable_varinfo
445+
DynamicPPL.Experimental.is_suitable_varinfo
446+
```
447+
428448
### [Model-Internal Functions](@id model_internal)
429449

430450
```@docs

docs/src/internals/varinfo.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,7 @@ For example, with the model above we have
7979

8080
```@example varinfo-design
8181
# Type-unstable `VarInfo`
82-
varinfo_untyped = DynamicPPL.untyped_varinfo(
83-
demo(), SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata()
84-
)
82+
varinfo_untyped = DynamicPPL.untyped_varinfo(demo())
8583
typeof(varinfo_untyped.metadata)
8684
```
8785

ext/DynamicPPLJETExt.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
module DynamicPPLJETExt
2+
3+
using DynamicPPL: DynamicPPL
4+
using JET: JET
5+
6+
function DynamicPPL.Experimental.is_suitable_varinfo(
7+
model::DynamicPPL.Model,
8+
context::DynamicPPL.AbstractContext,
9+
varinfo::DynamicPPL.AbstractVarInfo;
10+
only_ddpl::Bool=true,
11+
)
12+
# Let's make sure that both evaluation and sampling doesn't result in type errors.
13+
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
14+
model, varinfo, context
15+
)
16+
# If specified, we only check errors originating somewhere in the DynamicPPL.jl.
17+
# This way we don't just fall back to untyped if the user's code is the issue.
18+
result = if only_ddpl
19+
JET.report_call(f, argtypes; target_modules=(JET.AnyFrameModule(DynamicPPL),))
20+
else
21+
JET.report_call(f, argtypes)
22+
end
23+
return length(JET.get_reports(result)) == 0, result
24+
end
25+
26+
function DynamicPPL.Experimental._determine_varinfo_jet(
27+
model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true
28+
)
29+
# First we try with the typed varinfo.
30+
varinfo = DynamicPPL.typed_varinfo(model, context)
31+
32+
# Let's make sure that both evaluation and sampling doesn't result in type errors.
33+
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
34+
model, context, varinfo; only_ddpl
35+
)
36+
37+
if !issuccess
38+
# Useful information for debugging.
39+
@debug "Evaluaton with typed varinfo failed with the following issues:"
40+
@debug result
41+
end
42+
43+
# If we didn't fail anywhere, we return the type stable one.
44+
return if issuccess
45+
varinfo
46+
else
47+
# Warn the user that we can't use the type stable one.
48+
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
49+
DynamicPPL.untyped_varinfo(model, context)
50+
end
51+
end
52+
53+
end

src/DynamicPPL.jl

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -199,32 +199,35 @@ include("values_as_in_model.jl")
199199
include("debug_utils.jl")
200200
using .DebugUtils
201201

202+
include("experimental.jl")
202203
include("deprecated.jl")
203204

204205
if !isdefined(Base, :get_extension)
205206
using Requires
206207
end
207208

208-
@static if !isdefined(Base, :get_extension)
209+
# Better error message if users forget to load JET
210+
if isdefined(Base.Experimental, :register_error_hint)
209211
function __init__()
210-
@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include(
211-
"../ext/DynamicPPLChainRulesCoreExt.jl"
212-
)
213-
@require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include(
214-
"../ext/DynamicPPLEnzymeCoreExt.jl"
215-
)
216-
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include(
217-
"../ext/DynamicPPLForwardDiffExt.jl"
218-
)
219-
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
220-
"../ext/DynamicPPLMCMCChainsExt.jl"
221-
)
222-
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include(
223-
"../ext/DynamicPPLReverseDiffExt.jl"
224-
)
225-
@require ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" include(
226-
"../ext/DynamicPPLZygoteRulesExt.jl"
227-
)
212+
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
213+
requires_jet =
214+
exc.f === DynamicPPL.Experimental._determine_varinfo_jet &&
215+
length(argtypes) >= 2 &&
216+
argtypes[1] <: Model &&
217+
argtypes[2] <: AbstractContext
218+
requires_jet |=
219+
exc.f === DynamicPPL.Experimental.is_suitable_varinfo &&
220+
length(argtypes) >= 3 &&
221+
argtypes[1] <: Model &&
222+
argtypes[2] <: AbstractContext &&
223+
argtypes[3] <: AbstractVarInfo
224+
if requires_jet
225+
print(
226+
io,
227+
"\n$(exc.f) requires JET.jl to be loaded. Please run `using JET` before calling $(exc.f).",
228+
)
229+
end
230+
end
228231
end
229232
end
230233

src/experimental.jl

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
module Experimental
2+
3+
using DynamicPPL: DynamicPPL
4+
5+
# This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency.
6+
"""
7+
is_suitable_varinfo(model::Model, context::AbstractContext, varinfo::AbstractVarInfo; kwargs...)
8+
9+
Check if the `model` supports evaluation using the provided `context` and `varinfo`.
10+
11+
!!! warning
12+
Loading JET.jl is required before calling this function.
13+
14+
# Arguments
15+
- `model`: The model to verify the support for.
16+
- `context`: The context to use for the model evaluation.
17+
- `varinfo`: The varinfo to verify the support for.
18+
19+
# Keyword Arguments
20+
- `only_ddpl`: If `true`, only consider error reports occuring in the tilde pipeline. Default: `true`.
21+
22+
# Returns
23+
- `issuccess`: `true` if the model supports the varinfo, otherwise `false`.
24+
- `report`: The result of `report_call` from JET.jl.
25+
"""
26+
function is_suitable_varinfo end
27+
28+
# Internal hook for JET.jl to overload.
29+
function _determine_varinfo_jet end
30+
31+
"""
32+
determine_suitable_varinfo(model[, context]; only_ddpl::Bool=true)
33+
34+
Return a suitable varinfo for the given `model`.
35+
36+
See also: [`DynamicPPL.Experimental.is_suitable_varinfo`](@ref).
37+
38+
!!! warning
39+
For full functionality, this requires JET.jl to be loaded.
40+
If JET.jl is not loaded, this function will assume the model is compatible with typed varinfo.
41+
42+
# Arguments
43+
- `model`: The model for which to determine the varinfo.
44+
- `context`: The context to use for the model evaluation. Default: `SamplingContext()`.
45+
46+
# Keyword Arguments
47+
- `only_ddpl`: If `true`, only consider error reports within DynamicPPL.jl.
48+
49+
# Examples
50+
51+
```jldoctest
52+
julia> using DynamicPPL.Experimental: determine_suitable_varinfo
53+
54+
julia> using JET: JET # needs to be loaded for full functionality
55+
56+
julia> @model function model_with_random_support()
57+
x ~ Bernoulli()
58+
if x
59+
y ~ Normal()
60+
else
61+
z ~ Normal()
62+
end
63+
end
64+
model_with_random_support (generic function with 2 methods)
65+
66+
julia> model = model_with_random_support();
67+
68+
julia> # Typed varinfo cannot handle this random support model properly
69+
# as using a single execution of the model will not see all random variables.
70+
# Hence, this this model requires untyped varinfo.
71+
vi = determine_suitable_varinfo(model);
72+
┌ Warning: Model seems incompatible with typed varinfo. Falling back to untyped varinfo.
73+
└ @ DynamicPPLJETExt ~/.julia/dev/DynamicPPL.jl/ext/DynamicPPLJETExt.jl:48
74+
75+
julia> vi isa typeof(DynamicPPL.untyped_varinfo(model))
76+
true
77+
78+
julia> # In contrast, a simple model with no random support can be handled by typed varinfo.
79+
@model model_with_static_support() = x ~ Normal()
80+
model_with_static_support (generic function with 2 methods)
81+
82+
julia> vi = determine_suitable_varinfo(model_with_static_support());
83+
84+
julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support()))
85+
true
86+
```
87+
"""
88+
function determine_suitable_varinfo(
89+
model::DynamicPPL.Model,
90+
context::DynamicPPL.AbstractContext=DynamicPPL.SamplingContext();
91+
only_ddpl::Bool=true,
92+
)
93+
# If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that.
94+
return if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing
95+
_determine_varinfo_jet(model, context; only_ddpl)
96+
else
97+
# Warn the user.
98+
@warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo."
99+
# Otherwise, we use the, possibly incorrect, default typed varinfo (to stay backwards compat).
100+
DynamicPPL.typed_varinfo(model, context)
101+
end
102+
end
103+
104+
end

src/sampler.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,20 @@ function AbstractMCMC.step(
6767
return vi, nothing
6868
end
6969

70+
"""
71+
default_varinfo(rng, model, sampler[, context])
72+
73+
Return a default varinfo object for the given `model` and `sampler`.
74+
75+
# Arguments
76+
- `rng::Random.AbstractRNG`: Random number generator.
77+
- `model::Model`: Model for which we want to create a varinfo object.
78+
- `sampler::AbstractSampler`: Sampler which will make use of the varinfo object.
79+
- `context::AbstractContext`: Context in which the model is evaluated.
80+
81+
# Returns
82+
- `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`.
83+
"""
7084
function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler)
7185
return default_varinfo(rng, model, sampler, DefaultContext())
7286
end
@@ -126,7 +140,7 @@ By default, `data` is returned.
126140
loadstate(data) = data
127141

128142
"""
129-
default_chaintype(sampler)
143+
default_chain_type(sampler)
130144
131145
Default type of the chain of posterior samples from `sampler`.
132146
"""

src/varinfo.jl

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -164,30 +164,36 @@ function has_varnamedvector(vi::VarInfo)
164164
end
165165

166166
"""
167-
untyped_varinfo([rng, ]model[, sampler, context])
167+
untyped_varinfo(model[, context, metadata])
168168
169-
Return an untyped `VarInfo` instance for the model `model`.
169+
Return an untyped varinfo object for the given `model` and `context`.
170+
171+
# Arguments
172+
- `model::Model`: The model for which to create the varinfo object.
173+
- `context::AbstractContext`: The context in which to evaluate the model. Default: `SamplingContext()`.
174+
- `metadata::Union{Metadata,VarNamedVector}`: The metadata to use for the varinfo object.
175+
Default: `Metadata()`.
170176
"""
171177
function untyped_varinfo(
172-
rng::Random.AbstractRNG,
173178
model::Model,
174-
sampler::AbstractSampler=SampleFromPrior(),
175-
context::AbstractContext=DefaultContext(),
179+
context::AbstractContext=SamplingContext(),
176180
metadata::Union{Metadata,VarNamedVector}=Metadata(),
177181
)
178182
varinfo = VarInfo(metadata)
179-
return last(evaluate!!(model, varinfo, SamplingContext(rng, sampler, context)))
180-
end
181-
function untyped_varinfo(
182-
model::Model, args::Union{AbstractSampler,AbstractContext,Metadata,VarNamedVector}...
183-
)
184-
return untyped_varinfo(Random.default_rng(), model, args...)
183+
return last(
184+
evaluate!!(model, varinfo, hassampler(context) ? context : SamplingContext(context))
185+
)
185186
end
186187

187188
"""
188-
typed_varinfo([rng, ]model[, sampler, context])
189+
typed_varinfo(model[, context, metadata])
190+
191+
Return a typed varinfo object for the given `model`, `sampler` and `context`.
192+
193+
This simply calls [`DynamicPPL.untyped_varinfo`](@ref) and converts the resulting
194+
varinfo object to a typed varinfo object.
189195
190-
Return a typed `VarInfo` instance for the model `model`.
196+
See also: [`DynamicPPL.untyped_varinfo`](@ref)
191197
"""
192198
typed_varinfo(args...) = TypedVarInfo(untyped_varinfo(args...))
193199

@@ -198,7 +204,7 @@ function VarInfo(
198204
context::AbstractContext=DefaultContext(),
199205
metadata::Union{Metadata,VarNamedVector}=Metadata(),
200206
)
201-
return typed_varinfo(rng, model, sampler, context, metadata)
207+
return typed_varinfo(model, SamplingContext(rng, sampler, context), metadata)
202208
end
203209
VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...)
204210

0 commit comments

Comments
 (0)