@@ -7,6 +7,10 @@ using Random
77using Statistics
88using Test
99
10+ if ! isdefined (DimensionalData, :maplayers )
11+ maplayers = map
12+ end
13+
1014@testset " MCMCDiagnosticTools integration" begin
1115 nchains, ndraws = 4 , 10
1216 sizes = (x= (), y= (2 ,), z= (3 , 5 ))
@@ -16,12 +20,12 @@ using Test
1620 dict1 = Dict (Symbol (k) => randn (ndraws, nchains, sz... ) for (k, sz) in pairs (sizes))
1721 idata1 = from_dict (dict1; dims, coords, sample_stats= Dict (:energy => energy))
1822 # permute dimensions to test that diagnostics are invariant to dimension order
19- post2 = map (idata1. posterior) do var
23+ post2 = maplayers (idata1. posterior) do var
2024 n = ndims (var)
2125 permdims = ((3 : n). .. , 2 , 1 )
2226 return permutedims (var, permdims)
2327 end
24- sample_stats2 = map (permutedims, idata1. sample_stats)
28+ sample_stats2 = maplayers (permutedims, idata1. sample_stats)
2529 idata2 = InferenceData (; posterior= post2, sample_stats= sample_stats2)
2630
2731 @testset for f in (ess, rhat, ess_rhat, mcse)
@@ -35,7 +39,7 @@ using Test
3539 @test issetequal (keys (metric), keys (idata1. posterior))
3640 @test metric == f (idata1. posterior; kind)
3741 @test metric2 == f (idata2. posterior; kind)
38- @test all (map (≈ , metric2, metric))
42+ @test all (maplayers (≈ , metric2, metric))
3943 for k in keys (sizes)
4044 @test all (
4145 hasdim (
@@ -81,7 +85,7 @@ using Test
8185 r4 = rstar (rng, classifier (rng), idata2. posterior; subset)
8286 rng = Random. seed! (123 )
8387 post_mat = cat (
84- map (var -> reshape (parent (var), ndraws, nchains, :), idata1. posterior)... ;
88+ maplayers (var -> reshape (parent (var), ndraws, nchains, :), idata1. posterior)... ;
8589 dims= 3 ,
8690 )
8791 r5 = rstar (rng, classifier (rng), post_mat; subset)
0 commit comments