From dd7abf86924c5948d0cab49813003af77eef1d7c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 19 May 2021 20:55:29 +0200 Subject: [PATCH 01/36] Add Turing to extras To add a compat entry --- Project.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5740e60c9..77ab4f190 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee" REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -25,6 +26,7 @@ PyCall = "1.91.2" PyPlot = "2.8.2" Requires = "0.5.2, 1.0" StatsBase = "0.32, 0.33" +Turing = "0.15" julia = "^1" [extras] @@ -33,6 +35,7 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [targets] -test = ["CmdStan", "MCMCChains", "MonteCarloMeasurements", "Random", "Test"] +test = ["CmdStan", "MCMCChains", "MonteCarloMeasurements", "Random", "Test", "Turing"] From dff0ed1730a177b4bc2f604e292c89bc7471643b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 19 May 2021 20:56:42 +0200 Subject: [PATCH 02/36] Add initial implementation of from_turing --- src/ArviZ.jl | 6 ++++ src/turing.jl | 83 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+) create mode 100644 src/turing.jl diff --git a/src/ArviZ.jl b/src/ArviZ.jl index eb6926ff3..2ce2d1084 100644 --- a/src/ArviZ.jl +++ b/src/ArviZ.jl @@ -2,6 +2,7 @@ __precompile__() module ArviZ using Base: @__doc__ +using Random using Requires using REPL using NamedTupleTools @@ -76,6 +77,7 @@ export InferenceData, from_dict, from_cmdstan, from_mcmcchains, + from_turing, concat, concat! @@ -109,6 +111,10 @@ function __init__() import .MCMCChains: Chains, sections include("mcmcchains.jl") end + @require Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" begin + import .Turing: Turing + include("turing.jl") + end return nothing end diff --git a/src/turing.jl b/src/turing.jl new file mode 100644 index 000000000..d253ae85f --- /dev/null +++ b/src/turing.jl @@ -0,0 +1,83 @@ +function from_turing( + chns=nothing; + model=nothing, + rng=Random.default_rng(), + nchains=ndraws = chns isa Turing.MCMCChains.Chains ? last(size(chns)) : 1, + ndraws=chns isa Turing.MCMCChains.Chains ? first(size(chns)) : 1_000, + library=Turing, + observed_data=nothing, + constant_data=nothing, + posterior_predictive=nothing, + prior=nothing, + prior_predictive=nothing, + log_likelihood=nothing, + kwargs..., +) + groups = Dict{Symbol,Any}( + :observed_data => observed_data, + :constant_data => constant_data, + :posterior_predictive => posterior_predictive, + :prior => prior, + :prior_predictive => prior_predictive, + :log_likelihood => log_likelihood, + ) + model === nothing && return from_mcmcchains(chns; library=library, groups..., kwargs...) + if groups[:prior] === nothing + groups[:prior] = reduce( + Turing.chainscat, + map( + _ -> Turing.sample(rng, model, Turing.Prior(), ndraws; progress=false), + 1:nchains, + ), + ) + end + + groups[:observed_data] === nothing && + return from_mcmcchains(chns; library=library, groups..., kwargs...) + + observed_data = groups[:observed_data] + data_var_names = Set( + observed_data isa Dict ? Symbol.(keys(observed_data)) : propertynames(observed_data) + ) + + if groups[:constant_data] === nothing + groups[:constant_data] = NamedTuple( + filter(p -> first(p) ∉ data_var_names, pairs(model.args)) + ) + end + + # Instantiate the predictive model + args_pred = NamedTuple( + k => k in data_var_names ? similar(v, Missing) : v for (k, v) in pairs(model.args) + ) + model_predict = Turing.DynamicPPL.Model(model.name, model.f, args_pred, model.defaults) + + # and then sample! + if groups[:prior_predictive] === nothing && groups[:prior] isa Turing.MCMCChains.Chains + groups[:prior_predictive] = Turing.predict(rng, model_predict, groups[:prior]) + end + + if chns isa Turing.MCMCChains.Chains + if groups[:posterior_predictive] === nothing && chns isa Turing.MCMCChains.Chains + groups[:posterior_predictive] = Turing.predict(rng, model_predict, chns) + end + + if groups[:log_likelihood] === nothing && + groups[:posterior_predictive] isa MCMCChains.Chains + loglikelihoods = Turing.pointwise_loglikelihoods( + model, Turing.MCMCChains.get_sections(chns, :parameters) + ) + + # Bundle loglikelihoods into a `Chains` object so we can reuse our own variable + # name parsing + pred_names = string.(keys(groups[:posterior_predictive])) + loglikelihoods_vals = getindex.(Ref(loglikelihoods), pred_names) + loglikelihoods_arr = permutedims(cat(loglikelihoods_vals...; dims=3), (1, 3, 2)) + groups[:log_likelihood] = Turing.MCMCChains.Chains( + loglikelihoods_arr, pred_names + ) + end + end + + return from_mcmcchains(chns; library=Turing, groups..., kwargs...) +end From 55cbc78777872a5db38514c8c470b2121526828c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 19 May 2021 20:57:17 +0200 Subject: [PATCH 03/36] Handle non-array eltype constraints --- src/utils.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index d907d5c13..486541970 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -213,12 +213,17 @@ _asstringkeydict(x::Dict{String}) = x enforce_stat_eltypes(stats) = convert_to_eltypes(stats, sample_stats_eltypes) +_convert_to_eltype(v::AbstractArray, T) = convert(Array{T}, v) +_convert_to_eltype(v, T) = convert(T, v) function convert_to_eltypes(data::Dict, data_eltypes) - return Dict(k => convert(Array{get(data_eltypes, k, eltype(v))}, v) for (k, v) in data) + return Dict( + k => _convert_to_eltype(v, get(data_eltypes, k, eltype(v))) for (k, v) in data + ) end function convert_to_eltypes(data::NamedTuple, data_eltypes) return NamedTuple( - k => convert(Array{get(data_eltypes, k, eltype(v))}, v) for (k, v) in pairs(data) + k => _convert_to_eltype(v, get(data_eltypes, k, eltype(v))) for + (k, v) in pairs(data) ) end From d25e5927693bf9ed21114fbc841548eed08e3047 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 13:11:50 +0200 Subject: [PATCH 04/36] Apply suggestions from code review Co-authored-by: Tor Erlend Fjelde --- src/turing.jl | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/turing.jl b/src/turing.jl index d253ae85f..16f7bfc19 100644 --- a/src/turing.jl +++ b/src/turing.jl @@ -23,13 +23,7 @@ function from_turing( ) model === nothing && return from_mcmcchains(chns; library=library, groups..., kwargs...) if groups[:prior] === nothing - groups[:prior] = reduce( - Turing.chainscat, - map( - _ -> Turing.sample(rng, model, Turing.Prior(), ndraws; progress=false), - 1:nchains, - ), - ) + groups[:prior] = Turing.sample(rng, model, Turing.Prior(), Turing.MCMCThreads(), ndraws, nchains; progress=false) end groups[:observed_data] === nothing && @@ -47,10 +41,7 @@ function from_turing( end # Instantiate the predictive model - args_pred = NamedTuple( - k => k in data_var_names ? similar(v, Missing) : v for (k, v) in pairs(model.args) - ) - model_predict = Turing.DynamicPPL.Model(model.name, model.f, args_pred, model.defaults) +model_predict = Turing.DynamicPPL.Model{data_var_names}(model.name, model.f, args_pred, model.defaults) # and then sample! if groups[:prior_predictive] === nothing && groups[:prior] isa Turing.MCMCChains.Chains From 31f902b6b6147fdf2a4c3ba3d7522c033a128ae6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 13:29:07 +0200 Subject: [PATCH 05/36] Repair predictive model code --- src/turing.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/turing.jl b/src/turing.jl index 16f7bfc19..55c3a78c8 100644 --- a/src/turing.jl +++ b/src/turing.jl @@ -30,18 +30,21 @@ function from_turing( return from_mcmcchains(chns; library=library, groups..., kwargs...) observed_data = groups[:observed_data] - data_var_names = Set( + observed_data_keys = Set( observed_data isa Dict ? Symbol.(keys(observed_data)) : propertynames(observed_data) ) if groups[:constant_data] === nothing groups[:constant_data] = NamedTuple( - filter(p -> first(p) ∉ data_var_names, pairs(model.args)) + filter(p -> first(p) ∉ observed_data_keys, pairs(model.args)) ) end # Instantiate the predictive model -model_predict = Turing.DynamicPPL.Model{data_var_names}(model.name, model.f, args_pred, model.defaults) + data_var_names = filter(in(observed_data_keys), keys(model.args)) + model_predict = Turing.DynamicPPL.Model{data_var_names}( + model.name, model.f, model.args, model.defaults + ) # and then sample! if groups[:prior_predictive] === nothing && groups[:prior] isa Turing.MCMCChains.Chains From e1da410528ed0e15164509106c851b83a3b4bf9a Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 13:29:19 +0200 Subject: [PATCH 06/36] Run formatter --- src/turing.jl | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/turing.jl b/src/turing.jl index 55c3a78c8..e2e139ff8 100644 --- a/src/turing.jl +++ b/src/turing.jl @@ -23,11 +23,20 @@ function from_turing( ) model === nothing && return from_mcmcchains(chns; library=library, groups..., kwargs...) if groups[:prior] === nothing - groups[:prior] = Turing.sample(rng, model, Turing.Prior(), Turing.MCMCThreads(), ndraws, nchains; progress=false) + groups[:prior] = Turing.sample( + rng, + model, + Turing.Prior(), + Turing.MCMCThreads(), + ndraws, + nchains; + progress=false, + ) end - groups[:observed_data] === nothing && + if groups[:observed_data] === nothing return from_mcmcchains(chns; library=library, groups..., kwargs...) + end observed_data = groups[:observed_data] observed_data_keys = Set( From 454c23e5d34b6394f39bbbe4d07db3adfd3f3996 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 13:29:34 +0200 Subject: [PATCH 07/36] Constrain type of model --- src/turing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/turing.jl b/src/turing.jl index e2e139ff8..20e3037eb 100644 --- a/src/turing.jl +++ b/src/turing.jl @@ -1,6 +1,6 @@ function from_turing( chns=nothing; - model=nothing, + model::Union{Nothing,Turing.DynamicPPL.Model}=nothing, rng=Random.default_rng(), nchains=ndraws = chns isa Turing.MCMCChains.Chains ? last(size(chns)) : 1, ndraws=chns isa Turing.MCMCChains.Chains ? first(size(chns)) : 1_000, From 98511f4abf759aa666ab4c9ae4cccdc2d9307c8a Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 14:43:06 +0200 Subject: [PATCH 08/36] Add model name to attributes --- src/turing.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/turing.jl b/src/turing.jl index 20e3037eb..ff6675556 100644 --- a/src/turing.jl +++ b/src/turing.jl @@ -82,5 +82,13 @@ function from_turing( end end - return from_mcmcchains(chns; library=Turing, groups..., kwargs...) + idata = from_mcmcchains(chns; library=library, groups..., kwargs...) + + # add model name to generated InferenceData groups + for name in groupnames(idata) + name in (:observed_data,) && continue + ds = getproperty(idata, name) + setattribute!(ds, :model_name, nameof(model)) + end + return idata end From 191cfd7736c544ce407c526d6052bbffd2f554a9 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 14:43:45 +0200 Subject: [PATCH 09/36] Support specifying groups to not be generated --- src/turing.jl | 54 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/src/turing.jl b/src/turing.jl index ff6675556..599225c9b 100644 --- a/src/turing.jl +++ b/src/turing.jl @@ -5,12 +5,12 @@ function from_turing( nchains=ndraws = chns isa Turing.MCMCChains.Chains ? last(size(chns)) : 1, ndraws=chns isa Turing.MCMCChains.Chains ? first(size(chns)) : 1_000, library=Turing, - observed_data=nothing, - constant_data=nothing, - posterior_predictive=nothing, - prior=nothing, - prior_predictive=nothing, - log_likelihood=nothing, + observed_data=true, + constant_data=true, + posterior_predictive=true, + prior=true, + prior_predictive=true, + log_likelihood=true, kwargs..., ) groups = Dict{Symbol,Any}( @@ -21,8 +21,15 @@ function from_turing( :prior_predictive => prior_predictive, :log_likelihood => log_likelihood, ) + groups_to_generate = Set(k for (k, v) in groups if v === true) + for (name, value) in groups + if value isa Bool + groups[name] = nothing + end + end + model === nothing && return from_mcmcchains(chns; library=library, groups..., kwargs...) - if groups[:prior] === nothing + if :prior in groups_to_generate groups[:prior] = Turing.sample( rng, model, @@ -43,7 +50,7 @@ function from_turing( observed_data isa Dict ? Symbol.(keys(observed_data)) : propertynames(observed_data) ) - if groups[:constant_data] === nothing + if :constant_data in groups_to_generate groups[:constant_data] = NamedTuple( filter(p -> first(p) ∉ observed_data_keys, pairs(model.args)) ) @@ -56,29 +63,36 @@ function from_turing( ) # and then sample! - if groups[:prior_predictive] === nothing && groups[:prior] isa Turing.MCMCChains.Chains - groups[:prior_predictive] = Turing.predict(rng, model_predict, groups[:prior]) + if :prior_predictive in groups_to_generate + if groups[:prior] isa Turing.MCMCChains.Chains + groups[:prior_predictive] = Turing.predict(rng, model_predict, groups[:prior]) + elseif groups[:prior] !== nothing + @warn "Could not generate group :prior_predictive because group :prior was not an MCMCChains.Chains." + end end - if chns isa Turing.MCMCChains.Chains - if groups[:posterior_predictive] === nothing && chns isa Turing.MCMCChains.Chains + if :posterior_predictive in groups_to_generate + if chns isa Turing.MCMCChains.Chains groups[:posterior_predictive] = Turing.predict(rng, model_predict, chns) + elseif chns !== nothing + @warn "Could not generate group :posterior_predictive because group :posterior was not an MCMCChains.Chains." end + end - if groups[:log_likelihood] === nothing && - groups[:posterior_predictive] isa MCMCChains.Chains - loglikelihoods = Turing.pointwise_loglikelihoods( - model, Turing.MCMCChains.get_sections(chns, :parameters) - ) - + if :log_likelihood in groups_to_generate + if chns isa Turing.MCMCChains.Chains + chains_only_params = Turing.MCMCChains.get_sections(chns, :parameters) + loglikelihoods = Turing.pointwise_loglikelihoods(model, chains_only_params) + pred_names = sort(collect(keys(loglikelihoods)); by=split_locname) + loglikelihoods_vals = getindex.(Ref(loglikelihoods), pred_names) # Bundle loglikelihoods into a `Chains` object so we can reuse our own variable # name parsing - pred_names = string.(keys(groups[:posterior_predictive])) - loglikelihoods_vals = getindex.(Ref(loglikelihoods), pred_names) loglikelihoods_arr = permutedims(cat(loglikelihoods_vals...; dims=3), (1, 3, 2)) groups[:log_likelihood] = Turing.MCMCChains.Chains( loglikelihoods_arr, pred_names ) + elseif chns !== nothing + @warn "Could not generate log_likelihood because posterior must be an MCMCChains.Chains." end end From 2c6cc324c22cfc31766b302f7f3dce614dfe2ef7 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 14:44:38 +0200 Subject: [PATCH 10/36] Constrain type of rng --- src/turing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/turing.jl b/src/turing.jl index 599225c9b..eeaa199c5 100644 --- a/src/turing.jl +++ b/src/turing.jl @@ -1,7 +1,7 @@ function from_turing( chns=nothing; model::Union{Nothing,Turing.DynamicPPL.Model}=nothing, - rng=Random.default_rng(), + rng::AbstractRNG=Random.default_rng(), nchains=ndraws = chns isa Turing.MCMCChains.Chains ? last(size(chns)) : 1, ndraws=chns isa Turing.MCMCChains.Chains ? first(size(chns)) : 1_000, library=Turing, From 134a6a12c6e2aca0d59767b2148327224ce19a43 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 15:07:58 +0200 Subject: [PATCH 11/36] Add docstring --- src/turing.jl | 60 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/src/turing.jl b/src/turing.jl index eeaa199c5..0e618cc69 100644 --- a/src/turing.jl +++ b/src/turing.jl @@ -1,3 +1,63 @@ +""" + from_turing([posterior::Chains]; kwargs...) -> InferenceData + +Convert data from Turing into an [`InferenceData`](@ref). + +This permits passing a Turing `Model` and a random number generator to +`model` and `rng` keywords to automatically generate groups. By default, +if `posterior`, `observed_data`, and `model` are provided, then the +`prior`, `prior_predictive`, `posterior_predictive`, and `log_likelihood` +groups are automatically generated. To avoid generating a group, provide +group data or set it to `false`. + +# Arguments + +- `posterior::Chains`: Draws from the posterior + +# Keywords + +- `model::Turing.DynamicPPL.Model`: A Turing model conditioned on observed and +constant data. `constant_data` must be provided for the model to be used. +- `rng::AbstractRNG=Random.default_rng()`: a random number generator used for +sampling from the prior, posterior predictive and prior predictive +distributions. +- `nchains::Int`: Number of chains for prior samples, defaulting to the number +of chains in the posterior, if provided, else 1. +- `ndraws::Int`: Number of draws per chain for prior samples, defaulting to the +number of draws per chain in the posterior, if provided, else 1,000. +- `kwargs`: For remaining keywords, see [`from_mcmcchains`](@ref). + +# Examples + +```jldoctest +julia> using Turing, Random + +julia> rng = Random.seed!(42); + +julia> @model function demo(xs, y) + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + for i in eachindex(xs) + xs[i] ~ Normal(m, √s) + end + y ~ Normal(m, √s) + end; + +julia> observed_data = (xs=[0.87, 0.08, 0.53], y=-0.85); + +julia> model = demo(observed_data...); + +julia> chn = sample(rng, model, NUTS(), 1_000; progress=false); + +julia> from_turing(chn; model=model, rng=rng, observed_data=observed_data, prior=false) +InferenceData with groups: + > posterior + > posterior_predictive + > log_likelihood + > sample_stats + > observed_data +``` +""" function from_turing( chns=nothing; model::Union{Nothing,Turing.DynamicPPL.Model}=nothing, From 9e3edb7a0b9e473e657cb5f017bfc55057d810a9 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 15:10:17 +0200 Subject: [PATCH 12/36] Document from_turing --- docs/make.jl | 2 +- docs/src/api.md | 1 + src/mcmcchains.jl | 2 ++ 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/make.jl b/docs/make.jl index 5a01f02cb..4189e6cbf 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,5 +1,5 @@ using Documenter, ArviZ -using MCMCChains: MCMCChains # make `from_mcmcchains` available for API docs +using Turing # make `from_mcmcchains` and `from_turing` available for API docs makedocs(; modules=[ArviZ], diff --git a/docs/src/api.md b/docs/src/api.md index e882101f7..42766a299 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -71,6 +71,7 @@ | [`from_namedtuple`](@ref) | Convert `NamedTuple` data into an `InferenceData`. | | [`from_dict`](@ref) | Convert `Dict` data into an `InferenceData`. | | [`from_cmdstan`](@ref) | Convert CmdStan data into an `InferenceData`. | +| [`from_turing`](@ref) | Convert data from Turing into an `InferenceData`. | | [`from_mcmcchains`](@ref) | Convert `MCMCChains` data into an `InferenceData`. | | [`concat`](@ref) | Concatenate `InferenceData` objects. | | [`concat!`](@ref) | Concatenate `InferenceData` objects in-place. | diff --git a/src/mcmcchains.jl b/src/mcmcchains.jl index 049867ae0..f736937d2 100644 --- a/src/mcmcchains.jl +++ b/src/mcmcchains.jl @@ -124,6 +124,8 @@ Convert data in an `MCMCChains.Chains` format into an [`InferenceData`](@ref). Any keyword argument below without an an explicitly annotated type above is allowed, so long as it can be passed to [`convert_to_inference_data`](@ref). +For chains data from Turing, see [`from_turing`](@ref) for more options. + # Arguments - `posterior::Chains`: Draws from the posterior From 34084768a160193cd2fab760ebe1560150e5f6b6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 15:13:01 +0200 Subject: [PATCH 13/36] Also generate constant_data --- src/turing.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/turing.jl b/src/turing.jl index 0e618cc69..2f8fc579b 100644 --- a/src/turing.jl +++ b/src/turing.jl @@ -34,10 +34,10 @@ julia> using Turing, Random julia> rng = Random.seed!(42); -julia> @model function demo(xs, y) +julia> @model function demo(xs, y, n=length(xs)) s ~ InverseGamma(2, 3) m ~ Normal(0, √s) - for i in eachindex(xs) + for i in 1:n xs[i] ~ Normal(m, √s) end y ~ Normal(m, √s) @@ -56,6 +56,7 @@ InferenceData with groups: > log_likelihood > sample_stats > observed_data + > constant_data ``` """ function from_turing( From 180e56083b6b5985926bf9d1ef2bd47ae10db304 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 16:46:46 +0200 Subject: [PATCH 14/36] Make code more modular --- src/turing.jl | 73 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 43 insertions(+), 30 deletions(-) diff --git a/src/turing.jl b/src/turing.jl index 2f8fc579b..98bdbaf4d 100644 --- a/src/turing.jl +++ b/src/turing.jl @@ -1,4 +1,4 @@ -""" +@doc doc""" from_turing([posterior::Chains]; kwargs...) -> InferenceData Convert data from Turing into an [`InferenceData`](@ref). @@ -91,15 +91,7 @@ function from_turing( model === nothing && return from_mcmcchains(chns; library=library, groups..., kwargs...) if :prior in groups_to_generate - groups[:prior] = Turing.sample( - rng, - model, - Turing.Prior(), - Turing.MCMCThreads(), - ndraws, - nchains; - progress=false, - ) + groups[:prior] = _sample_prior(rng, model, nchains, ndraws) end if groups[:observed_data] === nothing @@ -112,21 +104,16 @@ function from_turing( ) if :constant_data in groups_to_generate - groups[:constant_data] = NamedTuple( + groups[:constant_data] = Dict( filter(p -> first(p) ∉ observed_data_keys, pairs(model.args)) ) end - # Instantiate the predictive model - data_var_names = filter(in(observed_data_keys), keys(model.args)) - model_predict = Turing.DynamicPPL.Model{data_var_names}( - model.name, model.f, model.args, model.defaults - ) - - # and then sample! if :prior_predictive in groups_to_generate if groups[:prior] isa Turing.MCMCChains.Chains - groups[:prior_predictive] = Turing.predict(rng, model_predict, groups[:prior]) + groups[:prior_predictive] = _sample_predictive( + rng, model, groups[:prior], observed_data_keys + ) elseif groups[:prior] !== nothing @warn "Could not generate group :prior_predictive because group :prior was not an MCMCChains.Chains." end @@ -134,7 +121,9 @@ function from_turing( if :posterior_predictive in groups_to_generate if chns isa Turing.MCMCChains.Chains - groups[:posterior_predictive] = Turing.predict(rng, model_predict, chns) + groups[:posterior_predictive] = _sample_predictive( + rng, model, chns, observed_data_keys + ) elseif chns !== nothing @warn "Could not generate group :posterior_predictive because group :posterior was not an MCMCChains.Chains." end @@ -142,16 +131,7 @@ function from_turing( if :log_likelihood in groups_to_generate if chns isa Turing.MCMCChains.Chains - chains_only_params = Turing.MCMCChains.get_sections(chns, :parameters) - loglikelihoods = Turing.pointwise_loglikelihoods(model, chains_only_params) - pred_names = sort(collect(keys(loglikelihoods)); by=split_locname) - loglikelihoods_vals = getindex.(Ref(loglikelihoods), pred_names) - # Bundle loglikelihoods into a `Chains` object so we can reuse our own variable - # name parsing - loglikelihoods_arr = permutedims(cat(loglikelihoods_vals...; dims=3), (1, 3, 2)) - groups[:log_likelihood] = Turing.MCMCChains.Chains( - loglikelihoods_arr, pred_names - ) + groups[:log_likelihood] = _compute_log_likelihood(model, chns) elseif chns !== nothing @warn "Could not generate log_likelihood because posterior must be an MCMCChains.Chains." end @@ -167,3 +147,36 @@ function from_turing( end return idata end + +function _sample_prior(rng::AbstractRNG, model::Turing.DynamicPPL.Model, nchains, ndraws) + return Turing.sample( + rng, model, Turing.Prior(), Turing.MCMCThreads(), ndraws, nchains; progress=false + ) +end + +function _build_predictive_model(model::Turing.DynamicPPL.Model, data_keys) + var_names = filter(in(data_keys), keys(model.args)) + return Turing.DynamicPPL.Model{var_names}( + model.name, model.f, model.args, model.defaults + ) +end + +function _sample_predictive( + rng::AbstractRNG, model::Turing.DynamicPPL.Model, chns, data_keys +) + model_predict = _build_predictive_model(model, data_keys) + return Turing.predict(rng, model_predict, chns) +end + +function _compute_log_likelihood( + model::Turing.DynamicPPL.Model, chns::Turing.MCMCChains.Chains +) + chains_only_params = Turing.MCMCChains.get_sections(chns, :parameters) + loglikelihoods = Turing.pointwise_loglikelihoods(model, chains_only_params) + pred_names = sort(collect(keys(loglikelihoods)); by=split_locname) + loglikelihoods_vals = getindex.(Ref(loglikelihoods), pred_names) + # Bundle loglikelihoods into a `Chains` object so we can reuse our own variable + # name parsing + loglikelihoods_arr = permutedims(cat(loglikelihoods_vals...; dims=3), (1, 3, 2)) + return Turing.MCMCChains.Chains(loglikelihoods_arr, pred_names) +end From 85ddd695c8313b6168b052decafc36f69c63d6a9 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 16:47:01 +0200 Subject: [PATCH 15/36] Add Turing tests --- test/runtests.jl | 1 + test/test_turing.jl | 74 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 test/test_turing.jl diff --git a/test/runtests.jl b/test/runtests.jl index 20ffc06bf..18b64ad76 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,4 +13,5 @@ using Test include("test_plots.jl") include("test_namedtuple.jl") include("test_mcmcchains.jl") + include("test_turing.jl") end diff --git a/test/test_turing.jl b/test/test_turing.jl new file mode 100644 index 000000000..2dbe60c43 --- /dev/null +++ b/test/test_turing.jl @@ -0,0 +1,74 @@ +using Turing +using ArviZ +using ArviZ: groupnames +using Test +using Random + +@testset "from_turing" begin + nchains = 4 + ndraws = 10 + Turing.@model function demo(xs, y, n=length(xs)) + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + for i in 1:n + xs[i] ~ Normal(m, √s) + end + return y ~ Normal(m, √s) + end + xs = randn(5) + y = randn() + observed_data = (xs=xs, y=y) + model = demo(observed_data...) + chn = Turing.sample( + model, Turing.MH(), Turing.MCMCThreads(), ndraws, nchains; progress=false + ) + @test size(chn) == (10, 3, 4) + + idata1 = from_turing(chn) + @test sort(groupnames(idata1)) == [:posterior, :sample_stats] + @test idata1.posterior.inference_library == "Turing" + + idata2 = from_turing(; model=model) + @test sort(groupnames(idata2)) == [:prior, :sample_stats_prior] + @test length(idata2.prior.chain.values) == 1 + @test length(idata2.prior.draw.values) == 1_000 + @test idata1.posterior.inference_library == "Turing" + + idata3 = from_turing(chn; model=model) + @test sort(groupnames(idata3)) == + sort([:posterior, :sample_stats, :prior, :sample_stats_prior]) + @test length(idata3.prior.chain.values) == nchains + @test length(idata3.prior.draw.values) == ndraws + + idata4 = from_turing(chn; model=model, prior=false) + @test sort(groupnames(idata4)) == [:posterior, :sample_stats] + + idata5 = from_turing( + chn; model=model, observed_data=observed_data, nchains=3, ndraws=100 + ) + @test sort(groupnames(idata5)) == sort([ + :posterior, + :posterior_predictive, + :log_likelihood, + :sample_stats, + :prior, + :prior_predictive, + :sample_stats_prior, + :observed_data, + :constant_data, + ]) + @test length(idata5.prior.chain.values) == 3 + @test length(idata5.prior.draw.values) == 100 + + rng1 = Random.MersenneTwister(42) + idata6 = from_turing(chn; model=model, observed_data=observed_data, rng=rng1) + rng2 = Random.MersenneTwister(42) + idata7 = from_turing(chn; model=model, observed_data=observed_data, rng=rng2) + @testset for name in groupnames(idata6) + group1 = getproperty(idata6, name) + group2 = getproperty(idata7, name) + @testset for var_name in group1.variables.keys() + @test group1[var_name].values == group2[var_name].values + end + end +end From 4f50d7a21096d35a9c4868081351fd7f06cb2bcc Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 18:14:15 +0200 Subject: [PATCH 16/36] Force library to be Turing --- src/turing.jl | 8 ++++---- test/test_turing.jl | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/turing.jl b/src/turing.jl index 98bdbaf4d..b9e85e596 100644 --- a/src/turing.jl +++ b/src/turing.jl @@ -65,7 +65,6 @@ function from_turing( rng::AbstractRNG=Random.default_rng(), nchains=ndraws = chns isa Turing.MCMCChains.Chains ? last(size(chns)) : 1, ndraws=chns isa Turing.MCMCChains.Chains ? first(size(chns)) : 1_000, - library=Turing, observed_data=true, constant_data=true, posterior_predictive=true, @@ -89,13 +88,13 @@ function from_turing( end end - model === nothing && return from_mcmcchains(chns; library=library, groups..., kwargs...) + model === nothing && return from_mcmcchains(chns; groups..., library=Turing, kwargs...) if :prior in groups_to_generate groups[:prior] = _sample_prior(rng, model, nchains, ndraws) end if groups[:observed_data] === nothing - return from_mcmcchains(chns; library=library, groups..., kwargs...) + return from_mcmcchains(chns; groups..., library=Turing, kwargs...) end observed_data = groups[:observed_data] @@ -137,13 +136,14 @@ function from_turing( end end - idata = from_mcmcchains(chns; library=library, groups..., kwargs...) + idata = from_mcmcchains(chns; groups..., kwargs...) # add model name to generated InferenceData groups for name in groupnames(idata) name in (:observed_data,) && continue ds = getproperty(idata, name) setattribute!(ds, :model_name, nameof(model)) + setattribute!(ds, :inference_library, :Turing) end return idata end diff --git a/test/test_turing.jl b/test/test_turing.jl index 2dbe60c43..7bcc97d0c 100644 --- a/test/test_turing.jl +++ b/test/test_turing.jl @@ -46,6 +46,7 @@ using Random idata5 = from_turing( chn; model=model, observed_data=observed_data, nchains=3, ndraws=100 ) + @test idata5.posterior.inference_library == "Turing" @test sort(groupnames(idata5)) == sort([ :posterior, :posterior_predictive, From 1b58b6b5b75622efb7df7e543a72eb6934666cc4 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 21:39:25 +0200 Subject: [PATCH 17/36] Overload setattribute! for InferenceData --- src/data.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/data.jl b/src/data.jl index 058d38a1c..0efb513d3 100644 --- a/src/data.jl +++ b/src/data.jl @@ -165,3 +165,10 @@ function reorder_groups!(data::InferenceData; group_order=SUPPORTED_GROUPS) setproperty!(obj, :_groups, string.([sorted_names; other_names])) return data end + +function setattribute!(data::InferenceData, key, value) + for (_, group) in groups(data) + setattribute!(group, key, value) + end + return data +end From 12ca8742f809281945b9be5ea9e3e6907ffc88af Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 21:40:03 +0200 Subject: [PATCH 18/36] Add function to add inference library info --- Project.toml | 2 ++ src/ArviZ.jl | 1 + src/data.jl | 10 ++++++++++ 3 files changed, 13 insertions(+) diff --git a/Project.toml b/Project.toml index 77ab4f190..bcdfc43f4 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50" +PkgVersion = "eebad327-c553-4316-9ea0-9fa01ccd7688" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee" REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" @@ -22,6 +23,7 @@ DataFrames = "0.20, 0.21, 0.22, 1.0" MCMCChains = "0.3.15, 0.4, 1.0, 2.0, 3.0, 4.0" MonteCarloMeasurements = "0.6.4, 0.7, 0.8" NamedTupleTools = "0.11.0, 0.12, 0.13" +PkgVersion = "0.1" PyCall = "1.91.2" PyPlot = "2.8.2" Requires = "0.5.2, 1.0" diff --git a/src/ArviZ.jl b/src/ArviZ.jl index 2ce2d1084..ab828b756 100644 --- a/src/ArviZ.jl +++ b/src/ArviZ.jl @@ -7,6 +7,7 @@ using Requires using REPL using NamedTupleTools using DataFrames +using PkgVersion: PkgVersion using PyCall using Conda diff --git a/src/data.jl b/src/data.jl index 0efb513d3..9ce412850 100644 --- a/src/data.jl +++ b/src/data.jl @@ -172,3 +172,13 @@ function setattribute!(data::InferenceData, key, value) end return data end + +_add_library_attributes!(data, ::Nothing) = ds +function _add_library_attributes!(data, library) + setattribute!(data, :inference_library, string(library)) + if library isa Module + lib_version = string(PkgVersion.Version(library)) + setattribute!(data, :inference_library_version, lib_version) + end + return data +end From a7bb79fb2d1ea5b7889e17cb171df55ac8ca9119 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 21:42:24 +0200 Subject: [PATCH 19/36] Globally use library utility --- src/mcmcchains.jl | 25 +++++++------------------ src/namedtuple.jl | 14 +++----------- src/turing.jl | 14 ++++++-------- 3 files changed, 16 insertions(+), 37 deletions(-) diff --git a/src/mcmcchains.jl b/src/mcmcchains.jl index f736937d2..50bd7a1c4 100644 --- a/src/mcmcchains.jl +++ b/src/mcmcchains.jl @@ -172,7 +172,6 @@ function from_mcmcchains( kwargs..., ) kwargs = convert(Dict, merge((; dims=Dict()), kwargs)) - library = string(library) rekey_fun = d -> rekey(d, stats_key_map) # Convert chains to dicts @@ -203,23 +202,18 @@ function from_mcmcchains( group_data = popsubdict!(post_dict, group_data) end group_dataset = if group_data isa Chains - convert_to_dataset(group_data; library=library, eltypes=eltypes, kwargs...) + convert_to_dataset(group_data; eltypes=eltypes, kwargs...) else - convert_to_dataset(group_data; library=library, kwargs...) + convert_to_dataset(group_data; kwargs...) end - setattribute!(group_dataset, "inference_library", library) concat!(all_idata, InferenceData(; group => group_dataset)) end - attrs_library = Dict("inference_library" => library) - if posterior === nothing - attrs = attrs_library - else - attrs = merge(attributes_dict(posterior), attrs_library) - end + attrs = posterior === nothing ? Dict() : attributes_dict(posterior) kwargs = convert(Dict, merge((; attrs=attrs, dims=Dict()), kwargs)) post_idata = _from_dict(post_dict; sample_stats=stats_dict, kwargs...) concat!(all_idata, post_idata) + _add_library_attributes!(all_idata, library) return all_idata end function from_mcmcchains( @@ -243,18 +237,13 @@ function from_mcmcchains( posterior_predictive, predictions, log_likelihood; - library=library, eltypes=eltypes, kwargs..., ) if prior !== nothing pre_prior_idata = convert_to_inference_data( - prior; - posterior_predictive=prior_predictive, - library=library, - eltypes=eltypes, - kwargs..., + prior; posterior_predictive=prior_predictive, eltypes=eltypes, kwargs... ) prior_idata = rekey( pre_prior_idata, @@ -274,10 +263,10 @@ function from_mcmcchains( ] group_data === nothing && continue group_data = convert_to_eltypes(group_data, eltypes) - group_dataset = convert_to_constant_dataset(group_data; library=library, kwargs...) + group_dataset = convert_to_constant_dataset(group_data; kwargs...) concat!(all_idata, InferenceData(; group => group_dataset)) end - + _add_library_attributes!(all_idata, library) return all_idata end diff --git a/src/namedtuple.jl b/src/namedtuple.jl index 06b3c7b73..600bad340 100644 --- a/src/namedtuple.jl +++ b/src/namedtuple.jl @@ -140,20 +140,14 @@ function from_namedtuple( isempty(group_data) && continue end group_dataset = convert_to_dataset(group_data; kwargs...) - if library !== nothing - setattribute!(group_dataset, "inference_library", string(library)) - end concat!(all_idata, InferenceData(; group => group_dataset)) end (post_dict === nothing || isempty(post_dict)) && return all_idata group_dataset = convert_to_dataset(post_dict; kwargs...) - if library !== nothing - setattribute!(group_dataset, "inference_library", string(library)) - end concat!(all_idata, InferenceData(; posterior=group_dataset)) - + _add_library_attributes!(all_idata, library) return all_idata end function from_namedtuple( @@ -177,7 +171,6 @@ function from_namedtuple( sample_stats, predictions, log_likelihood; - library=library, kwargs..., ) @@ -186,7 +179,6 @@ function from_namedtuple( prior; posterior_predictive=prior_predictive, sample_stats=sample_stats_prior, - library=library, kwargs..., ) prior_idata = rekey( @@ -207,10 +199,10 @@ function from_namedtuple( ] group_data === nothing && continue group_dict = convert(Dict, group_data) - group_dataset = convert_to_constant_dataset(group_dict; library=library, kwargs...) + group_dataset = convert_to_constant_dataset(group_dict; kwargs...) concat!(all_idata, InferenceData(; group => group_dataset)) end - + _add_library_attributes!(all_idata, library) return all_idata end function from_namedtuple(data::AbstractVector{<:NamedTuple}; kwargs...) diff --git a/src/turing.jl b/src/turing.jl index b9e85e596..a52a996df 100644 --- a/src/turing.jl +++ b/src/turing.jl @@ -73,6 +73,9 @@ function from_turing( log_likelihood=true, kwargs..., ) + kwargs = Dict{Symbol,Any}(kwargs) + kwargs[:library] = Turing + groups = Dict{Symbol,Any}( :observed_data => observed_data, :constant_data => constant_data, @@ -88,13 +91,13 @@ function from_turing( end end - model === nothing && return from_mcmcchains(chns; groups..., library=Turing, kwargs...) + model === nothing && return from_mcmcchains(chns; groups..., kwargs...) if :prior in groups_to_generate groups[:prior] = _sample_prior(rng, model, nchains, ndraws) end if groups[:observed_data] === nothing - return from_mcmcchains(chns; groups..., library=Turing, kwargs...) + return from_mcmcchains(chns; groups..., kwargs...) end observed_data = groups[:observed_data] @@ -139,12 +142,7 @@ function from_turing( idata = from_mcmcchains(chns; groups..., kwargs...) # add model name to generated InferenceData groups - for name in groupnames(idata) - name in (:observed_data,) && continue - ds = getproperty(idata, name) - setattribute!(ds, :model_name, nameof(model)) - setattribute!(ds, :inference_library, :Turing) - end + setattribute!(idata, :model_name, nameof(model)) return idata end From 42c882382f32215bc51375d17f861c3f01ab245a Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 21:42:40 +0200 Subject: [PATCH 20/36] Test library utility for Turing --- test/test_turing.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/test_turing.jl b/test/test_turing.jl index 7bcc97d0c..a61ebb774 100644 --- a/test/test_turing.jl +++ b/test/test_turing.jl @@ -27,12 +27,14 @@ using Random idata1 = from_turing(chn) @test sort(groupnames(idata1)) == [:posterior, :sample_stats] @test idata1.posterior.inference_library == "Turing" + VersionNumber(idata1.posterior.inference_library_version) idata2 = from_turing(; model=model) @test sort(groupnames(idata2)) == [:prior, :sample_stats_prior] @test length(idata2.prior.chain.values) == 1 @test length(idata2.prior.draw.values) == 1_000 - @test idata1.posterior.inference_library == "Turing" + @test idata2.prior.inference_library == "Turing" + VersionNumber(idata2.prior.inference_library_version) idata3 = from_turing(chn; model=model) @test sort(groupnames(idata3)) == @@ -47,6 +49,8 @@ using Random chn; model=model, observed_data=observed_data, nchains=3, ndraws=100 ) @test idata5.posterior.inference_library == "Turing" + VersionNumber(idata5.posterior.inference_library_version) + @test sort(groupnames(idata5)) == sort([ :posterior, :posterior_predictive, From 842447f9c60efb3e985457c128b15f8c556c876c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 21:43:37 +0200 Subject: [PATCH 21/36] Increment version number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index bcdfc43f4..ffea08692 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ArviZ" uuid = "131c737c-5715-5e2e-ad31-c244f01c1dc7" authors = ["Seth Axen "] -version = "0.5.4" +version = "0.5.5" [deps] Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d" From c9b45624840fd0ca710fab1cd6f704c92d528a13 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 21:58:04 +0200 Subject: [PATCH 22/36] Repair Turing example --- src/turing.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/turing.jl b/src/turing.jl index a52a996df..d6a23c12a 100644 --- a/src/turing.jl +++ b/src/turing.jl @@ -30,7 +30,7 @@ number of draws per chain in the posterior, if provided, else 1,000. # Examples ```jldoctest -julia> using Turing, Random +julia> using Turing, Random, ArviZ julia> rng = Random.seed!(42); @@ -47,7 +47,7 @@ julia> observed_data = (xs=[0.87, 0.08, 0.53], y=-0.85); julia> model = demo(observed_data...); -julia> chn = sample(rng, model, NUTS(), 1_000; progress=false); +julia> chn = sample(rng, model, MH(), 1_000; progress=false); julia> from_turing(chn; model=model, rng=rng, observed_data=observed_data, prior=false) InferenceData with groups: From e0d9ae309eb13d7230dc52973d6f6fa659e26af1 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 22:05:55 +0200 Subject: [PATCH 23/36] Don't import Turing's exports Causes clash with ess --- docs/make.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/make.jl b/docs/make.jl index 4189e6cbf..bcae8bc22 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,5 +1,5 @@ using Documenter, ArviZ -using Turing # make `from_mcmcchains` and `from_turing` available for API docs +using Turing: Turing # make `from_mcmcchains` and `from_turing` available for API docs makedocs(; modules=[ArviZ], From a90df635497eb05ed9e7a06df7d0b7d2e9057a64 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 22:06:04 +0200 Subject: [PATCH 24/36] Return correct variable name --- src/data.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/data.jl b/src/data.jl index 9ce412850..d407b1a25 100644 --- a/src/data.jl +++ b/src/data.jl @@ -173,7 +173,7 @@ function setattribute!(data::InferenceData, key, value) return data end -_add_library_attributes!(data, ::Nothing) = ds +_add_library_attributes!(data, ::Nothing) = data function _add_library_attributes!(data, library) setattribute!(data, :inference_library, string(library)) if library isa Module From ea273efc38b8856a0fd80c6a39d5903367e92c07 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 22:32:34 +0200 Subject: [PATCH 25/36] Indent wrapped lines --- src/turing.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/turing.jl b/src/turing.jl index d6a23c12a..4457af78f 100644 --- a/src/turing.jl +++ b/src/turing.jl @@ -17,14 +17,14 @@ group data or set it to `false`. # Keywords - `model::Turing.DynamicPPL.Model`: A Turing model conditioned on observed and -constant data. `constant_data` must be provided for the model to be used. + constant data. `constant_data` must be provided for the model to be used. - `rng::AbstractRNG=Random.default_rng()`: a random number generator used for -sampling from the prior, posterior predictive and prior predictive -distributions. + sampling from the prior, posterior predictive and prior predictive + distributions. - `nchains::Int`: Number of chains for prior samples, defaulting to the number -of chains in the posterior, if provided, else 1. + of chains in the posterior, if provided, else 1. - `ndraws::Int`: Number of draws per chain for prior samples, defaulting to the -number of draws per chain in the posterior, if provided, else 1,000. + number of draws per chain in the posterior, if provided, else 1,000. - `kwargs`: For remaining keywords, see [`from_mcmcchains`](@ref). # Examples From 6e93fa72a8d42ab10ea9553d1922ebb33ed5b82b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 22:34:31 +0200 Subject: [PATCH 26/36] Update quickstart.md --- docs/src/quickstart.md | 49 ++++++------------------------------------ 1 file changed, 7 insertions(+), 42 deletions(-) diff --git a/docs/src/quickstart.md b/docs/src/quickstart.md index b72b5493f..6237313b5 100644 --- a/docs/src/quickstart.md +++ b/docs/src/quickstart.md @@ -121,7 +121,7 @@ idata = from_mcmcchains( turing_chns; coords=Dict("school" => schools), dims=Dict("y" => ["school"], "σ" => ["school"], "θ" => ["school"]), - library="Turing", + library=Turing, ) ``` @@ -154,52 +154,17 @@ gcf() ### Additional information in Turing.jl -With a few more steps, we can use Turing to compute additional useful groups to add to the [`InferenceData`](@ref). - -To sample from the prior, one simply calls `sample` but with the `Prior` sampler: - -```@example turing -prior = sample(param_mod, Prior(), nsamples; progress=false) -``` - -To draw from the prior and posterior predictive distributions we can instantiate a "predictive model", i.e. a Turing model but with the observations set to `missing`, and then calling `predict` on the predictive model and the previously drawn samples: - -```@example turing -# Instantiate the predictive model -param_mod_predict = turing_model(similar(y, Missing), σ) -# and then sample! -prior_predictive = predict(param_mod_predict, prior) -posterior_predictive = predict(param_mod_predict, turing_chns) -``` - -And to extract the pointwise log-likelihoods, which is useful if you want to compute metrics such as [`loo`](@ref), - -```@example turing -loglikelihoods = Turing.pointwise_loglikelihoods( - param_mod, MCMCChains.get_sections(turing_chns, :parameters) -) -``` - -This can then be included in the [`from_mcmcchains`](@ref) call from above: +We would like to compute additional useful groups to add to the [`InferenceData`](@ref). +ArviZ includes a Turing-specific converter [`from_turing`](@ref) that, given a model, posterior samples, and data, can add the missing groups: ```@example turing -using LinearAlgebra -# Ensure the ordering of the loglikelihoods matches the ordering of `posterior_predictive` -ynames = string.(keys(posterior_predictive)) -loglikelihoods_vals = getindex.(Ref(loglikelihoods), ynames) -# Reshape into `(nchains, nsamples, size(y)...)` -loglikelihoods_arr = permutedims(cat(loglikelihoods_vals...; dims=3), (2, 1, 3)) - -idata = from_mcmcchains( +idata = from_turing( turing_chns; - posterior_predictive=posterior_predictive, - log_likelihood=Dict("y" => loglikelihoods_arr), - prior=prior, - prior_predictive=prior_predictive, - observed_data=Dict("y" => y), + model=param_mod, + observed_data=Dict("y"=>y,), + rng=rng, coords=Dict("school" => schools), dims=Dict("y" => ["school"], "σ" => ["school"], "θ" => ["school"]), - library="Turing", ) ``` From 92e8a2549d9f37728ec80279de07406d090bb6fa Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 20 May 2021 23:03:12 +0200 Subject: [PATCH 27/36] Run formatter --- docs/src/quickstart.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/quickstart.md b/docs/src/quickstart.md index 6237313b5..3b028d18f 100644 --- a/docs/src/quickstart.md +++ b/docs/src/quickstart.md @@ -161,7 +161,7 @@ ArviZ includes a Turing-specific converter [`from_turing`](@ref) that, given a m idata = from_turing( turing_chns; model=param_mod, - observed_data=Dict("y"=>y,), + observed_data=Dict("y" => y), rng=rng, coords=Dict("school" => schools), dims=Dict("y" => ["school"], "σ" => ["school"], "θ" => ["school"]), From 1581eb0f89063d38c3d8b2f4df8a4d4258c0e3ad Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 21 May 2021 00:00:01 +0200 Subject: [PATCH 28/36] Deep copy arguments These are apparently overwritten by predict --- src/turing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/turing.jl b/src/turing.jl index 4457af78f..20aa8cb8b 100644 --- a/src/turing.jl +++ b/src/turing.jl @@ -155,7 +155,7 @@ end function _build_predictive_model(model::Turing.DynamicPPL.Model, data_keys) var_names = filter(in(data_keys), keys(model.args)) return Turing.DynamicPPL.Model{var_names}( - model.name, model.f, model.args, model.defaults + model.name, model.f, deepcopy(model.args), deepcopy(model.defaults) ) end From b863095a01971f48128f08dc18b9d2fe39e1e3d8 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 21 May 2021 01:07:40 +0200 Subject: [PATCH 29/36] Capture status in string --- docs/src/quickstart.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/quickstart.md b/docs/src/quickstart.md index 3b028d18f..f452d450c 100644 --- a/docs/src/quickstart.md +++ b/docs/src/quickstart.md @@ -409,7 +409,7 @@ gcf() ```@example using Pkg -Pkg.status() +Text(sprint(io -> Pkg.status(io=io))) ``` ```@example From 13ad03a168089a8d067dede4063a87c5bad1deb0 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 21 May 2021 10:28:37 +0200 Subject: [PATCH 30/36] Better handle adding library info --- src/data.jl | 9 +++++++++ src/dataset.jl | 23 ++++++++++++++--------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/data.jl b/src/data.jl index d407b1a25..0a0182fd3 100644 --- a/src/data.jl +++ b/src/data.jl @@ -173,12 +173,21 @@ function setattribute!(data::InferenceData, key, value) return data end +function deleteattribute!(data::InferenceData, key) + for (_, group) in groups(data) + deleteattribute!(group, key) + end + return data +end + _add_library_attributes!(data, ::Nothing) = data function _add_library_attributes!(data, library) setattribute!(data, :inference_library, string(library)) if library isa Module lib_version = string(PkgVersion.Version(library)) setattribute!(data, :inference_library_version, lib_version) + else + deleteattribute!(data, :inference_library_version) end return data end diff --git a/src/dataset.jl b/src/dataset.jl index 4a22f29c1..b8cc9d212 100644 --- a/src/dataset.jl +++ b/src/dataset.jl @@ -69,7 +69,14 @@ end attributes(data::Dataset) = getproperty(PyObject(data), :_attrs) function setattribute!(data::Dataset, key, value) - attrs = merge(attributes(data), Dict(key => value)) + attrs = merge(attributes(data), Dict(string(key) => value)) + setproperty!(PyObject(data), :_attrs, attrs) + return attrs +end + +function deleteattribute!(data::Dataset, key) + attrs = attributes(data) + delete!(attrs, string(key)) setproperty!(PyObject(data), :_attrs, attrs) return attrs end @@ -132,11 +139,10 @@ function convert_to_constant_dataset( end default_attrs = base.make_attrs() - if library !== nothing - default_attrs = merge(default_attrs, Dict("inference_library" => string(library))) - end attrs = merge(default_attrs, attrs) - return Dataset(; data_vars=data, coords=coords, attrs=attrs) + ds = Dataset(; data_vars=data, coords=coords, attrs=attrs) + _add_library_attributes!(ds, library) + return ds end @doc doc""" @@ -164,10 +170,9 @@ ArviZ.dict_to_dataset(Dict("x" => randn(4, 100), "y" => randn(4, 100))) dict_to_dataset function dict_to_dataset(data; library=nothing, attrs=Dict(), kwargs...) - if library !== nothing - attrs = merge(attrs, Dict("inference_library" => string(library))) - end - return arviz.dict_to_dataset(data; attrs=attrs, kwargs...) + ds = arviz.dict_to_dataset(data; attrs=attrs, kwargs...) + _add_library_attributes!(ds, library) + return ds end @doc doc""" From a188e2769f4563e52cf2540a338b47cecabb909e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 21 May 2021 10:29:59 +0200 Subject: [PATCH 31/36] Run formatter --- docs/src/quickstart.md | 2 +- src/dataset.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/quickstart.md b/docs/src/quickstart.md index f452d450c..abf66b38e 100644 --- a/docs/src/quickstart.md +++ b/docs/src/quickstart.md @@ -409,7 +409,7 @@ gcf() ```@example using Pkg -Text(sprint(io -> Pkg.status(io=io))) +Text(sprint(io -> Pkg.status(; io=io))) ``` ```@example diff --git a/src/dataset.jl b/src/dataset.jl index b8cc9d212..13d1bcf47 100644 --- a/src/dataset.jl +++ b/src/dataset.jl @@ -140,7 +140,7 @@ function convert_to_constant_dataset( default_attrs = base.make_attrs() attrs = merge(default_attrs, attrs) - ds = Dataset(; data_vars=data, coords=coords, attrs=attrs) + ds = Dataset(; data_vars=data, coords=coords, attrs=attrs) _add_library_attributes!(ds, library) return ds end From ee474aee9599164fe98e6b9002b1d9f0333d1265 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 21 May 2021 12:46:59 +0200 Subject: [PATCH 32/36] Add attribute and library tests --- test/test_data.jl | 29 +++++++++++++++++++++++++++++ test/test_dataset.jl | 16 ++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/test/test_data.jl b/test/test_data.jl index 84ba5d652..ed3af32df 100644 --- a/test/test_data.jl +++ b/test/test_data.jl @@ -32,6 +32,17 @@ using MonteCarloMeasurements: Particles @test :posterior in keys(g) end + @testset "attributes" begin + ArviZ.setattribute!(data, "tmp", 3) + for (_, group) in ArviZ.groups(data) + @test Pair("tmp", 3) ∈ ArviZ.attributes(group) + end + ArviZ.deleteattribute!(data, "tmp") + for (_, group) in ArviZ.groups(data) + @test "tmp" ∉ keys(ArviZ.attributes(group)) + end + end + @testset "isempty" begin @test !isempty(data) @test isempty(InferenceData()) @@ -67,6 +78,24 @@ using MonteCarloMeasurements: Particles end end +@testset "add library attributes" begin + idata = load_arviz_data("centered_eight") + ArviZ.deleteattribute!(idata, :inference_library) + ArviZ.setattribute!(idata, :inference_library_version, "3.0.0") + @test ArviZ._add_library_attributes!(idata, nothing) === idata + ArviZ._add_library_attributes!(idata, "MyLib") + for (_, group) in ArviZ.groups(idata) + @test Pair("inference_library", "MyLib") ∈ ArviZ.attributes(group) + @test "inference_library_version" ∉ keys(ArviZ.attributes(group)) + end + ArviZ._add_library_attributes!(idata, ArviZ) + arviz_version = string(ArviZ.PkgVersion.Version(ArviZ)) + for (_, group) in ArviZ.groups(idata) + @test Pair("inference_library", "ArviZ") ∈ ArviZ.attributes(group) + @test Pair("inference_library_version", arviz_version) ∈ ArviZ.attributes(group) + end +end + @testset "+(::InferenceData, ::InferenceData)" begin rng = MersenneTwister(42) idata1 = from_dict(; diff --git a/test/test_dataset.jl b/test/test_dataset.jl index eb5920093..f57d9e4ad 100644 --- a/test/test_dataset.jl +++ b/test/test_dataset.jl @@ -47,6 +47,22 @@ @test convert(ArviZ.Dataset, [1.0, 2.0, 3.0, 4.0]) isa ArviZ.Dataset end + @testset "attributes" begin + attrs = dataset.attrs + attrs2 = ArviZ.attributes(dataset) + @test attrs == attrs2 + ArviZ.setattribute!(dataset, "tmp1", 3) + ArviZ.setattribute!(dataset, :tmp2, 4) + attrs3 = ArviZ.attributes(dataset) + @test Pair("tmp1", 3) ∈ attrs3 + @test Pair("tmp2", 4) ∈ attrs3 + ArviZ.deleteattribute!(dataset, :tmp1) + ArviZ.deleteattribute!(dataset, "tmp2") + attrs4 = ArviZ.attributes(dataset) + @test "tmp1" ∉ keys(attrs4) + @test "tmp2" ∉ keys(attrs4) + end + @testset "show(::ArviZ.Dataset)" begin @testset "$mimetype" for mimetype in ("plain", "html") text = repr(MIME("text/$(mimetype)"), dataset) From 5ee652af8644b450392354d3a738f9db20e035bd Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 21 May 2021 23:49:20 +0200 Subject: [PATCH 33/36] Extract observed_data from model --- src/turing.jl | 41 ++++++++++++++++++++++++++++++++--------- test/test_turing.jl | 20 ++++++++++---------- 2 files changed, 42 insertions(+), 19 deletions(-) diff --git a/src/turing.jl b/src/turing.jl index 20aa8cb8b..c949ab97b 100644 --- a/src/turing.jl +++ b/src/turing.jl @@ -96,18 +96,21 @@ function from_turing( groups[:prior] = _sample_prior(rng, model, nchains, ndraws) end - if groups[:observed_data] === nothing - return from_mcmcchains(chns; groups..., kwargs...) + if :observed_data ∈ groups_to_generate + groups[:observed_data] = _get_observed_data( + model, chns isa Turing.MCMCChains.Chains ? chns : groups[:prior] + ) end observed_data = groups[:observed_data] - observed_data_keys = Set( - observed_data isa Dict ? Symbol.(keys(observed_data)) : propertynames(observed_data) - ) + if observed_data === nothing + return from_mcmcchains(chns; groups..., kwargs...) + end + observed_data_keys = Set(Symbol.(keys(observed_data))) if :constant_data in groups_to_generate groups[:constant_data] = Dict( - filter(p -> first(p) ∉ observed_data_keys, pairs(model.args)) + filter(∉(observed_data_keys) ∘ first, pairs(model.args)) ) end @@ -117,7 +120,7 @@ function from_turing( rng, model, groups[:prior], observed_data_keys ) elseif groups[:prior] !== nothing - @warn "Could not generate group :prior_predictive because group :prior was not an MCMCChains.Chains." + @warn "Could not generate group :prior_predictive because group :prior is not an MCMCChains.Chains." end end @@ -127,7 +130,7 @@ function from_turing( rng, model, chns, observed_data_keys ) elseif chns !== nothing - @warn "Could not generate group :posterior_predictive because group :posterior was not an MCMCChains.Chains." + @warn "Could not generate group :posterior_predictive because group :posterior is not an MCMCChains.Chains." end end @@ -135,7 +138,7 @@ function from_turing( if chns isa Turing.MCMCChains.Chains groups[:log_likelihood] = _compute_log_likelihood(model, chns) elseif chns !== nothing - @warn "Could not generate log_likelihood because posterior must be an MCMCChains.Chains." + @warn "Could not generate group :log_likelihood because group :posterior is not an MCMCChains.Chains." end end @@ -178,3 +181,23 @@ function _compute_log_likelihood( loglikelihoods_arr = permutedims(cat(loglikelihoods_vals...; dims=3), (1, 3, 2)) return Turing.MCMCChains.Chains(loglikelihoods_arr, pred_names) end + +function _get_observed_data(model::Turing.DynamicPPL.Model, chns) + # use likelihood to find nmes of data variables + if chns isa Turing.MCMCChains.Chains + chns_small = chns[1, :, 1] + else + chns_small = _sample_prior(Random.MersenneTwister(0), model, 1, 1) + end + log_like = _compute_log_likelihood(model, chns_small) + obs_data_keys = keys(Turing.MCMCChains.get_params(log_like)) + # get values of data variables stored in model + args = model.args + args_obs_data_keys = filter(in(keys(args)), obs_data_keys) + if args_obs_data_keys != obs_data_keys + @warn "Failed to extract group :observed_data from model. Expected keys " * + "$(obs_data_keys) but only found keys $(keys(obs_data)) in model arguments." + return nothing + end + return NamedTuple{args_obs_data_keys}(getproperty.(Ref(args), args_obs_data_keys)) +end diff --git a/test/test_turing.jl b/test/test_turing.jl index a61ebb774..6f0b6ffb6 100644 --- a/test/test_turing.jl +++ b/test/test_turing.jl @@ -5,7 +5,7 @@ using Test using Random @testset "from_turing" begin - nchains = 4 + nchains = 2 ndraws = 10 Turing.@model function demo(xs, y, n=length(xs)) s ~ InverseGamma(2, 3) @@ -22,14 +22,14 @@ using Random chn = Turing.sample( model, Turing.MH(), Turing.MCMCThreads(), ndraws, nchains; progress=false ) - @test size(chn) == (10, 3, 4) + @test size(chn) == (ndraws, 3, nchains) idata1 = from_turing(chn) @test sort(groupnames(idata1)) == [:posterior, :sample_stats] @test idata1.posterior.inference_library == "Turing" VersionNumber(idata1.posterior.inference_library_version) - idata2 = from_turing(; model=model) + idata2 = from_turing(; model=model, observed_data=false) @test sort(groupnames(idata2)) == [:prior, :sample_stats_prior] @test length(idata2.prior.chain.values) == 1 @test length(idata2.prior.draw.values) == 1_000 @@ -42,15 +42,12 @@ using Random @test length(idata3.prior.chain.values) == nchains @test length(idata3.prior.draw.values) == ndraws - idata4 = from_turing(chn; model=model, prior=false) + idata4 = from_turing(chn; model=model, prior=false, observed_data=false) @test sort(groupnames(idata4)) == [:posterior, :sample_stats] - idata5 = from_turing( - chn; model=model, observed_data=observed_data, nchains=3, ndraws=100 - ) + idata5 = from_turing(chn; model=model, nchains=3, ndraws=100) @test idata5.posterior.inference_library == "Turing" VersionNumber(idata5.posterior.inference_library_version) - @test sort(groupnames(idata5)) == sort([ :posterior, :posterior_predictive, @@ -64,11 +61,14 @@ using Random ]) @test length(idata5.prior.chain.values) == 3 @test length(idata5.prior.draw.values) == 100 + @test idata5.observed_data.xs.values == xs + @test only(idata5.observed_data.y.values) == y + @test only(idata5.constant_data.n.values) == 5 rng1 = Random.MersenneTwister(42) - idata6 = from_turing(chn; model=model, observed_data=observed_data, rng=rng1) + idata6 = from_turing(chn; model=model, rng=rng1) rng2 = Random.MersenneTwister(42) - idata7 = from_turing(chn; model=model, observed_data=observed_data, rng=rng2) + idata7 = from_turing(chn; model=model, rng=rng2) @testset for name in groupnames(idata6) group1 = getproperty(idata6, name) group2 = getproperty(idata7, name) From 09c37da7be947b0d23f3dcc89b142493eca22892 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 21 May 2021 23:52:36 +0200 Subject: [PATCH 34/36] Update example --- src/turing.jl | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/turing.jl b/src/turing.jl index c949ab97b..c296887d1 100644 --- a/src/turing.jl +++ b/src/turing.jl @@ -5,10 +5,9 @@ Convert data from Turing into an [`InferenceData`](@ref). This permits passing a Turing `Model` and a random number generator to `model` and `rng` keywords to automatically generate groups. By default, -if `posterior`, `observed_data`, and `model` are provided, then the -`prior`, `prior_predictive`, `posterior_predictive`, and `log_likelihood` -groups are automatically generated. To avoid generating a group, provide -group data or set it to `false`. +if `posterior` and `model` are provided, then all remaining groups are +automatically generated. To avoid generating a group, provide group data +or set it to `false`. # Arguments @@ -17,7 +16,7 @@ group data or set it to `false`. # Keywords - `model::Turing.DynamicPPL.Model`: A Turing model conditioned on observed and - constant data. `constant_data` must be provided for the model to be used. + constant data. - `rng::AbstractRNG=Random.default_rng()`: a random number generator used for sampling from the prior, posterior predictive and prior predictive distributions. @@ -43,13 +42,11 @@ julia> @model function demo(xs, y, n=length(xs)) y ~ Normal(m, √s) end; -julia> observed_data = (xs=[0.87, 0.08, 0.53], y=-0.85); +julia> model = demo(randn(3), randn()); -julia> model = demo(observed_data...); +julia> chn = sample(rng, model, MH(), 100; progress=false); -julia> chn = sample(rng, model, MH(), 1_000; progress=false); - -julia> from_turing(chn; model=model, rng=rng, observed_data=observed_data, prior=false) +julia> idata = from_turing(chn; model=model, rng=rng, prior=false) InferenceData with groups: > posterior > posterior_predictive From a05bfd230f1bf7b32ad2d733529f23c751ad5a8a Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 22 May 2021 00:43:58 +0200 Subject: [PATCH 35/36] Update quickstart --- docs/src/quickstart.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/src/quickstart.md b/docs/src/quickstart.md index abf66b38e..045d362b7 100644 --- a/docs/src/quickstart.md +++ b/docs/src/quickstart.md @@ -161,7 +161,6 @@ ArviZ includes a Turing-specific converter [`from_turing`](@ref) that, given a m idata = from_turing( turing_chns; model=param_mod, - observed_data=Dict("y" => y), rng=rng, coords=Dict("school" => schools), dims=Dict("y" => ["school"], "σ" => ["school"], "θ" => ["school"]), From bf474a11597432d9dadb5c495386d520122c1eea Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 22 May 2021 12:57:33 +0200 Subject: [PATCH 36/36] Fix test --- test/test_turing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_turing.jl b/test/test_turing.jl index 6f0b6ffb6..444f84192 100644 --- a/test/test_turing.jl +++ b/test/test_turing.jl @@ -36,7 +36,7 @@ using Random @test idata2.prior.inference_library == "Turing" VersionNumber(idata2.prior.inference_library_version) - idata3 = from_turing(chn; model=model) + idata3 = from_turing(chn; model=model, observed_data=false) @test sort(groupnames(idata3)) == sort([:posterior, :sample_stats, :prior, :sample_stats_prior]) @test length(idata3.prior.chain.values) == nchains