Skip to content

Commit 69b64cc

Browse files
committed
Implement xarray like semantics in dims module
1 parent 9253343 commit 69b64cc

37 files changed

+1630
-140
lines changed

.github/workflows/tests.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,12 @@ jobs:
135135
tests/logprob/test_transforms.py
136136
tests/logprob/test_utils.py
137137
138+
- |
139+
tests/dims/distributions/test_core.py
140+
tests/dims/distributions/test_scalar.py
141+
tests/dims/distributions/test_vector.py
142+
tests/dims/test_model.py
143+
138144
fail-fast: false
139145
runs-on: ${{ matrix.os }}
140146
env:

conda-envs/environment-alternative-backends.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ dependencies:
2222
- numpyro>=0.8.0
2323
- pandas>=0.24.0
2424
- pip
25-
- pytensor>=2.31.2,<2.32
25+
- pytensor>=2.31.5,<2.32
2626
- python-graphviz
2727
- networkx
2828
- rich>=13.7.1

conda-envs/environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies:
1212
- numpy>=1.25.0
1313
- pandas>=0.24.0
1414
- pip
15-
- pytensor>=2.31.2,<2.32
15+
- pytensor>=2.31.5,<2.32
1616
- python-graphviz
1717
- networkx
1818
- scipy>=1.4.1

conda-envs/environment-docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ dependencies:
1111
- numpy>=1.25.0
1212
- pandas>=0.24.0
1313
- pip
14-
- pytensor>=2.31.2,<2.32
14+
- pytensor>=2.31.5,<2.32
1515
- python-graphviz
1616
- rich>=13.7.1
1717
- scipy>=1.4.1

conda-envs/environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ dependencies:
1414
- pandas>=0.24.0
1515
- pip
1616
- polyagamma
17-
- pytensor>=2.31.2,<2.32
17+
- pytensor>=2.31.5,<2.32
1818
- python-graphviz
1919
- networkx
2020
- rich>=13.7.1

conda-envs/windows-environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies:
1212
- numpy>=1.25.0
1313
- pandas>=0.24.0
1414
- pip
15-
- pytensor>=2.31.2,<2.32
15+
- pytensor>=2.31.5,<2.32
1616
- python-graphviz
1717
- networkx
1818
- rich>=13.7.1

conda-envs/windows-environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ dependencies:
1515
- pandas>=0.24.0
1616
- pip
1717
- polyagamma
18-
- pytensor>=2.31.2,<2.32
18+
- pytensor>=2.31.5,<2.32
1919
- python-graphviz
2020
- networkx
2121
- rich>=13.7.1

pymc/data.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# limitations under the License.
1414

1515
import io
16+
import typing
1617
import urllib.request
1718

1819
from collections.abc import Sequence
1920
from copy import copy
20-
from typing import cast
21+
from typing import Union, cast
2122

2223
import numpy as np
2324
import pandas as pd
@@ -32,12 +33,13 @@
3233
from pytensor.tensor.random.basic import IntegersRV
3334
from pytensor.tensor.variable import TensorConstant, TensorVariable
3435

35-
import pymc as pm
36-
37-
from pymc.logprob.utils import rvs_in_graph
38-
from pymc.pytensorf import convert_data
36+
from pymc.exceptions import ShapeError
37+
from pymc.pytensorf import convert_data, rvs_in_graph
3938
from pymc.vartypes import isgenerator
4039

40+
if typing.TYPE_CHECKING:
41+
from pymc.model.core import Model
42+
4143
__all__ = [
4244
"Data",
4345
"Minibatch",
@@ -197,7 +199,7 @@ def determine_coords(
197199

198200
if isinstance(value, np.ndarray) and dims is not None:
199201
if len(dims) != value.ndim:
200-
raise pm.exceptions.ShapeError(
202+
raise ShapeError(
201203
"Invalid data shape. The rank of the dataset must match the length of `dims`.",
202204
actual=value.shape,
203205
expected=value.ndim,
@@ -222,6 +224,7 @@ def Data(
222224
dims: Sequence[str] | None = None,
223225
coords: dict[str, Sequence | np.ndarray] | None = None,
224226
infer_dims_and_coords=False,
227+
model: Union["Model", None] = None,
225228
**kwargs,
226229
) -> SharedVariable | TensorConstant:
227230
"""Create a data container that registers a data variable with the model.
@@ -286,15 +289,18 @@ def Data(
286289
... model.set_data("data", data_vals)
287290
... idatas.append(pm.sample())
288291
"""
292+
from pymc.model.core import modelcontext
293+
289294
if coords is None:
290295
coords = {}
291296

292297
if isinstance(value, list):
293298
value = np.array(value)
294299

295300
# Add data container to the named variables of the model.
296-
model = pm.Model.get_context(error_if_none=False)
297-
if model is None:
301+
try:
302+
model = modelcontext(model)
303+
except TypeError:
298304
raise TypeError(
299305
"No model on context stack, which is needed to instantiate a data container. "
300306
"Add variable inside a 'with model:' block."
@@ -321,7 +327,7 @@ def Data(
321327
if isinstance(dims, str):
322328
dims = (dims,)
323329
if not (dims is None or len(dims) == x.ndim):
324-
raise pm.exceptions.ShapeError(
330+
raise ShapeError(
325331
"Length of `dims` must match the dimensions of the dataset.",
326332
actual=len(dims),
327333
expected=x.ndim,

pymc/dims/__init__.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def __init__():
17+
"""Make PyMC aware of the xtensor functionality.
18+
19+
This should be done eagerly once development matures.
20+
"""
21+
import datetime
22+
import warnings
23+
24+
from pytensor.compile import optdb
25+
26+
from pymc.initial_point import initial_point_rewrites_db
27+
from pymc.logprob.abstract import MeasurableOp
28+
from pymc.logprob.rewriting import logprob_rewrites_db
29+
30+
# Filter PyTensor xtensor warning, we emmit our own warning
31+
with warnings.catch_warnings():
32+
warnings.simplefilter("ignore", UserWarning)
33+
import pytensor.xtensor
34+
35+
from pytensor.xtensor.vectorization import XRV
36+
37+
# Make PyMC aware of xtensor functionality
38+
MeasurableOp.register(XRV)
39+
lower_xtensor_query = optdb.query("+lower_xtensor")
40+
logprob_rewrites_db.register("lower_xtensor", lower_xtensor_query, "basic", position=0.1)
41+
initial_point_rewrites_db.register("lower_xtensor", lower_xtensor_query, "basic", position=0.1)
42+
43+
# TODO: Better model of probability of bugs
44+
day_of_conception = datetime.date(2025, 6, 17)
45+
day_of_last_bug = datetime.date(2025, 6, 30)
46+
today = datetime.date.today()
47+
days_with_bugs = (day_of_last_bug - day_of_conception).days
48+
days_without_bugs = (today - day_of_last_bug).days
49+
p = 1 - (days_without_bugs / (days_without_bugs + days_with_bugs + 10))
50+
if p > 0.05:
51+
warnings.warn(
52+
f"The `pymc.dims` module is experimental and may contain critical bugs (p={p:.3f}).\n"
53+
"Please report any issues you encounter at https://github.com/pymc-devs/pymc/issues.\n"
54+
"Disclaimer: This an experimental API and may change at any time.",
55+
UserWarning,
56+
stacklevel=2,
57+
)
58+
59+
60+
__init__()
61+
del __init__
62+
63+
from pytensor.xtensor import as_xtensor, concat
64+
65+
from pymc.dims import math
66+
from pymc.dims.distributions import *
67+
from pymc.dims.model import Data, Deterministic, Potential, with_dims

pymc/dims/distributions/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from pymc.dims.distributions.scalar import *
15+
from pymc.dims.distributions.vector import *

0 commit comments

Comments
 (0)