-
Notifications
You must be signed in to change notification settings - Fork 36
Using JET.jl to determine if typed varinfo is okay #728
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 59 commits
361c45e
545cfab
abd432f
5cd9009
d503c3c
acb2cb0
902641f
d93006b
64ff18a
90c2df0
a94dbd5
67723d6
3d8ad44
c06b080
d1a5bab
5370e55
dd408ee
c253e9b
891b46a
686ed9f
d7d785a
dda56ec
46ea18c
c20ede3
0b3c36e
f76658a
690b017
97258f3
95bb3a9
155ce66
4998d08
5c27677
99d4df7
3b9a9eb
99fb153
3588597
7a302e5
040cb54
889c370
c98fe49
123b644
37fabb0
33e5b98
7ddec2c
b6b4bff
e07ecdb
8ba8f82
9ec1556
599488b
fa155a4
8496968
fd82871
bb87ba0
62c5cd1
55dc91e
ae51778
a692ec3
d5eb404
17b6ec9
bfa88b2
82578cf
3aad34f
da3eefe
325c5f9
4a17e82
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
module DynamicPPLJETExt | ||
|
||
using DynamicPPL: DynamicPPL | ||
using JET: JET | ||
|
||
function DynamicPPL.Experimental.is_suitable_varinfo( | ||
model::DynamicPPL.Model, | ||
context::DynamicPPL.AbstractContext, | ||
varinfo::DynamicPPL.AbstractVarInfo; | ||
only_ddpl::Bool=true, | ||
) | ||
# Let's make sure that both evaluation and sampling doesn't result in type errors. | ||
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( | ||
model, varinfo, context | ||
) | ||
# If specified, we only check errors originating somewhere in the DynamicPPL.jl. | ||
# This way we don't just fall back to untyped if the user's code is the issue. | ||
result = if only_ddpl | ||
JET.report_call(f, argtypes; target_modules=(JET.AnyFrameModule(DynamicPPL),)) | ||
else | ||
JET.report_call(f, argtypes) | ||
end | ||
return length(JET.get_reports(result)) == 0, result | ||
end | ||
|
||
function DynamicPPL.Experimental._determine_varinfo_jet( | ||
model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true | ||
) | ||
# First we try with the typed varinfo. | ||
varinfo = DynamicPPL.typed_varinfo(model, context) | ||
issuccess = true | ||
|
||
# Let's make sure that both evaluation and sampling doesn't result in type errors. | ||
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo( | ||
model, context, varinfo; only_ddpl | ||
) | ||
|
||
if !issuccess | ||
# Useful information for debugging. | ||
@debug "Evaluaton with typed varinfo failed with the following issues:" | ||
@debug result | ||
end | ||
|
||
# If we didn't fail anywhere, we return the type stable one. | ||
return if issuccess | ||
varinfo | ||
else | ||
# Warn the user that we can't use the type stable one. | ||
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo." | ||
DynamicPPL.untyped_varinfo(model, context) | ||
end | ||
end | ||
|
||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -199,32 +199,35 @@ include("values_as_in_model.jl") | |
include("debug_utils.jl") | ||
using .DebugUtils | ||
|
||
include("experimental.jl") | ||
include("deprecated.jl") | ||
|
||
if !isdefined(Base, :get_extension) | ||
using Requires | ||
end | ||
|
||
@static if !isdefined(Base, :get_extension) | ||
# Better error message if users forget to load the AD package | ||
torfjelde marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
if isdefined(Base.Experimental, :register_error_hint) | ||
function __init__() | ||
@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include( | ||
"../ext/DynamicPPLChainRulesCoreExt.jl" | ||
) | ||
@require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include( | ||
"../ext/DynamicPPLEnzymeCoreExt.jl" | ||
) | ||
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include( | ||
"../ext/DynamicPPLForwardDiffExt.jl" | ||
) | ||
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include( | ||
"../ext/DynamicPPLMCMCChainsExt.jl" | ||
) | ||
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include( | ||
"../ext/DynamicPPLReverseDiffExt.jl" | ||
) | ||
@require ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" include( | ||
"../ext/DynamicPPLZygoteRulesExt.jl" | ||
) | ||
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ | ||
requires_jet = | ||
exc.f === DynamicPPL.Experimental._determine_varinfo_jet && | ||
length(argtypes) >= 2 && | ||
argtypes[1] <: Model && | ||
argtypes[2] <: AbstractContext | ||
requires_jet |= | ||
exc.f === DynamicPPL.Experimental.is_suitable_varinfo && | ||
length(argtypes) >= 3 && | ||
argtypes[1] <: Model && | ||
argtypes[2] <: AbstractContext && | ||
argtypes[3] <: AbstractVarInfo | ||
if requires_jet | ||
print( | ||
io, | ||
"\n$(exc.f) requires JET.jl to be loaded. Please run `using JET` before calling $(exc.f).", | ||
) | ||
end | ||
end | ||
Comment on lines
+212
to
+230
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could there be some way to test this? I do see that it's tricky. I'm a bit uncomfortable having this in without any testing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah was thinking the same. We could put in a test strictly before loading JET.jl ofc. It's a bit messy, but seems like the best way 😕 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does wrapping the tests in separate modules save us? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nah. AFAIK extensions trigger if the package is loaded at any point, e.g. even if a dep loads it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's also a thing where it doesn't seem like we can nicely get the resulting error message (the error hint is not in the msg of the error or something). So I think we just leave this for now 😕 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, it does seem nasty to test for. Have you tried locally that it does what you expect? |
||
end | ||
end | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
module Experimental | ||
|
||
using DynamicPPL: DynamicPPL | ||
|
||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
is_suitable_varinfo(model::Model, context::AbstractContext, varinfo::AbstractVarInfo; kwargs...) | ||
Check if the `model` supports evaluation using the provided `context` and `varinfo`. | ||
!!! warning | ||
Loading JET.jl is required before calling this function. | ||
# Arguments | ||
- `model`: The model to to verify the support for. | ||
torfjelde marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
- `context`: The context to use for the model evaluation. | ||
- `varinfo`: The varinfo to verify the support for. | ||
# Keyword Arguments | ||
- `only_ddpl`: If `true`, only consider error reports occuring in the tilde pipeline. Default: `true`. | ||
mhauru marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Returns | ||
- `issuccess`: `true` if the model supports the varinfo, otherwise `false`. | ||
- `report`: The result of `report_call` from JET.jl. | ||
""" | ||
function is_suitable_varinfo end | ||
|
||
# Internal hook for JET.jl to overload. | ||
function _determine_varinfo_jet end | ||
|
||
""" | ||
determine_suitable_varinfo(model[, context]; verbose::Bool=false, only_ddpl::Bool=true) | ||
torfjelde marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
Return a suitable varinfo for the given `model`. | ||
See also: [`DynamicPPL.Experimental.is_suitable_varinfo`](@ref). | ||
!!! warning | ||
For full functionality, this requires JET.jl to be loaded. | ||
If JET.jl is not loaded, this function will assume the model is compatible with typed varinfo. | ||
# Arguments | ||
- `model`: The model for which to determine the varinfo. | ||
- `context`: The context to use for the model evaluation. Default: `SamplingContext()`. | ||
# Keyword Arguments | ||
mhauru marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- `only_ddpl`: If `true`, only consider error reports within DynamicPPL.jl. | ||
# Examples | ||
```jldoctest | ||
julia> using DynamicPPL.Experimental: determine_suitable_varinfo | ||
julia> using JET: JET # needs to be loaded for full functionality | ||
julia> @model function model_with_random_support() | ||
x ~ Bernoulli() | ||
if x | ||
y ~ Normal() | ||
else | ||
z ~ Normal() | ||
end | ||
end | ||
model_with_random_support (generic function with 2 methods) | ||
julia> model = model_with_random_support(); | ||
julia> # Typed varinfo cannot handle this random support model properly | ||
# as using a single execution of the model will not see all random variables. | ||
# Hence, this this model requires untyped varinfo. | ||
varinfo = determine_suitable_varinfo(model); | ||
┌ Warning: Model seems incompatible with typed varinfo. Falling back to untyped varinfo. | ||
└ @ DynamicPPLJETExt ~/Projects/public/DynamicPPL.jl/ext/DynamicPPLJETExt.jl:49 | ||
julia> varinfo isa typeof(DynamicPPL.untyped_varinfo(model)) | ||
true | ||
julia> # In contrast, a simple model with no random support can be handled by typed varinfo. | ||
@model model_with_static_support() = x ~ Normal() | ||
model_with_static_support (generic function with 2 method) | ||
julia> varinfo = determine_suitable_varinfo(model_with_static_support()); | ||
julia> varinfo isa typeof(DynamicPPL.typed_varinfo(model_with_static_support())) | ||
true | ||
``` | ||
""" | ||
function determine_suitable_varinfo( | ||
model::DynamicPPL.Model, | ||
context::DynamicPPL.AbstractContext=DynamicPPL.SamplingContext(); | ||
only_ddpl::Bool=true, | ||
) | ||
# If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that. | ||
return if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing | ||
_determine_varinfo_jet(model, context; only_ddpl) | ||
else | ||
# Warn the user. | ||
@warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo." | ||
# Otherwise, we use the, possibly incorrect, default typed varinfo (to stay backwards compat). | ||
DynamicPPL.typed_varinfo(model, context) | ||
end | ||
end | ||
|
||
end |
Uh oh!
There was an error while loading. Please reload this page.