diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 71f08da82..ba8928215 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -618,11 +618,13 @@ def dataset_to_point_list( for vn in var_names: if not isinstance(vn, str): raise ValueError(f"Variable names must be str, but dataset key {vn} is a {type(vn)}.") + num_sample_dims = len(sample_dims) stacked_dims = {dim_name: ds[var_names[0]][dim_name] for dim_name in sample_dims} transposed_dict = {vn: da.transpose(*sample_dims, ...) for vn, da in ds.items()} + stacked_size = np.prod(transposed_dict[var_names[0]].shape[:num_sample_dims], dtype=int) stacked_dict = { - vn: da.values.reshape((-1, *da.shape[num_sample_dims:])) + vn: da.values.reshape((stacked_size, *da.shape[num_sample_dims:])) for vn, da in transposed_dict.items() } points = [ diff --git a/tests/backends/test_arviz.py b/tests/backends/test_arviz.py index 3c06288b3..f09a3a953 100644 --- a/tests/backends/test_arviz.py +++ b/tests/backends/test_arviz.py @@ -837,3 +837,14 @@ def test_dataset_to_point_list_str_key(self): ds[3] = xarray.DataArray([1, 2, 3]) with pytest.raises(ValueError, match="must be str"): dataset_to_point_list(ds, sample_dims=["chain", "draw"]) + + def test_zero_size(self): + ds = xarray.Dataset() + ds["x"] = xarray.DataArray( + np.zeros((4, 10, 0, 5)), dims=("chain", "draw", "dim_0", "dim_5") + ) + pl, _ = dataset_to_point_list(ds, sample_dims=("chain", "draw")) + assert len(pl) == 40 + assert tuple(pl[0]) == ("x",) + assert pl[0]["x"].shape == (0, 5) + assert pl[0]["x"].dtype == np.float64