Skip to content

Commit e3021ee

Browse files
committed
Use DimensionalData.maplayers if available
1 parent 15fe295 commit e3021ee

File tree

6 files changed

+15
-7
lines changed

6 files changed

+15
-7
lines changed

ext/InferenceObjectsMCMCDiagnosticToolsExt/InferenceObjectsMCMCDiagnosticToolsExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ else # using Requires
1313
using ..Random: Random
1414
end
1515

16+
maplayers = isdefined(DimensionalData, :maplayers) ? DimensionalData.maplayers : map
17+
1618
include("utils.jl")
1719
include("bfmi.jl")
1820
include("ess_rhat.jl")

ext/InferenceObjectsMCMCDiagnosticToolsExt/ess_rhat.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ end
2121
for f in (:ess, :rhat)
2222
@eval begin
2323
function MCMCDiagnosticTools.$f(data::InferenceObjects.Dataset; kwargs...)
24-
ds = map(data) do var
24+
ds = maplayers(data) do var
2525
return _as_dimarray(MCMCDiagnosticTools.$f(_params_array(var); kwargs...), var)
2626
end
2727
return DimensionalData.rebuild(ds; metadata=DimensionalData.NoMetadata())

ext/InferenceObjectsMCMCDiagnosticToolsExt/mcse.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ function MCMCDiagnosticTools.mcse(data::InferenceObjects.InferenceData; kwargs..
88
return MCMCDiagnosticTools.mcse(data.posterior; kwargs...)
99
end
1010
function MCMCDiagnosticTools.mcse(data::InferenceObjects.Dataset; kwargs...)
11-
ds = map(data) do var
11+
ds = maplayers(data) do var
1212
return _as_dimarray(MCMCDiagnosticTools.mcse(_params_array(var); kwargs...), var)
1313
end
1414
return DimensionalData.rebuild(ds; metadata=DimensionalData.NoMetadata())

ext/InferenceObjectsMCMCDiagnosticToolsExt/rstar.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ end
1616
function MCMCDiagnosticTools.rstar(
1717
rng::Random.AbstractRNG, clf, data::InferenceObjects.Dataset; kwargs...
1818
)
19-
data_array = cat(map(_as_3d_array _params_array, data)...; dims=3)
19+
data_array = cat(maplayers(_as_3d_array _params_array, data)...; dims=3)
2020
return MCMCDiagnosticTools.rstar(rng, clf, data_array; kwargs...)
2121
end
2222
function MCMCDiagnosticTools.rstar(

src/dataset.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ for f in [:data, :dims, :refdims, :metadata, :layerdims, :layermetadata]
126126
end
127127
end
128128

129+
DimensionalData.modify(f, s::Dataset) = Dataset(DimensionalData.modify(f, parent(s)))
130+
129131
# Warning: this is not an API function and probably should be implemented abstractly upstream
130132
DimensionalData.show_after(io, mime, ::Dataset) = nothing
131133

test/mcmcdiagnostictools.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ using Random
77
using Statistics
88
using Test
99

10+
if !isdefined(DimensionalData, :maplayers)
11+
maplayers = map
12+
end
13+
1014
@testset "MCMCDiagnosticTools integration" begin
1115
nchains, ndraws = 4, 10
1216
sizes = (x=(), y=(2,), z=(3, 5))
@@ -16,12 +20,12 @@ using Test
1620
dict1 = Dict(Symbol(k) => randn(ndraws, nchains, sz...) for (k, sz) in pairs(sizes))
1721
idata1 = from_dict(dict1; dims, coords, sample_stats=Dict(:energy => energy))
1822
# permute dimensions to test that diagnostics are invariant to dimension order
19-
post2 = map(idata1.posterior) do var
23+
post2 = maplayers(idata1.posterior) do var
2024
n = ndims(var)
2125
permdims = ((3:n)..., 2, 1)
2226
return permutedims(var, permdims)
2327
end
24-
sample_stats2 = map(permutedims, idata1.sample_stats)
28+
sample_stats2 = maplayers(permutedims, idata1.sample_stats)
2529
idata2 = InferenceData(; posterior=post2, sample_stats=sample_stats2)
2630

2731
@testset for f in (ess, rhat, ess_rhat, mcse)
@@ -35,7 +39,7 @@ using Test
3539
@test issetequal(keys(metric), keys(idata1.posterior))
3640
@test metric == f(idata1.posterior; kind)
3741
@test metric2 == f(idata2.posterior; kind)
38-
@test all(map(, metric2, metric))
42+
@test all(maplayers(, metric2, metric))
3943
for k in keys(sizes)
4044
@test all(
4145
hasdim(
@@ -81,7 +85,7 @@ using Test
8185
r4 = rstar(rng, classifier(rng), idata2.posterior; subset)
8286
rng = Random.seed!(123)
8387
post_mat = cat(
84-
map(var -> reshape(parent(var), ndraws, nchains, :), idata1.posterior)...;
88+
maplayers(var -> reshape(parent(var), ndraws, nchains, :), idata1.posterior)...;
8589
dims=3,
8690
)
8791
r5 = rstar(rng, classifier(rng), post_mat; subset)

0 commit comments

Comments
 (0)