Skip to content

Commit 35493f3

Browse files
committed
Update/unify hdi/ci implementations
1 parent 0e8aaa9 commit 35493f3

File tree

5 files changed

+54
-56
lines changed

5 files changed

+54
-56
lines changed

ext/InferenceObjectsPosteriorStatsExt/InferenceObjectsPosteriorStatsExt.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@ using InferenceObjects: InferenceObjects
66
using PosteriorStats: PosteriorStats
77
using StatsBase: StatsBase
88

9-
import PosteriorStats: hdi, loo, loo_pit, r2_score, summarize, waic
9+
import PosteriorStats: eti, hdi, loo, loo_pit, r2_score, summarize, waic
1010
import StatsBase: summarystats
1111

12-
export hdi, loo, loo_pit, r2_score, summarize, waic, summarystats
12+
export eti, hdi, loo, loo_pit, r2_score, summarize, waic, summarystats
13+
14+
maplayers = isdefined(DimensionalData, :maplayers) ? DimensionalData.maplayers : map
1315

1416
include("utils.jl")
15-
include("hdi.jl")
17+
include("ci.jl")
1618
include("loo.jl")
1719
include("waic.jl")
1820
include("loo_pit.jl")
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
2+
for (ci_fun, ci_desc) in
3+
(:eti => "equal-tailed interval (ETI)", :hdi => "highest density interval (HDI)")
4+
@eval begin
5+
# this pattern ensures that the type is completely specified at compile time
6+
@doc """
7+
$($ci_fun)(data::InferenceData; kwargs...) -> Dataset
8+
$($ci_fun)(data::Dataset; kwargs...) -> Dataset
9+
10+
Calculate the $($ci_desc) for each parameter in the data.
11+
12+
For more details and a description of the `kwargs`, see
13+
[`PosteriorStats.$($ci_fun)`](@extref).
14+
"""
15+
function PosteriorStats.$(ci_fun)(data::InferenceObjects.InferenceData; kwargs...)
16+
return PosteriorStats.$(ci_fun)(data.posterior; kwargs...)
17+
end
18+
function PosteriorStats.$(ci_fun)(data::InferenceObjects.Dataset; kwargs...)
19+
ds = maplayers(data) do var
20+
return _as_dimarray(
21+
PosteriorStats.$(ci_fun)(_params_array(var); kwargs...), var
22+
)
23+
end
24+
return DimensionalData.rebuild(ds; metadata=DimensionalData.NoMetadata())
25+
end
26+
end
27+
end

ext/InferenceObjectsPosteriorStatsExt/hdi.jl

Lines changed: 0 additions & 37 deletions
This file was deleted.

ext/InferenceObjectsPosteriorStatsExt/utils.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,17 @@ function observations_and_predictions(
159159
end
160160
end
161161

162+
# reshape to (ndraws, nchains, nparams...) and drop the dimensions
163+
function _params_array(data)
164+
sample_dims = Dimensions.dims(data, InferenceObjects.DEFAULT_SAMPLE_DIMS)
165+
param_dims = Dimensions.otherdims(data, sample_dims)
166+
dims_combined = Dimensions.combinedims((sample_dims..., param_dims...))
167+
Dimensions.dimsmatch(Dimensions.dims(data), dims_combined) && return data
168+
return PermutedDimsArray(data, dims_combined)
169+
end
170+
162171
_as_dimarray(x::DimensionalData.AbstractDimArray, ::DimensionalData.AbstractDimArray) = x
163-
function _as_dimarray(x::Union{Real,Missing}, arr::DimensionalData.AbstractDimArray)
172+
function _as_dimarray(x, arr::DimensionalData.AbstractDimArray)
164173
return Dimensions.rebuild(arr, fill(x), ())
165174
end
166175

test/posteriorstats.jl

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@ using Statistics
66
using StatsBase
77
using Test
88

9+
_as_array(x) = fill(x)
10+
_as_array(x::AbstractArray) = x
11+
912
@testset "PosteriorStats integration" begin
10-
@testset "hdi" begin
13+
@testset for ci_fun in (eti, hdi)
1114
nt = (x=randn(1000, 3), y=randn(1000, 3, 4), z=randn(1000, 3, 4, 2))
1215
posterior = convert_to_dataset(nt)
1316
posterior_perm = convert_to_dataset((
@@ -17,23 +20,17 @@ using Test
1720
))
1821
idata = InferenceData(; posterior)
1922
@testset for prob in (0.76, 0.93)
20-
if VERSION v"1.9"
21-
@test_broken @inferred hdi(posterior; prob)
22-
end
23-
r1 = hdi(posterior; prob)
24-
r1_perm = hdi(posterior_perm; prob)
23+
@test_broken @inferred ci_fun(posterior; prob)
24+
r1 = ci_fun(posterior; prob)
25+
r1_perm = ci_fun(posterior_perm; prob)
2526
for k in (:x, :y, :z)
26-
rk = hdi(posterior[k]; prob)
27-
@test r1[k][hdi_bound=At(:lower)] == rk.lower
28-
@test r1[k][hdi_bound=At(:upper)] == rk.upper
27+
rk = ci_fun(posterior[k]; prob)
28+
@test r1[k] == _as_array(rk)
2929
# equality check is safe because these are always data values
30-
@test r1_perm[k][hdi_bound=At(:lower)] == rk.lower
31-
@test r1_perm[k][hdi_bound=At(:upper)] == rk.upper
32-
end
33-
if VERSION v"1.9"
34-
@test_broken @inferred hdi(idata; prob)
30+
@test r1_perm[k] == _as_array(rk)
3531
end
36-
r2 = hdi(idata; prob)
32+
@test_broken @inferred ci_fun(idata; prob)
33+
r2 = ci_fun(idata; prob)
3734
@test r1 == r2
3835
end
3936
end

0 commit comments

Comments
 (0)