Skip to content

Commit fedb85c

Browse files
authored
Raise error when dims match the variable name (#5518)
* Added assertion that a variable doesn't have the same name as its dimensions. * Fixed error in tests triggered by var-dim name collision. Closes #5309
1 parent 2c6abf6 commit fedb85c

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

pymc/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1442,7 +1442,10 @@ def add_random_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]]
14421442
if dims is not None:
14431443
if isinstance(dims, str):
14441444
dims = (dims,)
1445-
assert all(dim in self.coords or dim is None for dim in dims)
1445+
if any(dim not in self.coords and dim is not None for dim in dims):
1446+
raise ValueError(f"Dimension {dim} is not specified in `coords`.")
1447+
if any(var.name == dim for dim in dims):
1448+
raise ValueError(f"Variable `{var.name}` has the same name as its dimension label.")
14461449
self._RV_dims[var.name] = dims
14471450

14481451
self.named_vars[var.name] = var

pymc/tests/test_idata_conversion.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import Dict, Tuple
44

5+
import aesara.tensor as at
56
import numpy as np
67
import pandas as pd
78
import pytest
@@ -602,6 +603,12 @@ def test_issue_5043_autoconvert_coord_values(self):
602603
)
603604
assert isinstance(converter.coords["city"], pd.MultiIndex)
604605

606+
def test_variable_dimension_name_collision(self):
607+
with pytest.raises(ValueError, match="same name as its dimension"):
608+
with pm.Model() as pmodel:
609+
var = at.as_tensor([1, 2, 3])
610+
pmodel.register_rv(var, name="time", dims=("time",))
611+
605612

606613
class TestPyMCWarmupHandling:
607614
@pytest.mark.parametrize("save_warmup", [False, True])

pymc/tests/test_model_graph.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def model_with_dims():
104104

105105
population = pm.HalfNormal("population", sd=5, dims=("city"))
106106

107-
time = pm.ConstantData("year", [2014, 2015, 2016], dims="year")
107+
time = pm.ConstantData("time", [2014, 2015, 2016], dims="year")
108108

109109
n = pm.Deterministic(
110110
"tax revenue", economics * population[None, :] * time[:, None], dims=("year", "city")
@@ -116,15 +116,15 @@ def model_with_dims():
116116
compute_graph = {
117117
"economics": set(),
118118
"population": set(),
119-
"year": set(),
120-
"tax revenue": {"economics", "population", "year"},
119+
"time": set(),
120+
"tax revenue": {"economics", "population", "time"},
121121
"L": {"tax revenue"},
122122
"observed": {"L"},
123123
}
124124
plates = {
125125
"1": {"economics"},
126126
"city (4)": {"population"},
127-
"year (3)": {"year"},
127+
"year (3)": {"time"},
128128
"year (3) x city (4)": {"tax revenue"},
129129
"3 x 4": {"L", "observed"},
130130
}

0 commit comments

Comments
 (0)