Skip to content

Commit 6c0a5d1

Browse files
committed
Implement xarray like semantics in dims module
1 parent 0373dc1 commit 6c0a5d1

37 files changed

+1711
-150
lines changed

.github/workflows/tests.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,12 @@ jobs:
136136
tests/logprob/test_transforms.py
137137
tests/logprob/test_utils.py
138138
139+
- |
140+
tests/dims/distributions/test_core.py
141+
tests/dims/distributions/test_scalar.py
142+
tests/dims/distributions/test_vector.py
143+
tests/dims/test_model.py
144+
139145
fail-fast: false
140146
runs-on: ${{ matrix.os }}
141147
env:

pymc/data.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# limitations under the License.
1414

1515
import io
16+
import typing
1617
import urllib.request
1718

1819
from collections.abc import Sequence
1920
from copy import copy
20-
from typing import cast
21+
from typing import Union, cast
2122

2223
import numpy as np
2324
import pandas as pd
@@ -32,12 +33,13 @@
3233
from pytensor.tensor.random.basic import IntegersRV
3334
from pytensor.tensor.variable import TensorConstant, TensorVariable
3435

35-
import pymc as pm
36-
37-
from pymc.logprob.utils import rvs_in_graph
38-
from pymc.pytensorf import convert_data
36+
from pymc.exceptions import ShapeError
37+
from pymc.pytensorf import convert_data, rvs_in_graph
3938
from pymc.vartypes import isgenerator
4039

40+
if typing.TYPE_CHECKING:
41+
from pymc.model.core import Model
42+
4143
__all__ = [
4244
"Data",
4345
"Minibatch",
@@ -197,7 +199,7 @@ def determine_coords(
197199

198200
if isinstance(value, np.ndarray) and dims is not None:
199201
if len(dims) != value.ndim:
200-
raise pm.exceptions.ShapeError(
202+
raise ShapeError(
201203
"Invalid data shape. The rank of the dataset must match the length of `dims`.",
202204
actual=value.shape,
203205
expected=value.ndim,
@@ -222,6 +224,7 @@ def Data(
222224
dims: Sequence[str] | None = None,
223225
coords: dict[str, Sequence | np.ndarray] | None = None,
224226
infer_dims_and_coords=False,
227+
model: Union["Model", None] = None,
225228
**kwargs,
226229
) -> SharedVariable | TensorConstant:
227230
"""Create a data container that registers a data variable with the model.
@@ -286,15 +289,18 @@ def Data(
286289
... model.set_data("data", data_vals)
287290
... idatas.append(pm.sample())
288291
"""
292+
from pymc.model.core import modelcontext
293+
289294
if coords is None:
290295
coords = {}
291296

292297
if isinstance(value, list):
293298
value = np.array(value)
294299

295300
# Add data container to the named variables of the model.
296-
model = pm.Model.get_context(error_if_none=False)
297-
if model is None:
301+
try:
302+
model = modelcontext(model)
303+
except TypeError:
298304
raise TypeError(
299305
"No model on context stack, which is needed to instantiate a data container. "
300306
"Add variable inside a 'with model:' block."
@@ -321,7 +327,7 @@ def Data(
321327
if isinstance(dims, str):
322328
dims = (dims,)
323329
if not (dims is None or len(dims) == x.ndim):
324-
raise pm.exceptions.ShapeError(
330+
raise ShapeError(
325331
"Length of `dims` must match the dimensions of the dataset.",
326332
actual=len(dims),
327333
expected=x.ndim,

pymc/dims/__init__.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def __init__():
17+
"""Make PyMC aware of the xtensor functionality.
18+
19+
This should be done eagerly once development matures.
20+
"""
21+
import datetime
22+
import warnings
23+
24+
from pytensor.compile import optdb
25+
26+
from pymc.initial_point import initial_point_rewrites_db
27+
from pymc.logprob.abstract import MeasurableOp
28+
from pymc.logprob.rewriting import logprob_rewrites_db
29+
30+
# Filter PyTensor xtensor warning, we emmit our own warning
31+
with warnings.catch_warnings():
32+
warnings.simplefilter("ignore", UserWarning)
33+
import pytensor.xtensor
34+
35+
from pytensor.xtensor.vectorization import XRV
36+
37+
# Make PyMC aware of xtensor functionality
38+
MeasurableOp.register(XRV)
39+
logprob_rewrites_db.register(
40+
"pre_lower_xtensor", optdb.query("+lower_xtensor"), "basic", position=0.1
41+
)
42+
initial_point_rewrites_db.register(
43+
"lower_xtensor", optdb.query("+lower_xtensor"), "basic", position=0.1
44+
)
45+
46+
# TODO: Better model of probability of bugs
47+
day_of_conception = datetime.date(2025, 6, 17)
48+
day_of_last_bug = datetime.date(2025, 6, 30)
49+
today = datetime.date.today()
50+
days_with_bugs = (day_of_last_bug - day_of_conception).days
51+
days_without_bugs = (today - day_of_last_bug).days
52+
p = 1 - (days_without_bugs / (days_without_bugs + days_with_bugs + 10))
53+
if p > 0.05:
54+
warnings.warn(
55+
f"The `pymc.dims` module is experimental and may contain critical bugs (p={p:.3f}).\n"
56+
"Please report any issues you encounter at https://github.com/pymc-devs/pymc/issues.\n"
57+
"API changes are expected in future releases.\n",
58+
UserWarning,
59+
stacklevel=2,
60+
)
61+
62+
63+
__init__()
64+
del __init__
65+
66+
from pytensor.xtensor import as_xtensor, broadcast, concat, dot, full_like, ones_like, zeros_like
67+
from pytensor.xtensor.basic import tensor_from_xtensor
68+
69+
from pymc.dims import math
70+
from pymc.dims.distributions import *
71+
from pymc.dims.model import Data, Deterministic, Potential

pymc/dims/distributions/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from pymc.dims.distributions.scalar import *
15+
from pymc.dims.distributions.vector import *

pymc/dims/distributions/core.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from collections.abc import Callable, Sequence
15+
from itertools import chain
16+
17+
from pytensor.graph.basic import Variable
18+
from pytensor.tensor.elemwise import DimShuffle
19+
from pytensor.xtensor import as_xtensor
20+
from pytensor.xtensor.type import XTensorVariable
21+
22+
from pymc import modelcontext
23+
from pymc.dims.model import with_dims
24+
from pymc.distributions import transforms
25+
from pymc.distributions.distribution import _support_point, support_point
26+
from pymc.distributions.shape_utils import DimsWithEllipsis, convert_dims_with_ellipsis
27+
from pymc.logprob.transforms import Transform
28+
from pymc.util import UNSET
29+
30+
31+
@_support_point.register(DimShuffle)
32+
def dimshuffle_support_point(ds_op, _, rv):
33+
# We implement support point for DimShuffle because
34+
# DimDistribution can register a transposed version of a variable.
35+
36+
return ds_op(support_point(rv))
37+
38+
39+
class DimDistribution:
40+
"""Base class for PyMC distribution that wrap pytensor.xtensor.random operations, and follow xarray-like semantics."""
41+
42+
xrv_op: Callable
43+
default_transform: Transform | None = None
44+
45+
@staticmethod
46+
def _as_xtensor(x):
47+
try:
48+
return as_xtensor(x)
49+
except TypeError:
50+
try:
51+
return with_dims(x)
52+
except ValueError:
53+
raise ValueError(
54+
f"Variable {x} must have dims associated with it.\n"
55+
"To avoid subtle bugs, PyMC does not make any assumptions about the dims of parameters.\n"
56+
"Use `as_xtensor` with the `dims` keyword argument to specify the dims explicitly."
57+
)
58+
59+
def __new__(
60+
cls,
61+
name: str,
62+
*dist_params,
63+
dims: DimsWithEllipsis | None = None,
64+
initval=None,
65+
observed=None,
66+
total_size=None,
67+
transform=UNSET,
68+
default_transform=UNSET,
69+
model=None,
70+
**kwargs,
71+
):
72+
try:
73+
model = modelcontext(model)
74+
except TypeError:
75+
raise TypeError(
76+
"No model on context stack, which is needed to instantiate distributions. "
77+
"Add variable inside a 'with model:' block, or use the '.dist' syntax for a standalone distribution."
78+
)
79+
80+
if not isinstance(name, str):
81+
raise TypeError(f"Name needs to be a string but got: {name}")
82+
83+
dims = convert_dims_with_ellipsis(dims)
84+
if dims is None:
85+
dim_lengths = {}
86+
else:
87+
try:
88+
dim_lengths = {dim: model.dim_lengths[dim] for dim in dims if dim is not Ellipsis}
89+
except KeyError:
90+
raise ValueError(
91+
f"Not all dims {dims} are part of the model coords. "
92+
f"Add them at initialization time or use `model.add_coord` before defining the distribution."
93+
)
94+
95+
if observed is not None:
96+
observed = cls._as_xtensor(observed)
97+
98+
# Propagate observed dims to dim_lengths
99+
for observed_dim in observed.type.dims:
100+
if observed_dim not in dim_lengths:
101+
dim_lengths[observed_dim] = model.dim_lengths[observed_dim]
102+
103+
rv = cls.dist(*dist_params, dim_lengths=dim_lengths, **kwargs)
104+
105+
# User provided dims must specify all dims or use ellipsis
106+
if dims is not None:
107+
if (... not in dims) and (set(dims) != set(rv.type.dims)):
108+
raise ValueError(
109+
f"Provided dims {dims} do not match the distribution's output dims {rv.type.dims}. "
110+
"Use ellipsis to specify all other dimensions."
111+
)
112+
# Use provided dims to transpose the output to the desired order
113+
rv = rv.transpose(*dims)
114+
115+
rv_dims = rv.type.dims
116+
if observed is None:
117+
if default_transform is UNSET:
118+
default_transform = cls.default_transform
119+
else:
120+
# Align observed dims with those of the RV
121+
# TODO: If this fails give a more informative error message
122+
observed = observed.transpose(*rv_dims).values
123+
124+
rv = model.register_rv(
125+
rv.values,
126+
name=name,
127+
observed=observed,
128+
total_size=total_size,
129+
dims=rv_dims,
130+
transform=transform,
131+
default_transform=default_transform,
132+
initval=initval,
133+
)
134+
135+
return as_xtensor(rv, dims=rv_dims)
136+
137+
@classmethod
138+
def dist(
139+
cls,
140+
dist_params,
141+
*,
142+
dim_lengths: dict[str, Variable | int] | None = None,
143+
core_dims: str | Sequence[str] | None = None,
144+
**kwargs,
145+
) -> XTensorVariable:
146+
for invalid_kwarg in ("size", "shape", "dims"):
147+
if invalid_kwarg in kwargs:
148+
raise TypeError(f"DimDistribution does not accept {invalid_kwarg} argument.")
149+
150+
# XRV requires only extra_dims, not dims
151+
dist_params = [cls._as_xtensor(param) for param in dist_params]
152+
153+
if dim_lengths is None:
154+
extra_dims = None
155+
else:
156+
# Exclude dims that are implied by the parameters or core_dims
157+
implied_dims = set(chain.from_iterable(param.type.dims for param in dist_params))
158+
if core_dims is not None:
159+
if isinstance(core_dims, str):
160+
implied_dims.add(core_dims)
161+
else:
162+
implied_dims.update(core_dims)
163+
164+
extra_dims = {
165+
dim: length for dim, length in dim_lengths.items() if dim not in implied_dims
166+
}
167+
return cls.xrv_op(*dist_params, extra_dims=extra_dims, core_dims=core_dims, **kwargs)
168+
169+
170+
class VectorDimDistribution(DimDistribution):
171+
@classmethod
172+
def dist(self, *args, core_dims: str | Sequence[str] | None = None, **kwargs):
173+
# Add a helpful error message if core_dims is not provided
174+
if core_dims is None:
175+
raise ValueError(
176+
f"{self.__name__} requires core_dims to be specified, as it involves non-scalar inputs or outputs."
177+
"Check the documentation of the distribution for details."
178+
)
179+
return super().dist(*args, core_dims=core_dims, **kwargs)
180+
181+
182+
class PositiveDimDistribution(DimDistribution):
183+
"""Base class for positive continuous distributions."""
184+
185+
default_transform = transforms.log
186+
187+
188+
class UnitDimDistribution(DimDistribution):
189+
"""Base class for unit-valued distributions."""
190+
191+
default_transform = transforms.logodds

0 commit comments

Comments
 (0)