Skip to content

Commit c3cc833

Browse files
committed
Arviz don't fail hard on incompatible coordinate lengths
1 parent fba64f0 commit c3cc833

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

pymc/backends/arviz.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,41 @@
5050

5151
_log = logging.getLogger(__name__)
5252

53+
54+
RAISE_ON_INCOMPATIBLE_COORD_LENGTHS = False
55+
56+
5357
# random variable object ...
5458
Var = Any
5559

5660

61+
def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **kwargs):
62+
safe_coords = coords
63+
64+
if not RAISE_ON_INCOMPATIBLE_COORD_LENGTHS:
65+
coords_lengths = {k: len(v) for k, v in coords.items()}
66+
for var_name, var in vars_dict.items():
67+
# Iterate in reversed because of chain/draw batch dimensions
68+
for dim, dim_length in zip(reversed(dims.get(var_name, ())), reversed(var.shape)):
69+
coord_length = coords_lengths.get(dim, None)
70+
if (coord_length is not None) and (coord_length != dim_length):
71+
warnings.warn(
72+
f"Incompatible coordinate length of {coord_length} for dimension '{dim}' of variable '{var_name}'.\n"
73+
"This usually happens when a sliced or concatenated variable is wrapped as a `pymc.dims.Deterministic`."
74+
"The originate coordinates for this dim will not be included in the returned dataset for any of the variables. "
75+
"Instead they will default to `np.arange(var_length)` and the shorter variables will be right-padded with nan.\n"
76+
"To make this warning into an error set `pymc.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS` to `True`",
77+
UserWarning,
78+
)
79+
if safe_coords is coords:
80+
safe_coords = coords.copy()
81+
safe_coords.pop(dim)
82+
coords_lengths.pop(dim)
83+
84+
# FIXME: Would be better to drop coordinates altogether, but arviz defaults to `np.arange(var_length)`
85+
return dict_to_dataset(vars_dict, *args, dims=dims, coords=safe_coords, **kwargs)
86+
87+
5788
def find_observations(model: "Model") -> dict[str, Var]:
5889
"""If there are observations available, return them as a dictionary."""
5990
observations = {}
@@ -366,7 +397,7 @@ def priors_to_xarray(self):
366397
priors_dict[group] = (
367398
None
368399
if var_names is None
369-
else dict_to_dataset(
400+
else dict_to_dataset_drop_incompatible_coords(
370401
{k: np.expand_dims(self.prior[k], 0) for k in var_names},
371402
library=pymc,
372403
coords=self.coords,

tests/backends/test_arviz.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import re
1415
import warnings
1516

1617
import numpy as np
@@ -848,3 +849,27 @@ def test_zero_size(self):
848849
assert tuple(pl[0]) == ("x",)
849850
assert pl[0]["x"].shape == (0, 5)
850851
assert pl[0]["x"].dtype == np.float64
852+
853+
854+
def test_incompatible_coordinate_lengths():
855+
with pm.Model(coords={"a": [-1, -2, -3]}) as m:
856+
x = pm.Normal("x", dims="a")
857+
y = pm.Deterministic("y", x[1:], dims=("a",))
858+
859+
with pytest.warns(
860+
UserWarning,
861+
match=re.escape(
862+
"Incompatible coordinate length of 3 for dimension 'a' of variable 'y'"
863+
),
864+
):
865+
prior = pm.sample_prior_predictive(draws=1).prior.squeeze(("chain", "draw"))
866+
assert prior.x.dims == prior.y.dims == ("a",)
867+
assert prior.x.shape == prior.y.shape == (3,)
868+
assert np.isnan(prior.y.values[-1])
869+
assert list(prior.coords["a"]) == [0, 1, 2]
870+
871+
pm.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS = True
872+
with pytest.raises(ValueError):
873+
pm.sample_prior_predictive(draws=1)
874+
875+
pm.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS = False

0 commit comments

Comments
 (0)