-
-
Notifications
You must be signed in to change notification settings - Fork 10
Add converter from Turing using both Chains and Model #133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sethaxen
wants to merge
36
commits into
main
Choose a base branch
from
fromturing
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 3 commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
dd7abf8
Add Turing to extras
sethaxen dff0ed1
Add initial implementation of from_turing
sethaxen 55cbc78
Handle non-array eltype constraints
sethaxen d25e592
Apply suggestions from code review
sethaxen 31f902b
Repair predictive model code
sethaxen e1da410
Run formatter
sethaxen 454c23e
Constrain type of model
sethaxen 98511f4
Add model name to attributes
sethaxen 191cfd7
Support specifying groups to not be generated
sethaxen 2c6cc32
Constrain type of rng
sethaxen 134a6a1
Add docstring
sethaxen 9e3edb7
Document from_turing
sethaxen 3408476
Also generate constant_data
sethaxen 180e560
Make code more modular
sethaxen 85ddd69
Add Turing tests
sethaxen 4f50d7a
Force library to be Turing
sethaxen 1b58b6b
Overload setattribute! for InferenceData
sethaxen 12ca874
Add function to add inference library info
sethaxen a7bb79f
Globally use library utility
sethaxen 42c8823
Test library utility for Turing
sethaxen 842447f
Increment version number
sethaxen c9b4562
Repair Turing example
sethaxen e0d9ae3
Don't import Turing's exports
sethaxen a90df63
Return correct variable name
sethaxen ea273ef
Indent wrapped lines
sethaxen 6e93fa7
Update quickstart.md
sethaxen 92e8a25
Run formatter
sethaxen 1581eb0
Deep copy arguments
sethaxen b863095
Capture status in string
sethaxen 13ad03a
Better handle adding library info
sethaxen a188e27
Run formatter
sethaxen ee474ae
Add attribute and library tests
sethaxen 5ee652a
Extract observed_data from model
sethaxen 09c37da
Update example
sethaxen a05bfd2
Update quickstart
sethaxen bf474a1
Fix test
sethaxen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
function from_turing( | ||
chns=nothing; | ||
model=nothing, | ||
rng=Random.default_rng(), | ||
nchains=ndraws = chns isa Turing.MCMCChains.Chains ? last(size(chns)) : 1, | ||
ndraws=chns isa Turing.MCMCChains.Chains ? first(size(chns)) : 1_000, | ||
library=Turing, | ||
observed_data=nothing, | ||
constant_data=nothing, | ||
posterior_predictive=nothing, | ||
prior=nothing, | ||
prior_predictive=nothing, | ||
log_likelihood=nothing, | ||
kwargs..., | ||
) | ||
groups = Dict{Symbol,Any}( | ||
:observed_data => observed_data, | ||
:constant_data => constant_data, | ||
:posterior_predictive => posterior_predictive, | ||
:prior => prior, | ||
:prior_predictive => prior_predictive, | ||
:log_likelihood => log_likelihood, | ||
) | ||
model === nothing && return from_mcmcchains(chns; library=library, groups..., kwargs...) | ||
if groups[:prior] === nothing | ||
groups[:prior] = reduce( | ||
Turing.chainscat, | ||
map( | ||
_ -> Turing.sample(rng, model, Turing.Prior(), ndraws; progress=false), | ||
1:nchains, | ||
sethaxen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
), | ||
) | ||
sethaxen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
end | ||
|
||
groups[:observed_data] === nothing && | ||
return from_mcmcchains(chns; library=library, groups..., kwargs...) | ||
|
||
observed_data = groups[:observed_data] | ||
data_var_names = Set( | ||
observed_data isa Dict ? Symbol.(keys(observed_data)) : propertynames(observed_data) | ||
) | ||
|
||
if groups[:constant_data] === nothing | ||
groups[:constant_data] = NamedTuple( | ||
filter(p -> first(p) ∉ data_var_names, pairs(model.args)) | ||
) | ||
end | ||
|
||
# Instantiate the predictive model | ||
args_pred = NamedTuple( | ||
k => k in data_var_names ? similar(v, Missing) : v for (k, v) in pairs(model.args) | ||
) | ||
model_predict = Turing.DynamicPPL.Model(model.name, model.f, args_pred, model.defaults) | ||
sethaxen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
# and then sample! | ||
if groups[:prior_predictive] === nothing && groups[:prior] isa Turing.MCMCChains.Chains | ||
groups[:prior_predictive] = Turing.predict(rng, model_predict, groups[:prior]) | ||
end | ||
|
||
if chns isa Turing.MCMCChains.Chains | ||
if groups[:posterior_predictive] === nothing && chns isa Turing.MCMCChains.Chains | ||
groups[:posterior_predictive] = Turing.predict(rng, model_predict, chns) | ||
end | ||
|
||
if groups[:log_likelihood] === nothing && | ||
groups[:posterior_predictive] isa MCMCChains.Chains | ||
loglikelihoods = Turing.pointwise_loglikelihoods( | ||
model, Turing.MCMCChains.get_sections(chns, :parameters) | ||
) | ||
|
||
# Bundle loglikelihoods into a `Chains` object so we can reuse our own variable | ||
# name parsing | ||
pred_names = string.(keys(groups[:posterior_predictive])) | ||
loglikelihoods_vals = getindex.(Ref(loglikelihoods), pred_names) | ||
loglikelihoods_arr = permutedims(cat(loglikelihoods_vals...; dims=3), (1, 3, 2)) | ||
groups[:log_likelihood] = Turing.MCMCChains.Chains( | ||
loglikelihoods_arr, pred_names | ||
) | ||
end | ||
end | ||
|
||
return from_mcmcchains(chns; library=Turing, groups..., kwargs...) | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.