Skip to content

Commit 1719ce7

Browse files
committed
Do not fail with zero-sized arrays in dataset_to_point_list
Numpy does not support reshape(-1, ...) when size is zero
1 parent 8a436d8 commit 1719ce7

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

pymc/backends/arviz.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,10 +619,12 @@ def dataset_to_point_list(
619619
if not isinstance(vn, str):
620620
raise ValueError(f"Variable names must be str, but dataset key {vn} is a {type(vn)}.")
621621
num_sample_dims = len(sample_dims)
622+
622623
stacked_dims = {dim_name: ds[var_names[0]][dim_name] for dim_name in sample_dims}
624+
stacked_size = np.prod([ds.sizes[dim_name] for dim_name in sample_dims], dtype=int)
623625
transposed_dict = {vn: da.transpose(*sample_dims, ...) for vn, da in ds.items()}
624626
stacked_dict = {
625-
vn: da.values.reshape((-1, *da.shape[num_sample_dims:]))
627+
vn: da.values.reshape((stacked_size, *da.shape[num_sample_dims:]))
626628
for vn, da in transposed_dict.items()
627629
}
628630
points = [

tests/backends/test_arviz.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,3 +837,14 @@ def test_dataset_to_point_list_str_key(self):
837837
ds[3] = xarray.DataArray([1, 2, 3])
838838
with pytest.raises(ValueError, match="must be str"):
839839
dataset_to_point_list(ds, sample_dims=["chain", "draw"])
840+
841+
def test_zero_size(self):
842+
ds = xarray.Dataset()
843+
ds["x"] = xarray.DataArray(
844+
np.zeros((4, 10, 0, 5)), dims=("chain", "draw", "dim_0", "dim_5")
845+
)
846+
pl, _ = dataset_to_point_list(ds, sample_dims=("chain", "draw"))
847+
assert len(pl) == 40
848+
assert tuple(pl[0]) == ("x",)
849+
assert pl[0]["x"].shape == (0, 5)
850+
assert pl[0]["x"].dtype == np.float64

0 commit comments

Comments
 (0)