Skip to content

Commit ad0120a

Browse files
committed
Implement xarray like semantics in dims module
1 parent 0f1bfa9 commit ad0120a

File tree

18 files changed

+755
-119
lines changed

18 files changed

+755
-119
lines changed

pymc/data.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# limitations under the License.
1414

1515
import io
16+
import typing
1617
import urllib.request
1718
import warnings
1819

1920
from collections.abc import Sequence
2021
from copy import copy
21-
from typing import cast
22+
from typing import Union, cast
2223

2324
import numpy as np
2425
import pandas as pd
@@ -33,12 +34,13 @@
3334
from pytensor.tensor.random.basic import IntegersRV
3435
from pytensor.tensor.variable import TensorConstant, TensorVariable
3536

36-
import pymc as pm
37-
38-
from pymc.logprob.utils import rvs_in_graph
39-
from pymc.pytensorf import convert_data
37+
from pymc.exceptions import ShapeError
38+
from pymc.pytensorf import convert_data, rvs_in_graph
4039
from pymc.vartypes import isgenerator
4140

41+
if typing.TYPE_CHECKING:
42+
from pymc.model.core import Model
43+
4244
__all__ = [
4345
"ConstantData",
4446
"Data",
@@ -200,7 +202,7 @@ def determine_coords(
200202

201203
if isinstance(value, np.ndarray) and dims is not None:
202204
if len(dims) != value.ndim:
203-
raise pm.exceptions.ShapeError(
205+
raise ShapeError(
204206
"Invalid data shape. The rank of the dataset must match the length of `dims`.",
205207
actual=value.shape,
206208
expected=value.ndim,
@@ -286,6 +288,7 @@ def Data(
286288
coords: dict[str, Sequence | np.ndarray] | None = None,
287289
infer_dims_and_coords=False,
288290
mutable: bool | None = None,
291+
model: Union["Model", None] = None,
289292
**kwargs,
290293
) -> SharedVariable | TensorConstant:
291294
"""Create a data container that registers a data variable with the model.
@@ -350,15 +353,18 @@ def Data(
350353
... model.set_data("data", data_vals)
351354
... idatas.append(pm.sample())
352355
"""
356+
from pymc.model.core import modelcontext
357+
353358
if coords is None:
354359
coords = {}
355360

356361
if isinstance(value, list):
357362
value = np.array(value)
358363

359364
# Add data container to the named variables of the model.
360-
model = pm.Model.get_context(error_if_none=False)
361-
if model is None:
365+
try:
366+
model = modelcontext(model)
367+
except TypeError:
362368
raise TypeError(
363369
"No model on context stack, which is needed to instantiate a data container. "
364370
"Add variable inside a 'with model:' block."
@@ -390,7 +396,7 @@ def Data(
390396
if isinstance(dims, str):
391397
dims = (dims,)
392398
if not (dims is None or len(dims) == x.ndim):
393-
raise pm.exceptions.ShapeError(
399+
raise ShapeError(
394400
"Length of `dims` must match the dimensions of the dataset.",
395401
actual=len(dims),
396402
expected=x.ndim,

pymc/dims/__init__.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def __init__():
17+
"""Make PyMC aware of the xtensor functionality.
18+
19+
This should be done eagerly once developemnt matures.
20+
"""
21+
import datetime
22+
import warnings
23+
24+
from pytensor.compile import optdb
25+
26+
from pymc.initial_point import initial_point_rewrites_db
27+
from pymc.logprob.abstract import MeasurableOp
28+
from pymc.logprob.rewriting import logprob_rewrites_db
29+
30+
# Filter PyTensor xtensor warning, we emmit our own warning
31+
with warnings.catch_warnings():
32+
warnings.simplefilter("ignore", UserWarning)
33+
import pytensor.xtensor
34+
35+
from pytensor.xtensor.vectorization import XRV
36+
37+
# Make PyMC aware of xtensor functionality
38+
MeasurableOp.register(XRV)
39+
lower_xtensor_query = optdb.query("+lower_xtensor")
40+
logprob_rewrites_db.register("lower_xtensor", lower_xtensor_query, "basic")
41+
initial_point_rewrites_db.register("lower_xtensor", lower_xtensor_query, "basic")
42+
43+
# TODO: Better model of probability of bugs
44+
day_of_conception = datetime.date(2025, 6, 17)
45+
day_of_last_bug = datetime.date(2025, 6, 17)
46+
today = datetime.date.today()
47+
days_with_bugs = (day_of_last_bug - day_of_conception).days
48+
days_without_bugs = (today - day_of_last_bug).days
49+
p = 1 - (days_without_bugs / (days_without_bugs + days_with_bugs + 10))
50+
if p > 0.05:
51+
warnings.warn(
52+
f"The `pymc.dims` module is experimental and may contain critical bugs (p={p:.3f}).\n"
53+
"Please report any issues you encounter at https://github.com/pymc-devs/pymc/issues",
54+
UserWarning,
55+
stacklevel=2,
56+
)
57+
58+
59+
__init__()
60+
del __init__
61+
62+
from pymc.dims import math
63+
from pymc.dims.distributions import *
64+
from pymc.dims.model import Data

pymc/dims/distribution_core.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from collections.abc import Callable
15+
from itertools import chain
16+
17+
from pytensor.xtensor import as_xtensor
18+
from pytensor.xtensor.type import XTensorVariable
19+
20+
from pymc import modelcontext
21+
from pymc.distributions import transforms
22+
from pymc.distributions.shape_utils import Dims, convert_dims
23+
from pymc.util import UNSET
24+
25+
26+
class DimDistribution:
27+
"""Base class for PyMC distribution that wrap pytensor.xtensor.random operations, and follow xarray-like semantics."""
28+
29+
xrv_op: Callable
30+
default_transform: Callable | None
31+
32+
def __new__(
33+
cls,
34+
name: str,
35+
*dist_params,
36+
dims: Dims | None = None,
37+
initval=None,
38+
observed=None,
39+
total_size=None,
40+
transform=UNSET,
41+
default_transform=UNSET,
42+
model=None,
43+
**kwargs,
44+
) -> XTensorVariable:
45+
try:
46+
model = modelcontext(model)
47+
except TypeError:
48+
raise TypeError(
49+
"No model on context stack, which is needed to instantiate distributions. "
50+
"Add variable inside a 'with model:' block, or use the '.dist' syntax for a standalone distribution."
51+
)
52+
53+
if not isinstance(name, str):
54+
raise TypeError(f"Name needs to be a string but got: {name}")
55+
56+
if dims is None:
57+
dims_dict = {}
58+
else:
59+
dims = convert_dims(dims)
60+
try:
61+
dims_dict = {
62+
dim: model.dim_lengths[dim] for dim in dims if dim in model.dim_lengths
63+
}
64+
except KeyError:
65+
raise ValueError(
66+
f"Not all dims {dims} are part of the model coords. "
67+
f"Add them at initialization time or use `model.add_coord` before defining the distribution."
68+
)
69+
70+
if observed is not None:
71+
try:
72+
observed = as_xtensor(observed)
73+
except TypeError as e:
74+
raise TypeError(
75+
"Observed value must have dims associated with it.\n"
76+
"To avoid subtle bugs, PyMC does not make any assumptions about the dims of the observed value. "
77+
"Convert observed value to an xarray.DataArray, pymc.dims.Data or pytensor.xtensor.as_xtensor with explicit dims."
78+
) from e
79+
80+
# Propagate observed dims to dims_dict
81+
for observed_dim in observed.type.dims:
82+
if observed_dim not in dims_dict:
83+
dims_dict[observed_dim] = model.dim_lengths[observed_dim]
84+
85+
rv = cls.dist(*dist_params, dims_dict=dims_dict, **kwargs)
86+
87+
# User provided dims must specify all dims or use ellipsis
88+
if dims is not None:
89+
if (... not in dims) and (set(dims) != set(rv.type.dims)):
90+
raise ValueError(
91+
f"Provided dims {dims} do not match the distribution's output dims {rv.type.dims}. "
92+
"Use ellipsis to specify all other dimensions."
93+
)
94+
# Use provided dims to transpose the output to the desired order
95+
rv = rv.transpose(*dims)
96+
97+
rv_dims = rv.type.dims
98+
if observed is None:
99+
if default_transform is UNSET:
100+
default_transform = cls.default_transform
101+
else:
102+
# Align observed dims with those of the RV
103+
observed = observed.transpose(*rv_dims).values
104+
105+
rv = model.register_rv(
106+
rv.values,
107+
name=name,
108+
observed=observed,
109+
total_size=total_size,
110+
dims=rv_dims,
111+
transform=transform,
112+
default_transform=default_transform,
113+
initval=initval,
114+
)
115+
116+
xrv = as_xtensor(rv, dims=rv_dims)
117+
return xrv
118+
119+
@classmethod
120+
def dist(
121+
cls,
122+
dist_params,
123+
*,
124+
dims_dict: dict[str, int] | None = None,
125+
**kwargs,
126+
) -> XTensorVariable:
127+
for invalid_kwarg in ("size", "shape", "dims"):
128+
if invalid_kwarg in kwargs:
129+
raise TypeError(f"DimDistribution does not accept {invalid_kwarg} argument.")
130+
131+
# XRV requires only extra_dims, not dims
132+
try:
133+
dist_params = [as_xtensor(param) for param in dist_params]
134+
except TypeError as e:
135+
raise TypeError(
136+
"Distribution parameters must have dims associated with it.\n"
137+
"To avoid subtle bugs, PyMC does not make any assumptions about the dims of the parameters. "
138+
"Convert parameters to an xarray.DataArray, pymc.dims.Data or pytensor.xtensor.as_xtensor with explicit dims."
139+
)
140+
141+
if dims_dict is None:
142+
extra_dims = None
143+
else:
144+
parameter_implied_dims = set(
145+
chain.from_iterable(param.type.dims for param in dist_params)
146+
)
147+
extra_dims = {
148+
dim: length
149+
for dim, length in dims_dict.items()
150+
if dim not in parameter_implied_dims
151+
}
152+
return cls.xrv_op(*dist_params, extra_dims=extra_dims, **kwargs)
153+
154+
155+
class ContinuousDimDistribution(DimDistribution):
156+
"""Base class for real-valued distributions."""
157+
158+
default_transform = None
159+
160+
161+
class PositiveContinuousDimDistribution(DimDistribution):
162+
"""Base class for positive continuous distributions."""
163+
164+
default_transform = transforms.log
165+
166+
167+
class UnitDimDistribution(DimDistribution):
168+
"""Base class for unit-valued distributions."""
169+
170+
default_transform = transforms.logodds

0 commit comments

Comments
 (0)