|
16 | 16 | import numpy as np
|
17 | 17 | import pytensor.tensor as pt
|
18 | 18 | import pytest
|
| 19 | +import xarray |
19 | 20 |
|
20 | 21 | from arviz import InferenceData
|
21 | 22 | from arviz.tests.helpers import check_multiple_attrs
|
|
26 | 27 |
|
27 | 28 | from pymc.backends.arviz import (
|
28 | 29 | InferenceDataConverter,
|
| 30 | + dataset_to_point_list, |
29 | 31 | predictions_to_inference_data,
|
30 | 32 | to_inference_data,
|
31 | 33 | )
|
@@ -776,3 +778,34 @@ def test_save_warmup_issue_1208_after_3_9(self):
|
776 | 778 | assert not fails
|
777 | 779 | assert idata.posterior.sizes["chain"] == 2
|
778 | 780 | assert idata.posterior.sizes["draw"] == 30
|
| 781 | + |
| 782 | + |
| 783 | +class TestDatasetToPointList: |
| 784 | + @pytest.mark.parametrize("input_type", ("dict", "Dataset")) |
| 785 | + def test_dataset_to_point_list(self, input_type): |
| 786 | + if input_type == "dict": |
| 787 | + ds = {} |
| 788 | + elif input_type == "Dataset": |
| 789 | + ds = xarray.Dataset() |
| 790 | + ds["A"] = xarray.DataArray([[1, 2, 3]] * 2, dims=("chain", "draw")) |
| 791 | + pl, _ = dataset_to_point_list(ds, sample_dims=["chain", "draw"]) |
| 792 | + assert isinstance(pl, list) |
| 793 | + assert len(pl) == 6 |
| 794 | + assert isinstance(pl[0], dict) |
| 795 | + assert isinstance(pl[0]["A"], np.ndarray) |
| 796 | + |
| 797 | + def test_transposed_dataset_to_point_list(self): |
| 798 | + ds = xarray.Dataset() |
| 799 | + ds["A"] = xarray.DataArray([[[1, 2, 3], [2, 3, 4]]] * 5, dims=("team", "draw", "chain")) |
| 800 | + pl, _ = dataset_to_point_list(ds, sample_dims=["chain", "draw"]) |
| 801 | + assert isinstance(pl, list) |
| 802 | + assert len(pl) == 6 |
| 803 | + assert isinstance(pl[0], dict) |
| 804 | + assert isinstance(pl[0]["A"], np.ndarray) |
| 805 | + |
| 806 | + def test_dataset_to_point_list_str_key(self): |
| 807 | + # Check that non-str keys are caught |
| 808 | + ds = xarray.Dataset() |
| 809 | + ds[3] = xarray.DataArray([1, 2, 3]) |
| 810 | + with pytest.raises(ValueError, match="must be str"): |
| 811 | + dataset_to_point_list(ds, sample_dims=["chain", "draw"]) |
0 commit comments