Skip to content

Commit 4c93338

Browse files
committed
Model with dims
1 parent 0f1bfa9 commit 4c93338

File tree

17 files changed

+738
-219
lines changed

17 files changed

+738
-219
lines changed

pymc/data.py

Lines changed: 15 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# limitations under the License.
1414

1515
import io
16+
import typing
1617
import urllib.request
1718
import warnings
1819

1920
from collections.abc import Sequence
2021
from copy import copy
21-
from typing import cast
22+
from typing import Union, cast
2223

2324
import numpy as np
2425
import pandas as pd
@@ -33,17 +34,16 @@
3334
from pytensor.tensor.random.basic import IntegersRV
3435
from pytensor.tensor.variable import TensorConstant, TensorVariable
3536

36-
import pymc as pm
37-
38-
from pymc.logprob.utils import rvs_in_graph
39-
from pymc.pytensorf import convert_data
37+
from pymc.exceptions import ShapeError
38+
from pymc.pytensorf import convert_data, rvs_in_graph
4039
from pymc.vartypes import isgenerator
4140

41+
if typing.TYPE_CHECKING:
42+
from pymc.model.core import Model
43+
4244
__all__ = [
43-
"ConstantData",
4445
"Data",
4546
"Minibatch",
46-
"MutableData",
4747
"get_data",
4848
]
4949
BASE_URL = "https://raw.githubusercontent.com/pymc-devs/pymc-examples/main/examples/data/{filename}"
@@ -200,7 +200,7 @@ def determine_coords(
200200

201201
if isinstance(value, np.ndarray) and dims is not None:
202202
if len(dims) != value.ndim:
203-
raise pm.exceptions.ShapeError(
203+
raise ShapeError(
204204
"Invalid data shape. The rank of the dataset must match the length of `dims`.",
205205
actual=value.shape,
206206
expected=value.ndim,
@@ -218,66 +218,6 @@ def determine_coords(
218218
return coords, new_dims
219219

220220

221-
def ConstantData(
222-
name: str,
223-
value,
224-
*,
225-
dims: Sequence[str] | None = None,
226-
coords: dict[str, Sequence | np.ndarray] | None = None,
227-
infer_dims_and_coords=False,
228-
**kwargs,
229-
) -> TensorConstant:
230-
"""Alias for ``pm.Data``.
231-
232-
Registers the ``value`` as a :class:`~pytensor.tensor.TensorConstant` with the model.
233-
For more information, please reference :class:`pymc.Data`.
234-
"""
235-
warnings.warn(
236-
"ConstantData is deprecated. All Data variables are now mutable. Use Data instead.",
237-
FutureWarning,
238-
)
239-
240-
var = Data(
241-
name,
242-
value,
243-
dims=dims,
244-
coords=coords,
245-
infer_dims_and_coords=infer_dims_and_coords,
246-
**kwargs,
247-
)
248-
return cast(TensorConstant, var)
249-
250-
251-
def MutableData(
252-
name: str,
253-
value,
254-
*,
255-
dims: Sequence[str] | None = None,
256-
coords: dict[str, Sequence | np.ndarray] | None = None,
257-
infer_dims_and_coords=False,
258-
**kwargs,
259-
) -> SharedVariable:
260-
"""Alias for ``pm.Data``.
261-
262-
Registers the ``value`` as a :class:`~pytensor.compile.sharedvalue.SharedVariable`
263-
with the model. For more information, please reference :class:`pymc.Data`.
264-
"""
265-
warnings.warn(
266-
"MutableData is deprecated. All Data variables are now mutable. Use Data instead.",
267-
FutureWarning,
268-
)
269-
270-
var = Data(
271-
name,
272-
value,
273-
dims=dims,
274-
coords=coords,
275-
infer_dims_and_coords=infer_dims_and_coords,
276-
**kwargs,
277-
)
278-
return cast(SharedVariable, var)
279-
280-
281221
def Data(
282222
name: str,
283223
value,
@@ -286,6 +226,7 @@ def Data(
286226
coords: dict[str, Sequence | np.ndarray] | None = None,
287227
infer_dims_and_coords=False,
288228
mutable: bool | None = None,
229+
model: Union["Model", None] = None,
289230
**kwargs,
290231
) -> SharedVariable | TensorConstant:
291232
"""Create a data container that registers a data variable with the model.
@@ -350,15 +291,18 @@ def Data(
350291
... model.set_data("data", data_vals)
351292
... idatas.append(pm.sample())
352293
"""
294+
from pymc.model.core import modelcontext
295+
353296
if coords is None:
354297
coords = {}
355298

356299
if isinstance(value, list):
357300
value = np.array(value)
358301

359302
# Add data container to the named variables of the model.
360-
model = pm.Model.get_context(error_if_none=False)
361-
if model is None:
303+
try:
304+
model = modelcontext(model)
305+
except TypeError:
362306
raise TypeError(
363307
"No model on context stack, which is needed to instantiate a data container. "
364308
"Add variable inside a 'with model:' block."
@@ -390,7 +334,7 @@ def Data(
390334
if isinstance(dims, str):
391335
dims = (dims,)
392336
if not (dims is None or len(dims) == x.ndim):
393-
raise pm.exceptions.ShapeError(
337+
raise ShapeError(
394338
"Length of `dims` must match the dimensions of the dataset.",
395339
actual=len(dims),
396340
expected=x.ndim,

pymc/dims/__init__.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import datetime
15+
16+
from pymc.initial_point import INITIAL_POINT_REWRITES
17+
18+
19+
def __init__():
20+
import warnings
21+
22+
from pytensor.compile import optdb
23+
24+
from pymc.logprob.rewriting import logprob_rewrites_db
25+
from pymc.pytensorf import _XRV
26+
27+
# Import xtensor so that lower_xtensor is registered in the pytensor database
28+
# Filter PyTensor level warning, we emmit our own warning
29+
with warnings.catch_warnings():
30+
warnings.simplefilter("ignore", UserWarning)
31+
import pytensor.xtensor
32+
33+
from pytensor.xtensor.vectorization import XRV
34+
35+
# Register lower_xtensor rewrites in logprob_rewrites_db
36+
logprob_rewrites_db.register("pre_lower_xtensor", optdb.query("+lower_xtensor"), "basic")
37+
_XRV.append(XRV)
38+
INITIAL_POINT_REWRITES.append("lower_xtensor")
39+
40+
# TODO: Better model of probability of bugs
41+
day_of_conception = datetime.date(2025, 6, 17)
42+
day_of_last_bug = datetime.date(2025, 6, 17)
43+
today = datetime.date.today()
44+
days_with_bugs = (day_of_last_bug - day_of_conception).days
45+
days_without_bugs = (today - day_of_last_bug).days
46+
p = 1 - (days_without_bugs / (days_without_bugs + days_with_bugs + 10))
47+
if p > 0.05:
48+
warnings.warn(
49+
f"The `pymc.dims` module is experimental and may contain critical bugs (p={p:.3f}).\n"
50+
"Please report any issues you encounter at https://github.com/pymc-devs/pymc/issues",
51+
UserWarning,
52+
stacklevel=2,
53+
)
54+
55+
56+
__init__()
57+
del __init__
58+
59+
from pymc.dims import math
60+
from pymc.dims.distributions import *
61+
from pymc.dims.model import Data

pymc/dims/distribution_core.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from collections.abc import Callable
15+
from itertools import chain
16+
17+
from pytensor.xtensor import as_xtensor
18+
from pytensor.xtensor.type import XTensorVariable
19+
20+
from pymc import modelcontext
21+
from pymc.distributions import transforms
22+
from pymc.distributions.shape_utils import Dims, convert_dims
23+
from pymc.util import UNSET
24+
25+
26+
class DimDistribution:
27+
"""Base class for PyMC distribution that wrap pytensor.xtensor.random operations, and follow xarray-like semantics."""
28+
29+
xrv_op: Callable
30+
default_transform: Callable | None
31+
32+
def __new__(
33+
cls,
34+
name: str,
35+
*dist_params,
36+
dims: Dims | None = None,
37+
initval=None,
38+
observed=None,
39+
total_size=None,
40+
transform=UNSET,
41+
default_transform=UNSET,
42+
model=None,
43+
**kwargs,
44+
) -> XTensorVariable:
45+
try:
46+
model = modelcontext(model)
47+
except TypeError:
48+
raise TypeError(
49+
"No model on context stack, which is needed to instantiate distributions. "
50+
"Add variable inside a 'with model:' block, or use the '.dist' syntax for a standalone distribution."
51+
)
52+
53+
if not isinstance(name, str):
54+
raise TypeError(f"Name needs to be a string but got: {name}")
55+
56+
if dims is None:
57+
dims_dict = {}
58+
else:
59+
dims = convert_dims(dims)
60+
try:
61+
dims_dict = {
62+
dim: model.dim_lengths[dim] for dim in dims if dim in model.dim_lengths
63+
}
64+
except KeyError:
65+
raise ValueError(
66+
f"Not all dims {dims} are part of the model coords. "
67+
f"Add them at initialization time or use `model.add_coord` before defining the distribution."
68+
)
69+
70+
if observed is not None:
71+
try:
72+
observed = as_xtensor(observed)
73+
except TypeError as e:
74+
raise TypeError(
75+
"Observed value must have dims associated with it.\n"
76+
"To avoid subtle bugs, PyMC does not make any assumptions about the dims of the observed value. "
77+
"Convert observed value to an xarray.DataArray, pymc.dims.Data or pytensor.xtensor.as_xtensor with explicit dims."
78+
) from e
79+
80+
# Propagate observed dims to dims_dict
81+
for observed_dim in observed.type.dims:
82+
if observed_dim not in dims_dict:
83+
dims_dict[observed_dim] = model.dim_lengths[observed_dim]
84+
85+
rv = cls.dist(*dist_params, dims_dict=dims_dict, **kwargs)
86+
87+
# User provided dims must specify all dims or use ellipsis
88+
if dims is not None:
89+
if (... not in dims) and (set(dims) != set(rv.type.dims)):
90+
raise ValueError(
91+
f"Provided dims {dims} do not match the distribution's output dims {rv.type.dims}. "
92+
"Use ellipsis to specify all other dimensions."
93+
)
94+
# Use provided dims to transpose the output to the desired order
95+
rv = rv.transpose(*dims)
96+
97+
rv_dims = rv.type.dims
98+
if observed is None:
99+
if default_transform is UNSET:
100+
default_transform = cls.default_transform
101+
else:
102+
# Align observed dims with those of the RV
103+
observed = observed.transpose(*rv_dims).values
104+
105+
rv = model.register_rv(
106+
rv.values,
107+
name=name,
108+
observed=observed,
109+
total_size=total_size,
110+
dims=rv_dims,
111+
transform=transform,
112+
default_transform=default_transform,
113+
initval=initval,
114+
)
115+
116+
xrv = as_xtensor(rv, dims=rv_dims)
117+
return xrv
118+
119+
@classmethod
120+
def dist(
121+
cls,
122+
dist_params,
123+
*,
124+
dims_dict: dict[str, int] | None = None,
125+
**kwargs,
126+
) -> XTensorVariable:
127+
for invalid_kwarg in ("size", "shape", "dims"):
128+
if invalid_kwarg in kwargs:
129+
raise TypeError(f"DimDistribution does not accept {invalid_kwarg} argument.")
130+
131+
# XRV requires only extra_dims, not dims
132+
try:
133+
dist_params = [as_xtensor(param) for param in dist_params]
134+
except TypeError as e:
135+
raise TypeError(
136+
"Distribution parameters must have dims associated with it.\n"
137+
"To avoid subtle bugs, PyMC does not make any assumptions about the dims of the parameters. "
138+
"Convert parameters to an xarray.DataArray, pymc.dims.Data or pytensor.xtensor.as_xtensor with explicit dims."
139+
)
140+
141+
if dims_dict is None:
142+
extra_dims = None
143+
else:
144+
parameter_implied_dims = set(
145+
chain.from_iterable(param.type.dims for param in dist_params)
146+
)
147+
extra_dims = {
148+
dim: length
149+
for dim, length in dims_dict.items()
150+
if dim not in parameter_implied_dims
151+
}
152+
return cls.xrv_op(*dist_params, extra_dims=extra_dims, **kwargs)
153+
154+
155+
class ContinuousDimDistribution(DimDistribution):
156+
"""Base class for real-valued distributions."""
157+
158+
default_transform = None
159+
160+
161+
class PositiveContinuousDimDistribution(DimDistribution):
162+
"""Base class for positive continuous distributions."""
163+
164+
default_transform = transforms.log
165+
166+
167+
class UnitDimDistribution(DimDistribution):
168+
"""Base class for unit-valued distributions."""
169+
170+
default_transform = transforms.logodds

0 commit comments

Comments
 (0)