Skip to content

Commit 244931e

Browse files
committed
.Remove with dims (merge this before)
1 parent d93c1c9 commit 244931e

File tree

3 files changed

+6
-34
lines changed

3 files changed

+6
-34
lines changed

pymc/dims/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,4 @@ def __init__():
6868

6969
from pymc.dims import math
7070
from pymc.dims.distributions import *
71-
from pymc.dims.model import Data, with_dims
71+
from pymc.dims.model import Data, Deterministic, Potential

pymc/dims/distribution_core.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from pytensor.xtensor.type import XTensorVariable
2525

2626
from pymc import SymbolicRandomVariable, modelcontext
27-
from pymc.dims.model import with_dims
2827
from pymc.dims.transforms import DimTransform, log_odds_transform, log_transform
2928
from pymc.distributions.distribution import _support_point, support_point
3029
from pymc.distributions.shape_utils import DimsWithEllipsis, convert_dims
@@ -129,14 +128,11 @@ def _as_xtensor(x):
129128
try:
130129
return as_xtensor(x)
131130
except TypeError:
132-
try:
133-
return with_dims(x)
134-
except ValueError:
135-
raise ValueError(
136-
f"Variable {x} must have dims associated with it.\n"
137-
"To avoid subtle bugs, PyMC does not make any assumptions about the dims of the parameters.\n"
138-
"Convert parameters to an xarray.DataArray, pymc.dims.Data or pytensor.xtensor.as_xtensor with explicit dims."
139-
)
131+
raise ValueError(
132+
f"Variable {x} must have dims associated with it.\n"
133+
"To avoid subtle bugs, PyMC does not make any assumptions about the dims of the parameters.\n"
134+
"Convert parameters to an xarray.DataArray, pymc.dims.Data or pytensor.xtensor.as_xtensor with explicit dims."
135+
)
140136

141137
def __new__(
142138
cls,

pymc/dims/model.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from pytensor.tensor import TensorVariable
1717
from pytensor.xtensor import as_xtensor
18-
from pytensor.xtensor.basic import TensorFromXTensor
1918
from pytensor.xtensor.type import XTensorVariable
2019

2120
from pymc.data import Data as RegularData
@@ -25,29 +24,6 @@
2524
from pymc.model.core import Potential as RegularPotential
2625

2726

28-
def with_dims(x: TensorVariable | XTensorVariable, model: Model | None = None) -> XTensorVariable:
29-
"""Recover the dims of a variable that was registered in the Model."""
30-
if isinstance(x, XTensorVariable):
31-
return x
32-
33-
if (x.owner is not None) and isinstance(x.owner.op, TensorFromXTensor):
34-
dims = x.owner.inputs[0].type.dims
35-
return as_xtensor(x, dims=dims, name=x.name)
36-
37-
# Try accessing the model context to get dims
38-
try:
39-
model = modelcontext(model)
40-
if (
41-
model.named_vars.get(x.name, None) is x
42-
and (dims := model.named_vars_to_dims.get(x.name, None)) is not None
43-
):
44-
return as_xtensor(x, dims=dims, name=x.name)
45-
except TypeError:
46-
pass
47-
48-
raise ValueError(f"variable {x} doesn't have dims associated with it")
49-
50-
5127
def Data(
5228
name: str, value, dims: DimsWithEllipsis = None, model: Model | None = None, **kwargs
5329
) -> XTensorVariable:

0 commit comments

Comments
 (0)