diff --git a/Project.toml b/Project.toml index f6a00edb..884af994 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "InferenceObjects" uuid = "b5cf5a8d-e756-4ee3-b014-01d49d192c00" authors = ["Seth Axen and contributors"] -version = "0.4.6" +version = "0.4.7" [deps] Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" @@ -23,7 +23,7 @@ InferenceObjectsPosteriorStatsExt = ["PosteriorStats", "StatsBase"] [compat] ArviZExampleData = "0.1.10" Dates = "1.9" -DimensionalData = "0.27, 0.28" +DimensionalData = "0.27, 0.28, 0.29" EvoTrees = "0.16" MCMCDiagnosticTools = "0.3.4" MLJBase = "1" diff --git a/ext/InferenceObjectsMCMCDiagnosticToolsExt/InferenceObjectsMCMCDiagnosticToolsExt.jl b/ext/InferenceObjectsMCMCDiagnosticToolsExt/InferenceObjectsMCMCDiagnosticToolsExt.jl index 3b4ef55c..d26041eb 100644 --- a/ext/InferenceObjectsMCMCDiagnosticToolsExt/InferenceObjectsMCMCDiagnosticToolsExt.jl +++ b/ext/InferenceObjectsMCMCDiagnosticToolsExt/InferenceObjectsMCMCDiagnosticToolsExt.jl @@ -5,6 +5,8 @@ using DimensionalData: DimensionalData, Dimensions, LookupArrays using InferenceObjects: InferenceObjects, Random using MCMCDiagnosticTools: MCMCDiagnosticTools +maplayers = isdefined(DimensionalData, :maplayers) ? DimensionalData.maplayers : map + include("utils.jl") include("bfmi.jl") include("ess_rhat.jl") diff --git a/ext/InferenceObjectsMCMCDiagnosticToolsExt/ess_rhat.jl b/ext/InferenceObjectsMCMCDiagnosticToolsExt/ess_rhat.jl index 85d10cc8..38908487 100644 --- a/ext/InferenceObjectsMCMCDiagnosticToolsExt/ess_rhat.jl +++ b/ext/InferenceObjectsMCMCDiagnosticToolsExt/ess_rhat.jl @@ -21,7 +21,7 @@ end for f in (:ess, :rhat) @eval begin function MCMCDiagnosticTools.$f(data::InferenceObjects.Dataset; kwargs...) - ds = map(data) do var + ds = maplayers(data) do var return _as_dimarray(MCMCDiagnosticTools.$f(_params_array(var); kwargs...), var) end return DimensionalData.rebuild(ds; metadata=DimensionalData.NoMetadata()) diff --git a/ext/InferenceObjectsMCMCDiagnosticToolsExt/mcse.jl b/ext/InferenceObjectsMCMCDiagnosticToolsExt/mcse.jl index be9cb466..739ceed4 100644 --- a/ext/InferenceObjectsMCMCDiagnosticToolsExt/mcse.jl +++ b/ext/InferenceObjectsMCMCDiagnosticToolsExt/mcse.jl @@ -8,7 +8,7 @@ function MCMCDiagnosticTools.mcse(data::InferenceObjects.InferenceData; kwargs.. return MCMCDiagnosticTools.mcse(data.posterior; kwargs...) end function MCMCDiagnosticTools.mcse(data::InferenceObjects.Dataset; kwargs...) - ds = map(data) do var + ds = maplayers(data) do var return _as_dimarray(MCMCDiagnosticTools.mcse(_params_array(var); kwargs...), var) end return DimensionalData.rebuild(ds; metadata=DimensionalData.NoMetadata()) diff --git a/ext/InferenceObjectsMCMCDiagnosticToolsExt/rstar.jl b/ext/InferenceObjectsMCMCDiagnosticToolsExt/rstar.jl index 9cf9d04d..8992e860 100644 --- a/ext/InferenceObjectsMCMCDiagnosticToolsExt/rstar.jl +++ b/ext/InferenceObjectsMCMCDiagnosticToolsExt/rstar.jl @@ -16,7 +16,7 @@ end function MCMCDiagnosticTools.rstar( rng::Random.AbstractRNG, clf, data::InferenceObjects.Dataset; kwargs... ) - data_array = cat(map(_as_3d_array ∘ _params_array, data)...; dims=3) + data_array = cat(maplayers(_as_3d_array ∘ _params_array, data)...; dims=3) return MCMCDiagnosticTools.rstar(rng, clf, data_array; kwargs...) end function MCMCDiagnosticTools.rstar( diff --git a/src/dataset.jl b/src/dataset.jl index 22554791..7995a907 100644 --- a/src/dataset.jl +++ b/src/dataset.jl @@ -126,6 +126,8 @@ for f in [:data, :dims, :refdims, :metadata, :layerdims, :layermetadata] end end +DimensionalData.modify(f, s::Dataset) = Dataset(DimensionalData.modify(f, parent(s))) + # Warning: this is not an API function and probably should be implemented abstractly upstream DimensionalData.show_after(io, mime, ::Dataset) = nothing diff --git a/test/mcmcdiagnostictools.jl b/test/mcmcdiagnostictools.jl index 908e691a..6e5d6870 100644 --- a/test/mcmcdiagnostictools.jl +++ b/test/mcmcdiagnostictools.jl @@ -7,6 +7,10 @@ using Random using Statistics using Test +if !isdefined(DimensionalData, :maplayers) + maplayers = map +end + @testset "MCMCDiagnosticTools integration" begin nchains, ndraws = 4, 10 sizes = (x=(), y=(2,), z=(3, 5)) @@ -16,12 +20,12 @@ using Test dict1 = Dict(Symbol(k) => randn(ndraws, nchains, sz...) for (k, sz) in pairs(sizes)) idata1 = from_dict(dict1; dims, coords, sample_stats=Dict(:energy => energy)) # permute dimensions to test that diagnostics are invariant to dimension order - post2 = map(idata1.posterior) do var + post2 = maplayers(idata1.posterior) do var n = ndims(var) permdims = ((3:n)..., 2, 1) return permutedims(var, permdims) end - sample_stats2 = map(permutedims, idata1.sample_stats) + sample_stats2 = maplayers(permutedims, idata1.sample_stats) idata2 = InferenceData(; posterior=post2, sample_stats=sample_stats2) @testset for f in (ess, rhat, ess_rhat, mcse) @@ -35,7 +39,7 @@ using Test @test issetequal(keys(metric), keys(idata1.posterior)) @test metric == f(idata1.posterior; kind) @test metric2 == f(idata2.posterior; kind) - @test all(map(≈, metric2, metric)) + @test all(maplayers(≈, metric2, metric)) for k in keys(sizes) @test all( hasdim( @@ -81,7 +85,9 @@ using Test r4 = rstar(rng, classifier(rng), idata2.posterior; subset) rng = Random.seed!(123) post_mat = cat( - map(var -> reshape(parent(var), ndraws, nchains, :), idata1.posterior)...; + maplayers( + var -> reshape(parent(var), ndraws, nchains, :), idata1.posterior + )...; dims=3, ) r5 = rstar(rng, classifier(rng), post_mat; subset)