diff --git a/Project.toml b/Project.toml index 5740e60c9..ffea08692 100644 --- a/Project.toml +++ b/Project.toml @@ -1,16 +1,18 @@ name = "ArviZ" uuid = "131c737c-5715-5e2e-ad31-c244f01c1dc7" authors = ["Seth Axen "] -version = "0.5.4" +version = "0.5.5" [deps] Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50" +PkgVersion = "eebad327-c553-4316-9ea0-9fa01ccd7688" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee" REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -21,10 +23,12 @@ DataFrames = "0.20, 0.21, 0.22, 1.0" MCMCChains = "0.3.15, 0.4, 1.0, 2.0, 3.0, 4.0" MonteCarloMeasurements = "0.6.4, 0.7, 0.8" NamedTupleTools = "0.11.0, 0.12, 0.13" +PkgVersion = "0.1" PyCall = "1.91.2" PyPlot = "2.8.2" Requires = "0.5.2, 1.0" StatsBase = "0.32, 0.33" +Turing = "0.15" julia = "^1" [extras] @@ -33,6 +37,7 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [targets] -test = ["CmdStan", "MCMCChains", "MonteCarloMeasurements", "Random", "Test"] +test = ["CmdStan", "MCMCChains", "MonteCarloMeasurements", "Random", "Test", "Turing"] diff --git a/docs/make.jl b/docs/make.jl index 5a01f02cb..bcae8bc22 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,5 +1,5 @@ using Documenter, ArviZ -using MCMCChains: MCMCChains # make `from_mcmcchains` available for API docs +using Turing: Turing # make `from_mcmcchains` and `from_turing` available for API docs makedocs(; modules=[ArviZ], diff --git a/docs/src/api.md b/docs/src/api.md index e882101f7..42766a299 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -71,6 +71,7 @@ | [`from_namedtuple`](@ref) | Convert `NamedTuple` data into an `InferenceData`. | | [`from_dict`](@ref) | Convert `Dict` data into an `InferenceData`. | | [`from_cmdstan`](@ref) | Convert CmdStan data into an `InferenceData`. | +| [`from_turing`](@ref) | Convert data from Turing into an `InferenceData`. | | [`from_mcmcchains`](@ref) | Convert `MCMCChains` data into an `InferenceData`. | | [`concat`](@ref) | Concatenate `InferenceData` objects. | | [`concat!`](@ref) | Concatenate `InferenceData` objects in-place. | diff --git a/docs/src/quickstart.md b/docs/src/quickstart.md index b72b5493f..045d362b7 100644 --- a/docs/src/quickstart.md +++ b/docs/src/quickstart.md @@ -121,7 +121,7 @@ idata = from_mcmcchains( turing_chns; coords=Dict("school" => schools), dims=Dict("y" => ["school"], "σ" => ["school"], "θ" => ["school"]), - library="Turing", + library=Turing, ) ``` @@ -154,52 +154,16 @@ gcf() ### Additional information in Turing.jl -With a few more steps, we can use Turing to compute additional useful groups to add to the [`InferenceData`](@ref). - -To sample from the prior, one simply calls `sample` but with the `Prior` sampler: - -```@example turing -prior = sample(param_mod, Prior(), nsamples; progress=false) -``` - -To draw from the prior and posterior predictive distributions we can instantiate a "predictive model", i.e. a Turing model but with the observations set to `missing`, and then calling `predict` on the predictive model and the previously drawn samples: - -```@example turing -# Instantiate the predictive model -param_mod_predict = turing_model(similar(y, Missing), σ) -# and then sample! -prior_predictive = predict(param_mod_predict, prior) -posterior_predictive = predict(param_mod_predict, turing_chns) -``` - -And to extract the pointwise log-likelihoods, which is useful if you want to compute metrics such as [`loo`](@ref), - -```@example turing -loglikelihoods = Turing.pointwise_loglikelihoods( - param_mod, MCMCChains.get_sections(turing_chns, :parameters) -) -``` - -This can then be included in the [`from_mcmcchains`](@ref) call from above: +We would like to compute additional useful groups to add to the [`InferenceData`](@ref). +ArviZ includes a Turing-specific converter [`from_turing`](@ref) that, given a model, posterior samples, and data, can add the missing groups: ```@example turing -using LinearAlgebra -# Ensure the ordering of the loglikelihoods matches the ordering of `posterior_predictive` -ynames = string.(keys(posterior_predictive)) -loglikelihoods_vals = getindex.(Ref(loglikelihoods), ynames) -# Reshape into `(nchains, nsamples, size(y)...)` -loglikelihoods_arr = permutedims(cat(loglikelihoods_vals...; dims=3), (2, 1, 3)) - -idata = from_mcmcchains( +idata = from_turing( turing_chns; - posterior_predictive=posterior_predictive, - log_likelihood=Dict("y" => loglikelihoods_arr), - prior=prior, - prior_predictive=prior_predictive, - observed_data=Dict("y" => y), + model=param_mod, + rng=rng, coords=Dict("school" => schools), dims=Dict("y" => ["school"], "σ" => ["school"], "θ" => ["school"]), - library="Turing", ) ``` @@ -444,7 +408,7 @@ gcf() ```@example using Pkg -Pkg.status() +Text(sprint(io -> Pkg.status(; io=io))) ``` ```@example diff --git a/src/ArviZ.jl b/src/ArviZ.jl index eb6926ff3..ab828b756 100644 --- a/src/ArviZ.jl +++ b/src/ArviZ.jl @@ -2,10 +2,12 @@ __precompile__() module ArviZ using Base: @__doc__ +using Random using Requires using REPL using NamedTupleTools using DataFrames +using PkgVersion: PkgVersion using PyCall using Conda @@ -76,6 +78,7 @@ export InferenceData, from_dict, from_cmdstan, from_mcmcchains, + from_turing, concat, concat! @@ -109,6 +112,10 @@ function __init__() import .MCMCChains: Chains, sections include("mcmcchains.jl") end + @require Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" begin + import .Turing: Turing + include("turing.jl") + end return nothing end diff --git a/src/data.jl b/src/data.jl index 058d38a1c..0a0182fd3 100644 --- a/src/data.jl +++ b/src/data.jl @@ -165,3 +165,29 @@ function reorder_groups!(data::InferenceData; group_order=SUPPORTED_GROUPS) setproperty!(obj, :_groups, string.([sorted_names; other_names])) return data end + +function setattribute!(data::InferenceData, key, value) + for (_, group) in groups(data) + setattribute!(group, key, value) + end + return data +end + +function deleteattribute!(data::InferenceData, key) + for (_, group) in groups(data) + deleteattribute!(group, key) + end + return data +end + +_add_library_attributes!(data, ::Nothing) = data +function _add_library_attributes!(data, library) + setattribute!(data, :inference_library, string(library)) + if library isa Module + lib_version = string(PkgVersion.Version(library)) + setattribute!(data, :inference_library_version, lib_version) + else + deleteattribute!(data, :inference_library_version) + end + return data +end diff --git a/src/dataset.jl b/src/dataset.jl index 4a22f29c1..13d1bcf47 100644 --- a/src/dataset.jl +++ b/src/dataset.jl @@ -69,7 +69,14 @@ end attributes(data::Dataset) = getproperty(PyObject(data), :_attrs) function setattribute!(data::Dataset, key, value) - attrs = merge(attributes(data), Dict(key => value)) + attrs = merge(attributes(data), Dict(string(key) => value)) + setproperty!(PyObject(data), :_attrs, attrs) + return attrs +end + +function deleteattribute!(data::Dataset, key) + attrs = attributes(data) + delete!(attrs, string(key)) setproperty!(PyObject(data), :_attrs, attrs) return attrs end @@ -132,11 +139,10 @@ function convert_to_constant_dataset( end default_attrs = base.make_attrs() - if library !== nothing - default_attrs = merge(default_attrs, Dict("inference_library" => string(library))) - end attrs = merge(default_attrs, attrs) - return Dataset(; data_vars=data, coords=coords, attrs=attrs) + ds = Dataset(; data_vars=data, coords=coords, attrs=attrs) + _add_library_attributes!(ds, library) + return ds end @doc doc""" @@ -164,10 +170,9 @@ ArviZ.dict_to_dataset(Dict("x" => randn(4, 100), "y" => randn(4, 100))) dict_to_dataset function dict_to_dataset(data; library=nothing, attrs=Dict(), kwargs...) - if library !== nothing - attrs = merge(attrs, Dict("inference_library" => string(library))) - end - return arviz.dict_to_dataset(data; attrs=attrs, kwargs...) + ds = arviz.dict_to_dataset(data; attrs=attrs, kwargs...) + _add_library_attributes!(ds, library) + return ds end @doc doc""" diff --git a/src/mcmcchains.jl b/src/mcmcchains.jl index 049867ae0..50bd7a1c4 100644 --- a/src/mcmcchains.jl +++ b/src/mcmcchains.jl @@ -124,6 +124,8 @@ Convert data in an `MCMCChains.Chains` format into an [`InferenceData`](@ref). Any keyword argument below without an an explicitly annotated type above is allowed, so long as it can be passed to [`convert_to_inference_data`](@ref). +For chains data from Turing, see [`from_turing`](@ref) for more options. + # Arguments - `posterior::Chains`: Draws from the posterior @@ -170,7 +172,6 @@ function from_mcmcchains( kwargs..., ) kwargs = convert(Dict, merge((; dims=Dict()), kwargs)) - library = string(library) rekey_fun = d -> rekey(d, stats_key_map) # Convert chains to dicts @@ -201,23 +202,18 @@ function from_mcmcchains( group_data = popsubdict!(post_dict, group_data) end group_dataset = if group_data isa Chains - convert_to_dataset(group_data; library=library, eltypes=eltypes, kwargs...) + convert_to_dataset(group_data; eltypes=eltypes, kwargs...) else - convert_to_dataset(group_data; library=library, kwargs...) + convert_to_dataset(group_data; kwargs...) end - setattribute!(group_dataset, "inference_library", library) concat!(all_idata, InferenceData(; group => group_dataset)) end - attrs_library = Dict("inference_library" => library) - if posterior === nothing - attrs = attrs_library - else - attrs = merge(attributes_dict(posterior), attrs_library) - end + attrs = posterior === nothing ? Dict() : attributes_dict(posterior) kwargs = convert(Dict, merge((; attrs=attrs, dims=Dict()), kwargs)) post_idata = _from_dict(post_dict; sample_stats=stats_dict, kwargs...) concat!(all_idata, post_idata) + _add_library_attributes!(all_idata, library) return all_idata end function from_mcmcchains( @@ -241,18 +237,13 @@ function from_mcmcchains( posterior_predictive, predictions, log_likelihood; - library=library, eltypes=eltypes, kwargs..., ) if prior !== nothing pre_prior_idata = convert_to_inference_data( - prior; - posterior_predictive=prior_predictive, - library=library, - eltypes=eltypes, - kwargs..., + prior; posterior_predictive=prior_predictive, eltypes=eltypes, kwargs... ) prior_idata = rekey( pre_prior_idata, @@ -272,10 +263,10 @@ function from_mcmcchains( ] group_data === nothing && continue group_data = convert_to_eltypes(group_data, eltypes) - group_dataset = convert_to_constant_dataset(group_data; library=library, kwargs...) + group_dataset = convert_to_constant_dataset(group_data; kwargs...) concat!(all_idata, InferenceData(; group => group_dataset)) end - + _add_library_attributes!(all_idata, library) return all_idata end diff --git a/src/namedtuple.jl b/src/namedtuple.jl index 06b3c7b73..600bad340 100644 --- a/src/namedtuple.jl +++ b/src/namedtuple.jl @@ -140,20 +140,14 @@ function from_namedtuple( isempty(group_data) && continue end group_dataset = convert_to_dataset(group_data; kwargs...) - if library !== nothing - setattribute!(group_dataset, "inference_library", string(library)) - end concat!(all_idata, InferenceData(; group => group_dataset)) end (post_dict === nothing || isempty(post_dict)) && return all_idata group_dataset = convert_to_dataset(post_dict; kwargs...) - if library !== nothing - setattribute!(group_dataset, "inference_library", string(library)) - end concat!(all_idata, InferenceData(; posterior=group_dataset)) - + _add_library_attributes!(all_idata, library) return all_idata end function from_namedtuple( @@ -177,7 +171,6 @@ function from_namedtuple( sample_stats, predictions, log_likelihood; - library=library, kwargs..., ) @@ -186,7 +179,6 @@ function from_namedtuple( prior; posterior_predictive=prior_predictive, sample_stats=sample_stats_prior, - library=library, kwargs..., ) prior_idata = rekey( @@ -207,10 +199,10 @@ function from_namedtuple( ] group_data === nothing && continue group_dict = convert(Dict, group_data) - group_dataset = convert_to_constant_dataset(group_dict; library=library, kwargs...) + group_dataset = convert_to_constant_dataset(group_dict; kwargs...) concat!(all_idata, InferenceData(; group => group_dataset)) end - + _add_library_attributes!(all_idata, library) return all_idata end function from_namedtuple(data::AbstractVector{<:NamedTuple}; kwargs...) diff --git a/src/turing.jl b/src/turing.jl new file mode 100644 index 000000000..c296887d1 --- /dev/null +++ b/src/turing.jl @@ -0,0 +1,200 @@ +@doc doc""" + from_turing([posterior::Chains]; kwargs...) -> InferenceData + +Convert data from Turing into an [`InferenceData`](@ref). + +This permits passing a Turing `Model` and a random number generator to +`model` and `rng` keywords to automatically generate groups. By default, +if `posterior` and `model` are provided, then all remaining groups are +automatically generated. To avoid generating a group, provide group data +or set it to `false`. + +# Arguments + +- `posterior::Chains`: Draws from the posterior + +# Keywords + +- `model::Turing.DynamicPPL.Model`: A Turing model conditioned on observed and + constant data. +- `rng::AbstractRNG=Random.default_rng()`: a random number generator used for + sampling from the prior, posterior predictive and prior predictive + distributions. +- `nchains::Int`: Number of chains for prior samples, defaulting to the number + of chains in the posterior, if provided, else 1. +- `ndraws::Int`: Number of draws per chain for prior samples, defaulting to the + number of draws per chain in the posterior, if provided, else 1,000. +- `kwargs`: For remaining keywords, see [`from_mcmcchains`](@ref). + +# Examples + +```jldoctest +julia> using Turing, Random, ArviZ + +julia> rng = Random.seed!(42); + +julia> @model function demo(xs, y, n=length(xs)) + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + for i in 1:n + xs[i] ~ Normal(m, √s) + end + y ~ Normal(m, √s) + end; + +julia> model = demo(randn(3), randn()); + +julia> chn = sample(rng, model, MH(), 100; progress=false); + +julia> idata = from_turing(chn; model=model, rng=rng, prior=false) +InferenceData with groups: + > posterior + > posterior_predictive + > log_likelihood + > sample_stats + > observed_data + > constant_data +``` +""" +function from_turing( + chns=nothing; + model::Union{Nothing,Turing.DynamicPPL.Model}=nothing, + rng::AbstractRNG=Random.default_rng(), + nchains=ndraws = chns isa Turing.MCMCChains.Chains ? last(size(chns)) : 1, + ndraws=chns isa Turing.MCMCChains.Chains ? first(size(chns)) : 1_000, + observed_data=true, + constant_data=true, + posterior_predictive=true, + prior=true, + prior_predictive=true, + log_likelihood=true, + kwargs..., +) + kwargs = Dict{Symbol,Any}(kwargs) + kwargs[:library] = Turing + + groups = Dict{Symbol,Any}( + :observed_data => observed_data, + :constant_data => constant_data, + :posterior_predictive => posterior_predictive, + :prior => prior, + :prior_predictive => prior_predictive, + :log_likelihood => log_likelihood, + ) + groups_to_generate = Set(k for (k, v) in groups if v === true) + for (name, value) in groups + if value isa Bool + groups[name] = nothing + end + end + + model === nothing && return from_mcmcchains(chns; groups..., kwargs...) + if :prior in groups_to_generate + groups[:prior] = _sample_prior(rng, model, nchains, ndraws) + end + + if :observed_data ∈ groups_to_generate + groups[:observed_data] = _get_observed_data( + model, chns isa Turing.MCMCChains.Chains ? chns : groups[:prior] + ) + end + + observed_data = groups[:observed_data] + if observed_data === nothing + return from_mcmcchains(chns; groups..., kwargs...) + end + observed_data_keys = Set(Symbol.(keys(observed_data))) + + if :constant_data in groups_to_generate + groups[:constant_data] = Dict( + filter(∉(observed_data_keys) ∘ first, pairs(model.args)) + ) + end + + if :prior_predictive in groups_to_generate + if groups[:prior] isa Turing.MCMCChains.Chains + groups[:prior_predictive] = _sample_predictive( + rng, model, groups[:prior], observed_data_keys + ) + elseif groups[:prior] !== nothing + @warn "Could not generate group :prior_predictive because group :prior is not an MCMCChains.Chains." + end + end + + if :posterior_predictive in groups_to_generate + if chns isa Turing.MCMCChains.Chains + groups[:posterior_predictive] = _sample_predictive( + rng, model, chns, observed_data_keys + ) + elseif chns !== nothing + @warn "Could not generate group :posterior_predictive because group :posterior is not an MCMCChains.Chains." + end + end + + if :log_likelihood in groups_to_generate + if chns isa Turing.MCMCChains.Chains + groups[:log_likelihood] = _compute_log_likelihood(model, chns) + elseif chns !== nothing + @warn "Could not generate group :log_likelihood because group :posterior is not an MCMCChains.Chains." + end + end + + idata = from_mcmcchains(chns; groups..., kwargs...) + + # add model name to generated InferenceData groups + setattribute!(idata, :model_name, nameof(model)) + return idata +end + +function _sample_prior(rng::AbstractRNG, model::Turing.DynamicPPL.Model, nchains, ndraws) + return Turing.sample( + rng, model, Turing.Prior(), Turing.MCMCThreads(), ndraws, nchains; progress=false + ) +end + +function _build_predictive_model(model::Turing.DynamicPPL.Model, data_keys) + var_names = filter(in(data_keys), keys(model.args)) + return Turing.DynamicPPL.Model{var_names}( + model.name, model.f, deepcopy(model.args), deepcopy(model.defaults) + ) +end + +function _sample_predictive( + rng::AbstractRNG, model::Turing.DynamicPPL.Model, chns, data_keys +) + model_predict = _build_predictive_model(model, data_keys) + return Turing.predict(rng, model_predict, chns) +end + +function _compute_log_likelihood( + model::Turing.DynamicPPL.Model, chns::Turing.MCMCChains.Chains +) + chains_only_params = Turing.MCMCChains.get_sections(chns, :parameters) + loglikelihoods = Turing.pointwise_loglikelihoods(model, chains_only_params) + pred_names = sort(collect(keys(loglikelihoods)); by=split_locname) + loglikelihoods_vals = getindex.(Ref(loglikelihoods), pred_names) + # Bundle loglikelihoods into a `Chains` object so we can reuse our own variable + # name parsing + loglikelihoods_arr = permutedims(cat(loglikelihoods_vals...; dims=3), (1, 3, 2)) + return Turing.MCMCChains.Chains(loglikelihoods_arr, pred_names) +end + +function _get_observed_data(model::Turing.DynamicPPL.Model, chns) + # use likelihood to find nmes of data variables + if chns isa Turing.MCMCChains.Chains + chns_small = chns[1, :, 1] + else + chns_small = _sample_prior(Random.MersenneTwister(0), model, 1, 1) + end + log_like = _compute_log_likelihood(model, chns_small) + obs_data_keys = keys(Turing.MCMCChains.get_params(log_like)) + # get values of data variables stored in model + args = model.args + args_obs_data_keys = filter(in(keys(args)), obs_data_keys) + if args_obs_data_keys != obs_data_keys + @warn "Failed to extract group :observed_data from model. Expected keys " * + "$(obs_data_keys) but only found keys $(keys(obs_data)) in model arguments." + return nothing + end + return NamedTuple{args_obs_data_keys}(getproperty.(Ref(args), args_obs_data_keys)) +end diff --git a/src/utils.jl b/src/utils.jl index d907d5c13..486541970 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -213,12 +213,17 @@ _asstringkeydict(x::Dict{String}) = x enforce_stat_eltypes(stats) = convert_to_eltypes(stats, sample_stats_eltypes) +_convert_to_eltype(v::AbstractArray, T) = convert(Array{T}, v) +_convert_to_eltype(v, T) = convert(T, v) function convert_to_eltypes(data::Dict, data_eltypes) - return Dict(k => convert(Array{get(data_eltypes, k, eltype(v))}, v) for (k, v) in data) + return Dict( + k => _convert_to_eltype(v, get(data_eltypes, k, eltype(v))) for (k, v) in data + ) end function convert_to_eltypes(data::NamedTuple, data_eltypes) return NamedTuple( - k => convert(Array{get(data_eltypes, k, eltype(v))}, v) for (k, v) in pairs(data) + k => _convert_to_eltype(v, get(data_eltypes, k, eltype(v))) for + (k, v) in pairs(data) ) end diff --git a/test/runtests.jl b/test/runtests.jl index 20ffc06bf..18b64ad76 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,4 +13,5 @@ using Test include("test_plots.jl") include("test_namedtuple.jl") include("test_mcmcchains.jl") + include("test_turing.jl") end diff --git a/test/test_data.jl b/test/test_data.jl index 84ba5d652..ed3af32df 100644 --- a/test/test_data.jl +++ b/test/test_data.jl @@ -32,6 +32,17 @@ using MonteCarloMeasurements: Particles @test :posterior in keys(g) end + @testset "attributes" begin + ArviZ.setattribute!(data, "tmp", 3) + for (_, group) in ArviZ.groups(data) + @test Pair("tmp", 3) ∈ ArviZ.attributes(group) + end + ArviZ.deleteattribute!(data, "tmp") + for (_, group) in ArviZ.groups(data) + @test "tmp" ∉ keys(ArviZ.attributes(group)) + end + end + @testset "isempty" begin @test !isempty(data) @test isempty(InferenceData()) @@ -67,6 +78,24 @@ using MonteCarloMeasurements: Particles end end +@testset "add library attributes" begin + idata = load_arviz_data("centered_eight") + ArviZ.deleteattribute!(idata, :inference_library) + ArviZ.setattribute!(idata, :inference_library_version, "3.0.0") + @test ArviZ._add_library_attributes!(idata, nothing) === idata + ArviZ._add_library_attributes!(idata, "MyLib") + for (_, group) in ArviZ.groups(idata) + @test Pair("inference_library", "MyLib") ∈ ArviZ.attributes(group) + @test "inference_library_version" ∉ keys(ArviZ.attributes(group)) + end + ArviZ._add_library_attributes!(idata, ArviZ) + arviz_version = string(ArviZ.PkgVersion.Version(ArviZ)) + for (_, group) in ArviZ.groups(idata) + @test Pair("inference_library", "ArviZ") ∈ ArviZ.attributes(group) + @test Pair("inference_library_version", arviz_version) ∈ ArviZ.attributes(group) + end +end + @testset "+(::InferenceData, ::InferenceData)" begin rng = MersenneTwister(42) idata1 = from_dict(; diff --git a/test/test_dataset.jl b/test/test_dataset.jl index eb5920093..f57d9e4ad 100644 --- a/test/test_dataset.jl +++ b/test/test_dataset.jl @@ -47,6 +47,22 @@ @test convert(ArviZ.Dataset, [1.0, 2.0, 3.0, 4.0]) isa ArviZ.Dataset end + @testset "attributes" begin + attrs = dataset.attrs + attrs2 = ArviZ.attributes(dataset) + @test attrs == attrs2 + ArviZ.setattribute!(dataset, "tmp1", 3) + ArviZ.setattribute!(dataset, :tmp2, 4) + attrs3 = ArviZ.attributes(dataset) + @test Pair("tmp1", 3) ∈ attrs3 + @test Pair("tmp2", 4) ∈ attrs3 + ArviZ.deleteattribute!(dataset, :tmp1) + ArviZ.deleteattribute!(dataset, "tmp2") + attrs4 = ArviZ.attributes(dataset) + @test "tmp1" ∉ keys(attrs4) + @test "tmp2" ∉ keys(attrs4) + end + @testset "show(::ArviZ.Dataset)" begin @testset "$mimetype" for mimetype in ("plain", "html") text = repr(MIME("text/$(mimetype)"), dataset) diff --git a/test/test_turing.jl b/test/test_turing.jl new file mode 100644 index 000000000..444f84192 --- /dev/null +++ b/test/test_turing.jl @@ -0,0 +1,79 @@ +using Turing +using ArviZ +using ArviZ: groupnames +using Test +using Random + +@testset "from_turing" begin + nchains = 2 + ndraws = 10 + Turing.@model function demo(xs, y, n=length(xs)) + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + for i in 1:n + xs[i] ~ Normal(m, √s) + end + return y ~ Normal(m, √s) + end + xs = randn(5) + y = randn() + observed_data = (xs=xs, y=y) + model = demo(observed_data...) + chn = Turing.sample( + model, Turing.MH(), Turing.MCMCThreads(), ndraws, nchains; progress=false + ) + @test size(chn) == (ndraws, 3, nchains) + + idata1 = from_turing(chn) + @test sort(groupnames(idata1)) == [:posterior, :sample_stats] + @test idata1.posterior.inference_library == "Turing" + VersionNumber(idata1.posterior.inference_library_version) + + idata2 = from_turing(; model=model, observed_data=false) + @test sort(groupnames(idata2)) == [:prior, :sample_stats_prior] + @test length(idata2.prior.chain.values) == 1 + @test length(idata2.prior.draw.values) == 1_000 + @test idata2.prior.inference_library == "Turing" + VersionNumber(idata2.prior.inference_library_version) + + idata3 = from_turing(chn; model=model, observed_data=false) + @test sort(groupnames(idata3)) == + sort([:posterior, :sample_stats, :prior, :sample_stats_prior]) + @test length(idata3.prior.chain.values) == nchains + @test length(idata3.prior.draw.values) == ndraws + + idata4 = from_turing(chn; model=model, prior=false, observed_data=false) + @test sort(groupnames(idata4)) == [:posterior, :sample_stats] + + idata5 = from_turing(chn; model=model, nchains=3, ndraws=100) + @test idata5.posterior.inference_library == "Turing" + VersionNumber(idata5.posterior.inference_library_version) + @test sort(groupnames(idata5)) == sort([ + :posterior, + :posterior_predictive, + :log_likelihood, + :sample_stats, + :prior, + :prior_predictive, + :sample_stats_prior, + :observed_data, + :constant_data, + ]) + @test length(idata5.prior.chain.values) == 3 + @test length(idata5.prior.draw.values) == 100 + @test idata5.observed_data.xs.values == xs + @test only(idata5.observed_data.y.values) == y + @test only(idata5.constant_data.n.values) == 5 + + rng1 = Random.MersenneTwister(42) + idata6 = from_turing(chn; model=model, rng=rng1) + rng2 = Random.MersenneTwister(42) + idata7 = from_turing(chn; model=model, rng=rng2) + @testset for name in groupnames(idata6) + group1 = getproperty(idata6, name) + group2 = getproperty(idata7, name) + @testset for var_name in group1.variables.keys() + @test group1[var_name].values == group2[var_name].values + end + end +end