-
Notifications
You must be signed in to change notification settings - Fork 36
InitContext
, part 4 - Use init!!
to replace evaluate_and_sample!!
, predict
, returned
, and initialize_values
#984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: breaking
Are you sure you want to change the base?
Changes from 14 commits
485a525
7a05ec5
b00e284
5ed975c
2706239
84e5e55
7f188b9
f7ac1b1
2041927
d9292ad
70bb2c4
bc04355
2cfc297
891b4b3
3bb7ade
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,7 @@ end | |
|
||
function _check_varname_indexing(c::MCMCChains.Chains) | ||
return DynamicPPL.supports_varname_indexing(c) || | ||
error("Chains do not support indexing using `VarName`s.") | ||
error("This `Chains` object does not support indexing using `VarName`s.") | ||
end | ||
|
||
function DynamicPPL.getindex_varname( | ||
|
@@ -42,6 +42,15 @@ function DynamicPPL.varnames(c::MCMCChains.Chains) | |
return keys(c.info.varname_to_symbol) | ||
end | ||
|
||
function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx) | ||
_check_varname_indexing(c) | ||
d = Dict{DynamicPPL.VarName,Any}() | ||
for vn in DynamicPPL.varnames(c) | ||
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx) | ||
end | ||
return d | ||
end | ||
Comment on lines
+45
to
+52
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that, if the chain does not store varnames inside its I don't think this is a huge problem right now because every chain obtained via Turing's So this is only a problem if one manually constructs a chain and tries to call However, it's obviously ugly. The only good way around this is to rework MCMCChains.jl :( (See here for the implementation of the corresponding functionality in FlexiChains.) |
||
|
||
""" | ||
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) | ||
|
||
|
@@ -114,9 +123,15 @@ function DynamicPPL.predict( | |
|
||
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) | ||
predictive_samples = map(iters) do (sample_idx, chain_idx) | ||
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx) | ||
varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo)) | ||
|
||
# Extract values from the chain | ||
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) | ||
# Resample any variables that are not present in `values_dict` | ||
_, varinfo = DynamicPPL.init!!( | ||
rng, | ||
model, | ||
varinfo, | ||
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), | ||
) | ||
vals = DynamicPPL.values_as_in_model(model, false, varinfo) | ||
varname_vals = mapreduce( | ||
collect, | ||
|
@@ -248,13 +263,16 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha | |
varinfo = DynamicPPL.VarInfo(model) | ||
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) | ||
return map(iters) do (sample_idx, chain_idx) | ||
# TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702. | ||
# Update the varinfo with the current sample and make variables not present in `chain` | ||
# to be sampled. | ||
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx) | ||
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to | ||
# `deepcopy` the `varinfo` before passing it to the `model`. | ||
model(deepcopy(varinfo)) | ||
# Extract values from the chain | ||
values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx) | ||
# Resample any variables that are not present in `values_dict`, and | ||
# return the model's retval. | ||
retval, _ = DynamicPPL.init!!( | ||
model, | ||
varinfo, | ||
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), | ||
) | ||
retval | ||
end | ||
end | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -850,7 +850,7 @@ end | |
# ^ Weird Documenter.jl bug means that we have to write the two above separately | ||
# as it can only detect the `function`-less syntax. | ||
function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo()) | ||
return first(evaluate_and_sample!!(rng, model, varinfo)) | ||
return first(init!!(rng, model, varinfo)) | ||
end | ||
|
||
""" | ||
|
@@ -863,32 +863,6 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) | |
return Threads.nthreads() > 1 | ||
end | ||
|
||
""" | ||
evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler]) | ||
|
||
Evaluate the `model` with the given `varinfo`, but perform sampling during the | ||
evaluation using the given `sampler` by wrapping the model's context in a | ||
`SamplingContext`. | ||
|
||
If `sampler` is not provided, defaults to [`SampleFromPrior`](@ref). | ||
|
||
Returns a tuple of the model's return value, plus the updated `varinfo` object. | ||
""" | ||
function evaluate_and_sample!!( | ||
rng::Random.AbstractRNG, | ||
model::Model, | ||
varinfo::AbstractVarInfo, | ||
sampler::AbstractSampler=SampleFromPrior(), | ||
) | ||
sampling_model = contextualize(model, SamplingContext(rng, sampler, model.context)) | ||
return evaluate!!(sampling_model, varinfo) | ||
end | ||
function evaluate_and_sample!!( | ||
model::Model, varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior() | ||
) | ||
return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler) | ||
end | ||
|
||
""" | ||
init!!( | ||
[rng::Random.AbstractRNG,] | ||
|
@@ -897,10 +871,10 @@ end | |
[init_strategy::AbstractInitStrategy=InitFromPrior()] | ||
) | ||
|
||
Evaluate the `model` and replace the values of the model's random variables in | ||
the given `varinfo` with new values using a specified initialisation strategy. | ||
If the values in `varinfo` are not already present, they will be added using | ||
that same strategy. | ||
Evaluate the `model` and replace the values of the model's random variables | ||
in the given `varinfo` with new values, using a specified initialisation strategy. | ||
If the values in `varinfo` are not set, they will be added. | ||
using a specified initialisation strategy. | ||
|
||
If `init_strategy` is not provided, defaults to InitFromPrior(). | ||
|
||
|
@@ -1051,11 +1025,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f) | |
Generate a sample of type `T` from the prior distribution of the `model`. | ||
""" | ||
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} | ||
x = last( | ||
evaluate_and_sample!!( | ||
rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()) | ||
), | ||
) | ||
x = last(init!!(rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()))) | ||
return values_as(x, T) | ||
end | ||
|
||
|
@@ -1227,25 +1197,8 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC | |
end | ||
end | ||
|
||
""" | ||
predict([rng::Random.AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) | ||
|
||
Generate samples from the posterior predictive distribution by evaluating `model` at each set | ||
of parameter values provided in `chain`. The number of posterior predictive samples matches | ||
the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values | ||
and the predicted values. | ||
""" | ||
function predict( | ||
rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo} | ||
) | ||
varinfo = DynamicPPL.VarInfo(model) | ||
return map(chain) do params_varinfo | ||
vi = deepcopy(varinfo) | ||
DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple)) | ||
model(rng, vi) | ||
return vi | ||
end | ||
end | ||
# Implemented & documented in DynamicPPLMCMCChainsExt | ||
function predict end | ||
Comment on lines
-1230
to
+1201
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was discussed at one of the meetings and we decided we didn't care enough about the |
||
|
||
""" | ||
returned(model::Model, parameters::NamedTuple) | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused by the comments in this function because as far as I can tell it only ever tested sampling, not both sampling and evaluation. (That was also true going further back e.g. in v0.36)
This PR thus also changes the implementation of this function to test both evaluation and sampling (i.e. initialisation) and if either fails, it will return the untyped varinfo.
Sorry I had to make this change in this PR. There were a few unholy tests where one would end up evaluating a model with a
SamplingContext{<:InitContext}
, which would error unless I introduced special code to handle it, and I didn't really want to do that. JETExt was one of those unholy scenarios.