Skip to content

Commit 245d822

Browse files
github-actions[bot]CompatHelper Juliasethaxen
authored
CompatHelper: bump compat for DimensionalData to 0.26, (keep existing compat) (#76)
* CompatHelper: bump compat for DimensionalData to 0.26, (keep existing compat) * Use correct combinedims syntax * Update docstrings with DimData v0.26 style * Improve type inference of r2_score * Increment patch number * Fix stack overflow on Julia v1.6 --------- Co-authored-by: CompatHelper Julia <[email protected]> Co-authored-by: Seth Axen <[email protected]>
1 parent 4424fd0 commit 245d822

File tree

7 files changed

+74
-53
lines changed

7 files changed

+74
-53
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "InferenceObjects"
22
uuid = "b5cf5a8d-e756-4ee3-b014-01d49d192c00"
33
authors = ["Seth Axen <[email protected]> and contributors"]
4-
version = "0.3.15"
4+
version = "0.3.16"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
@@ -27,7 +27,7 @@ InferenceObjectsPosteriorStatsExt = ["PosteriorStats", "StatsBase"]
2727
ArviZExampleData = "0.1"
2828
Compat = "3.46.0, 4.2.0"
2929
Dates = "1.6"
30-
DimensionalData = "0.24, 0.25"
30+
DimensionalData = "0.24, 0.25, 0.26"
3131
EvoTrees = "0.16"
3232
MCMCDiagnosticTools = "0.3.4"
3333
MLJBase = "1"

ext/InferenceObjectsMCMCDiagnosticToolsExt/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
function _params_array(data)
33
sample_dims = Dimensions.dims(data, InferenceObjects.DEFAULT_SAMPLE_DIMS)
44
param_dims = Dimensions.otherdims(data, sample_dims)
5-
dims_combined = Dimensions.combinedims(sample_dims, param_dims)
5+
dims_combined = Dimensions.combinedims((sample_dims..., param_dims...))
66
Dimensions.dimsmatch(Dimensions.dims(data), dims_combined) && return data
77
return PermutedDimsArray(data, dims_combined)
88
end

ext/InferenceObjectsPosteriorStatsExt/hdi.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# this pattern ensures that the type is completely specified at compile time
2-
const HDI_BOUND_DIM = Dimensions.format(
3-
Dimensions.Dim{:hdi_bound}([:lower, :upper]), Base.OneTo(2)
4-
)
2+
const HDI_BOUND_DIM = let
3+
dims = Dimensions.format(Dimensions.Dim{:hdi_bound}([:lower, :upper]), Base.OneTo(2))
4+
# some versions of DimensionalData return a tuple here, others return a Dim
5+
dims isa Tuple ? only(dims) : dims
6+
end
57

68
@doc """
79
hdi(data::InferenceData; kwargs...) -> Dataset
@@ -19,9 +21,9 @@ function PosteriorStats.hdi(data::InferenceObjects.Dataset; kwargs...)
1921
lower, upper = map(Base.Fix2(_as_dimarray, x), r)
2022
return cat(lower, upper; dims=HDI_BOUND_DIM)
2123
end
22-
dims = Dimensions.combinedims(
23-
Dimensions.otherdims(data, InferenceObjects.DEFAULT_SAMPLE_DIMS), HDI_BOUND_DIM
24-
)
24+
dims = Dimensions.combinedims((
25+
Dimensions.otherdims(data, InferenceObjects.DEFAULT_SAMPLE_DIMS)..., HDI_BOUND_DIM
26+
))
2527
return DimensionalData.rebuild(
2628
data;
2729
data=map(parent, results),

ext/InferenceObjectsPosteriorStatsExt/loo_pit.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@ julia> idata = load_example_data("centered_eight");
2323
julia> loo_result = loo(idata; var_name=:obs);
2424
2525
julia> loo_pit(idata, loo_result.psis_result.log_weights; y_name=:obs)
26-
8-element DimArray{Float64,1} loo_pit_obs with dimensions:
27-
Dim{:school} Categorical{String} String[Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
26+
╭───────────────────────────────────────────╮
27+
│ 8-element DimArray{Float64,1} loo_pit_obs │
28+
├───────────────────────────────────────────┴──────────────────────────── dims ┐
29+
↓ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
30+
└──────────────────────────────────────────────────────────────────────────────┘
2831
"Choate" 0.943511
2932
"Deerfield" 0.63797
3033
"Phillips Andover" 0.316697
@@ -78,8 +81,11 @@ julia> using ArviZExampleData, PosteriorStats
7881
julia> idata = load_example_data("centered_eight");
7982
8083
julia> loo_pit(idata; y_name=:obs)
81-
8-element DimArray{Float64,1} loo_pit_obs with dimensions:
82-
Dim{:school} Categorical{String} String[Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
84+
╭───────────────────────────────────────────╮
85+
│ 8-element DimArray{Float64,1} loo_pit_obs │
86+
├───────────────────────────────────────────┴──────────────────────────── dims ┐
87+
↓ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
88+
└──────────────────────────────────────────────────────────────────────────────┘
8389
"Choate" 0.943511
8490
"Deerfield" 0.63797
8591
"Phillips Andover" 0.316697

ext/InferenceObjectsPosteriorStatsExt/r2_score.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,9 @@ function PosteriorStats.r2_score(
2929
y_pred_name::Union{Symbol,Nothing}=nothing,
3030
)
3131
(_, y), (_, y_pred) = observations_and_predictions(idata, y_name, y_pred_name)
32-
return PosteriorStats.r2_score(y, _draw_chains_params_array(y_pred))
32+
y_data = y isa DimensionalData.AbstractDimArray ? parent(y) : y
33+
y_data, y_pred_data = map((y, _draw_chains_params_array(y_pred))) do arr
34+
return arr isa DimensionalData.AbstractDimArray ? parent(arr) : arr
35+
end
36+
return PosteriorStats.r2_score(y_data, y_pred_data)
3337
end

ext/InferenceObjectsPosteriorStatsExt/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ end
167167
function _draw_chains_params_array(x::DimensionalData.AbstractDimArray)
168168
sample_dims = Dimensions.dims(x, InferenceObjects.DEFAULT_SAMPLE_DIMS)
169169
param_dims = Dimensions.otherdims(x, sample_dims)
170-
dims_combined = Dimensions.combinedims(sample_dims, param_dims)
170+
dims_combined = Dimensions.combinedims((sample_dims..., param_dims...))
171171
Dimensions.dimsmatch(Dimensions.dims(x), dims_combined) && return x
172172
return PermutedDimsArray(x, dims_combined)
173173
end

src/inference_data.jl

Lines changed: 47 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -300,16 +300,18 @@ InferenceData with groups:
300300
> posterior
301301
302302
julia> idata_cat1.posterior
303-
Dataset with dimensions:
304-
Dim{:draw},
305-
Dim{:chain},
306-
Dim{:a_dim} Categorical{String} String["x", "y", "z"] ForwardOrdered
307-
and 2 layers:
308-
:a Float64 dims: Dim{:draw}, Dim{:chain}, Dim{:a_dim} (100×8×3)
309-
:b Float64 dims: Dim{:draw}, Dim{:chain} (100×8)
310-
311-
with metadata Dict{String, Any} with 1 entry:
312-
"created_at" => "2023-04-03T18:41:35.779"
303+
╭─────────────────╮
304+
│ 100×8×3 Dataset │
305+
├─────────────────┴──────────────────────────────────── dims ┐
306+
↓ draw ,
307+
→ chain,
308+
↗ a_dim Categorical{String} ["x", "y", "z"] ForwardOrdered
309+
├──────────────────────────────────────────────────── layers ┤
310+
:a eltype: Float64 dims: draw, chain, a_dim size: 100×8×3
311+
:b eltype: Float64 dims: draw, chain size: 100×8
312+
├────────────────────────────────────────────────── metadata ┤
313+
Dict{String, Any} with 1 entry:
314+
"created_at" => "2024-03-11T14:10:48.434"
313315
```
314316
315317
Alternatively, we can concatenate along a new `run` dimension, which will be created.
@@ -320,17 +322,19 @@ InferenceData with groups:
320322
> posterior
321323
322324
julia> idata_cat2.posterior
323-
Dataset with dimensions:
324-
Dim{:draw},
325-
Dim{:chain},
326-
Dim{:a_dim} Categorical{String} String["x", "y", "z"] ForwardOrdered,
327-
Dim{:run}
328-
and 2 layers:
329-
:a Float64 dims: Dim{:draw}, Dim{:chain}, Dim{:a_dim}, Dim{:run} (100×4×3×2)
330-
:b Float64 dims: Dim{:draw}, Dim{:chain}, Dim{:run} (100×4×2)
331-
332-
with metadata Dict{String, Any} with 1 entry:
333-
"created_at" => "2023-04-03T18:41:35.779"
325+
╭───────────────────╮
326+
│ 100×4×3×2 Dataset │
327+
├───────────────────┴─────────────────────────────────── dims ┐
328+
↓ draw ,
329+
→ chain,
330+
↗ a_dim Categorical{String} ["x", "y", "z"] ForwardOrdered,
331+
⬔ run
332+
├─────────────────────────────────────────────────────────────┴ layers ┐
333+
:a eltype: Float64 dims: draw, chain, a_dim, run size: 100×4×3×2
334+
:b eltype: Float64 dims: draw, chain, run size: 100×4×2
335+
├──────────────────────────────────────────────────────────── metadata ┤
336+
Dict{String, Any} with 1 entry:
337+
"created_at" => "2024-03-11T14:10:48.434"
334338
```
335339
336340
We can also concatenate only a subset of groups and merge the rest, which is useful when
@@ -351,25 +355,30 @@ InferenceData with groups:
351355
> observed_data
352356
353357
julia> idata_cat3.posterior
354-
Dataset with dimensions:
355-
Dim{:draw},
356-
Dim{:chain},
357-
Dim{:a_dim} Categorical{String} String["x", "y", "z"] ForwardOrdered,
358-
Dim{:run}
359-
and 2 layers:
360-
:a Float64 dims: Dim{:draw}, Dim{:chain}, Dim{:a_dim}, Dim{:run} (100×4×3×2)
361-
:b Float64 dims: Dim{:draw}, Dim{:chain}, Dim{:run} (100×4×2)
362-
363-
with metadata Dict{String, Any} with 1 entry:
364-
"created_at" => "2023-04-03T18:41:35.779"
358+
╭───────────────────╮
359+
│ 100×4×3×2 Dataset │
360+
├───────────────────┴─────────────────────────────────── dims ┐
361+
↓ draw ,
362+
→ chain,
363+
↗ a_dim Categorical{String} ["x", "y", "z"] ForwardOrdered,
364+
⬔ run
365+
├─────────────────────────────────────────────────────────────┴ layers ┐
366+
:a eltype: Float64 dims: draw, chain, a_dim, run size: 100×4×3×2
367+
:b eltype: Float64 dims: draw, chain, run size: 100×4×2
368+
├──────────────────────────────────────────────────────────── metadata ┤
369+
Dict{String, Any} with 1 entry:
370+
"created_at" => "2024-03-11T14:10:48.434"
365371
366372
julia> idata_cat3.observed_data
367-
Dataset with dimensions: Dim{:y_dim_1}
368-
and 1 layer:
369-
:y Float64 dims: Dim{:y_dim_1} (10)
370-
371-
with metadata Dict{String, Any} with 1 entry:
372-
"created_at" => "2023-02-17T15:11:00.59"
373+
╭────────────────────╮
374+
│ 10-element Dataset │
375+
├────────────── dims ┤
376+
↓ y_dim_1
377+
├────────────────────┴─────────────── layers ┐
378+
:y eltype: Float64 dims: y_dim_1 size: 10
379+
├────────────────────────────────────────────┴ metadata ┐
380+
Dict{String, Any} with 1 entry:
381+
"created_at" => "2024-03-11T14:10:53.539"
373382
```
374383
"""
375384
function Base.cat(data::InferenceData, others::InferenceData...; groups=keys(data), dims)

0 commit comments

Comments
 (0)