-
Notifications
You must be signed in to change notification settings - Fork 36
InitContext
, part 4 - Use init!!
to replace evaluate_and_sample!!
, predict
, returned
, and initialize_values
#984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: breaking
Are you sure you want to change the base?
Conversation
Benchmark Report for Commit 3bb7adeComputer Information
Benchmark Results
|
init!!
to replace evaluate_and_sample!!
, predict
, returned
, and initialize_values
InitContext
, part 4 - Use init!!
to replace evaluate_and_sample!!
, predict
, returned
, and initialize_values
025aa8b
to
b55c1e1
Compare
b72c3bf
to
92d3542
Compare
7438b23
to
d55d378
Compare
12d93e5
to
7a8e7e3
Compare
1d8bceb
to
2edcd10
Compare
DynamicPPL.jl documentation for PR #984 is available at: |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## breaking #984 +/- ##
============================================
- Coverage 82.53% 80.69% -1.85%
============================================
Files 39 39
Lines 4008 3947 -61
============================================
- Hits 3308 3185 -123
- Misses 700 762 +62 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
4408efb
to
9c07727
Compare
5025592
to
3a16f9c
Compare
3a16f9c
to
4c96020
Compare
""" | ||
predict([rng::Random.AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) | ||
|
||
Generate samples from the posterior predictive distribution by evaluating `model` at each set | ||
of parameter values provided in `chain`. The number of posterior predictive samples matches | ||
the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values | ||
and the predicted values. | ||
""" | ||
function predict( | ||
rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo} | ||
) | ||
varinfo = DynamicPPL.VarInfo(model) | ||
return map(chain) do params_varinfo | ||
vi = deepcopy(varinfo) | ||
DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple)) | ||
model(rng, vi) | ||
return vi | ||
end | ||
end | ||
# Implemented & documented in DynamicPPLMCMCChainsExt | ||
function predict end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was discussed at one of the meetings and we decided we didn't care enough about the predict
method on vectors of varinfos. It's currently bugged because varinfo
is always unlinked, but params_varinfo
might be linked, and if it is, it will give wrong results because it sets a linked value into an unlinked varinfo. See #983.
f6dd1d5
to
d9292ad
Compare
function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx) | ||
_check_varname_indexing(c) | ||
d = Dict{DynamicPPL.VarName,Any}() | ||
for vn in DynamicPPL.varnames(c) | ||
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx) | ||
end | ||
return d | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that, if the chain does not store varnames inside its info
field, chain_sample_to_varname_dict
will fail.
I don't think this is a huge problem right now because every chain obtained via Turing's sample()
will contain varnames:
So this is only a problem if one manually constructs a chain and tries to call predict
on it, which I think is a highly unlikely workflow (and I'm happy to wait for people to complain if it fails). There are a few places in DynamicPPL's test suite where this does actually happen. I fixed them all by manually adding the varname dictionary.
However, it's obviously ugly. The only good way around this is to rework MCMCChains.jl :( (See here for the implementation of the corresponding functionality in FlexiChains.)
89bc0ea
to
726d486
Compare
726d486
to
bc04355
Compare
function DynamicPPL.Experimental._determine_varinfo_jet( | ||
model::DynamicPPL.Model; only_ddpl::Bool=true | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused by the comments in this function because as far as I can tell it only ever tested sampling, not both sampling and evaluation. (That was also true going further back e.g. in v0.36)
This PR thus also changes the implementation of this function to test both evaluation and sampling (i.e. initialisation) and if either fails, it will return the untyped varinfo.
Sorry I had to make this change in this PR. There were a few unholy tests where one would end up evaluating a model with a SamplingContext{<:InitContext}
, which would error unless I introduced special code to handle it, and I didn't really want to do that. JETExt was one of those unholy scenarios.
DynamicPPL.evaluate_and_sample!!(rng, model, vi, sampler) | ||
return vi, nothing | ||
strategy = sampler isa SampleFromPrior ? InitFromPrior() : InitFromUniform() | ||
_, new_vi = DynamicPPL.init!!(rng, model, vi, strategy) | ||
return new_vi, nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit weird, but it's really just to tide us over until we delete SampleFromUniform/SampleFromPrior
properly.
Define the initialisation strategy used for generating initial values when | ||
sampling with `sampler`. Defaults to `InitFromPrior()`, but can be overridden. | ||
""" | ||
init_strategy(::Sampler) = InitFromPrior() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eventually the aim would be to use ::AbstractSampler
. But that will have to wait for cleanup in Turing. DynamicPPL itself doesn't use Sampler
at all and if you only look at DPPL it looks like a meaningless empty wrapper, but Turing relies on these methods a fair bit
@testset "rng" begin | ||
model = GDEMO_DEFAULT | ||
|
||
for sampler in (SampleFromPrior(), SampleFromUniform()) | ||
for i in 1:10 | ||
Random.seed!(100 + i) | ||
vi = VarInfo() | ||
DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) | ||
vals = vi[:] | ||
|
||
Random.seed!(100 + i) | ||
vi = VarInfo() | ||
DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) | ||
@test vi[:] == vals | ||
end | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is adequately tested in the InitContext tests (test_rng_respected
)
varinfo_untyped = DynamicPPL.VarInfo() | ||
model_with_spl = contextualize(model, SamplingContext(context)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also had to rework these tests because of the SamplingContext{<:InitContext}
case.
## `typed_varinfo` | ||
vi = DynamicPPL.typed_varinfo(model) | ||
vi = DynamicPPL.settrans!!(vi, true, vn) | ||
test_linked_varinfo(model, vi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was a duplicate of the 4 lines above
@@ -1012,45 +959,6 @@ end | |||
@test merge(vi_double, vi_single)[vn] == 1.0 | |||
end | |||
|
|||
@testset "sampling from linked varinfo" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these tests are also covered in InitContext now (test_link_status_respected
)!
Part 1: Adding
hasvalue
andgetvalue
to AbstractPPLPart 2: Removing
hasvalue
andgetvalue
from DynamicPPLPart 3: Introducing
InitContext
andinit!!
This is part 4/N of #967.
In Part 3 we introduced
InitContext
. This PR makes use of the functionality in there to replace a bunch of code that no longer needs to exist:setval_and_resample!
followed by model evaluation: This process was used forpredict
andreturned
, to manually store certain values in the VarInfo, which would be used in the subsequent model evaluation. We can now do this in a single step usingInitFromParams
.initialize_values!!
: very similar to the above. It would manually set values inside the varinfo, and then it would trigger an extra model evaluation to update the logp field. Again, this is directly replaced withInitFromParams
.evaluate_and_sample!!
: direct one-to-one replacement withinit!!
.There is one fairly major API change associated with point (2): the
initial_params
kwarg to Turing'ssample
must now be anAbstractInitStrategy
.It's still optional (it will default to
init_strategy(spl)
, which is usuallyInitFromPrior
, except for the HMC family which usesInitFromUniform
). However, there are two implications:initial_params
cannot be a vector of parameters anymore. It must beInitFromParams(::NamedTuple)
ORInitFromParams(::AbstractDict{VarName})
.InitFromParams
expects values in unlinked space,initial_params
must always be specified in unlinked space. Previously,initial_params
would have to be specified in a way that matched the linking status of the underlying varinfo.I consider both of these to be a major win for clarity. (One might argue that vectors are more convenient. But IMO anything that lets you extract a vector will also let you extract a NT or Dict, maybe with a bit more typing at worst).
Closes
Closes #774
Closes #797
Closes #983
Closes TuringLang/Turing.jl#2476
Closes TuringLang/Turing.jl#1775