Skip to content

Commit 04bc0ac

Browse files
committed
.WIP
1 parent 244931e commit 04bc0ac

File tree

7 files changed

+768
-156
lines changed

7 files changed

+768
-156
lines changed

docs/source/learn/core_notebooks/dims_module.ipynb

Lines changed: 570 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from pymc.dims.distributions.scalar import *
15+
from pymc.dims.distributions.vector import *

pymc/dims/distribution_core.py renamed to pymc/dims/distributions/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,13 +257,13 @@ def dist(
257257
return cls.xrv_op(*dist_params, extra_dims=extra_dims, core_dims=core_dims, **kwargs)
258258

259259

260-
class MultivariateDimDistribution(DimDistribution):
260+
class VectorDimDistribution(DimDistribution):
261261
@classmethod
262262
def dist(self, *args, core_dims: str | Sequence[str] | None = None, **kwargs):
263263
# Add a helpful error message if core_dims is not provided
264264
if core_dims is None:
265265
raise ValueError(
266-
f"{self.__name__} requires core_dims to be specified, as it is a multivariate distribution."
266+
f"{self.__name__} requires core_dims to be specified, as it involves non-scalar inputs or outputs."
267267
"Check the documentation of the distribution for details."
268268
)
269269
return super().dist(*args, core_dims=core_dims, **kwargs)

pymc/dims/distributions/scalar.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytensor.xtensor as ptx
15+
import pytensor.xtensor.random as pxr
16+
17+
from pymc.dims.distributions.core import (
18+
DimDistribution,
19+
PositiveDimDistribution,
20+
UnitDimDistribution,
21+
)
22+
from pymc.distributions.continuous import Beta as RegularBeta
23+
from pymc.distributions.continuous import Gamma as RegularGamma
24+
from pymc.distributions.continuous import HalfStudentTRV, flat, halfflat
25+
from pymc.distributions.continuous import InverseGamma as RegularInverseGamma
26+
27+
28+
def _get_sigma_from_either_sigma_or_tau(*, sigma, tau):
29+
if sigma is not None and tau is not None:
30+
raise ValueError("Can't pass both tau and sigma")
31+
32+
if sigma is None and tau is None:
33+
return 1.0
34+
35+
if sigma is not None:
36+
return sigma
37+
38+
return ptx.math.reciprocal(ptx.math.square(sigma))
39+
40+
41+
class Flat(DimDistribution):
42+
xrv_op = pxr._as_xrv(flat)
43+
44+
@classmethod
45+
def dist(cls, **kwargs):
46+
return super().dist([], **kwargs)
47+
48+
49+
class HalfFlat(PositiveDimDistribution):
50+
xrv_op = pxr._as_xrv(halfflat, [], ())
51+
52+
@classmethod
53+
def dist(cls, **kwargs):
54+
return super().dist([], **kwargs)
55+
56+
57+
class Normal(DimDistribution):
58+
xrv_op = pxr.normal
59+
60+
@classmethod
61+
def dist(cls, mu=0, sigma=None, *, tau=None, **kwargs):
62+
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=tau)
63+
return super().dist([mu, sigma], **kwargs)
64+
65+
66+
class HalfNormal(PositiveDimDistribution):
67+
xrv_op = pxr.halfnormal
68+
69+
@classmethod
70+
def dist(cls, sigma=None, *, tau=None, **kwargs):
71+
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=tau)
72+
return super().dist([0.0, sigma], **kwargs)
73+
74+
75+
class LogNormal(PositiveDimDistribution):
76+
xrv_op = pxr.lognormal
77+
78+
@classmethod
79+
def dist(cls, mu=0, sigma=None, *, tau=None, **kwargs):
80+
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=tau)
81+
return super().dist([mu, sigma], **kwargs)
82+
83+
84+
class StudentT(DimDistribution):
85+
xrv_op = pxr.t
86+
87+
@classmethod
88+
def dist(cls, nu, mu=0, sigma=None, *, lam=None, **kwargs):
89+
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=lam)
90+
return super().dist([nu, mu, sigma], **kwargs)
91+
92+
93+
class HalfStudentT(PositiveDimDistribution):
94+
xrv_op = pxr._as_xrv(HalfStudentTRV.rv_op, [(), ()], ())
95+
96+
@classmethod
97+
def dist(cls, nu, sigma=None, *, lam=None, **kwargs):
98+
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=lam)
99+
return super().dist([nu, sigma], **kwargs)
100+
101+
102+
class Cauchy(DimDistribution):
103+
xrv_op = pxr.cauchy
104+
105+
@classmethod
106+
def dist(cls, alpha, beta, **kwargs):
107+
return super().dist([alpha, beta], **kwargs)
108+
109+
110+
class HalfCauchy(PositiveDimDistribution):
111+
xrv_op = pxr.halfcauchy
112+
113+
@classmethod
114+
def dist(cls, beta, **kwargs):
115+
return super().dist([0.0, beta], **kwargs)
116+
117+
118+
class Beta(UnitDimDistribution):
119+
xrv_op = pxr.beta
120+
121+
@classmethod
122+
def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, nu=None, **kwargs):
123+
alpha, beta = RegularBeta.get_alpha_beta(alpha=alpha, beta=beta, mu=mu, sigma=sigma, nu=nu)
124+
return super().dist([alpha, beta], **kwargs)
125+
126+
127+
class Laplace(DimDistribution):
128+
xrv_op = pxr.laplace
129+
130+
@classmethod
131+
def dist(cls, mu=0, b=1, **kwargs):
132+
return super().dist([mu, b], **kwargs)
133+
134+
135+
class Exponential(PositiveDimDistribution):
136+
xrv_op = pxr.exponential
137+
138+
@classmethod
139+
def dist(cls, lam=None, *, scale=None, **kwargs):
140+
if lam is None and scale is None:
141+
scale = 1.0
142+
elif lam is not None and scale is not None:
143+
raise ValueError("Cannot pass both 'lam' and 'scale'. Use one of them.")
144+
elif lam is not None:
145+
scale = 1 / lam
146+
return super().dist([scale], **kwargs)
147+
148+
149+
class Gamma(PositiveDimDistribution):
150+
xrv_op = pxr.gamma
151+
152+
@classmethod
153+
def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, **kwargs):
154+
alpha, beta = RegularGamma.get_alpha_beta(alpha=alpha, beta=beta, mu=None, sigma=None)
155+
return super().dist([alpha, 1 / beta], **kwargs)
156+
157+
158+
class InverseGamma(PositiveDimDistribution):
159+
xrv_op = pxr.invgamma
160+
161+
@classmethod
162+
def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, **kwargs):
163+
alpha, beta = RegularInverseGamma.get_alpha_beta(alpha=alpha, beta=beta, mu=mu, sigma=sigma)
164+
return super().dist([alpha, beta], **kwargs)

pymc/dims/distributions.py renamed to pymc/dims/distributions/vector.py

Lines changed: 16 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -12,167 +12,36 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import pytensor.xtensor as ptx
15-
import pytensor.xtensor.random as pxr
15+
import pytensor.xtensor.random as ptxr
1616

1717
from pytensor.tensor import as_tensor
1818
from pytensor.tensor.random.utils import normalize_size_param
1919
from pytensor.xtensor import as_xtensor
20+
from pytensor.xtensor import random as pxr
2021

21-
from pymc.dims.distribution_core import (
22-
DimDistribution,
23-
MultivariateDimDistribution,
24-
PositiveDimDistribution,
25-
UnitDimDistribution,
26-
)
22+
from pymc.dims.distributions.core import VectorDimDistribution
2723
from pymc.dims.transforms import ZeroSumTransform
28-
from pymc.distributions.continuous import Beta as RegularBeta
29-
from pymc.distributions.continuous import Gamma as RegularGamma
30-
from pymc.distributions.continuous import HalfStudentTRV, flat, halfflat
31-
from pymc.distributions.continuous import InverseGamma as RegularInverseGamma
3224
from pymc.distributions.multivariate import ZeroSumNormalRV
3325
from pymc.util import UNSET
3426

3527

36-
def _get_sigma_from_either_sigma_or_tau(*, sigma, tau):
37-
if sigma is not None and tau is not None:
38-
raise ValueError("Can't pass both tau and sigma")
39-
40-
if sigma is None and tau is None:
41-
return 1.0
42-
43-
if sigma is not None:
44-
return sigma
45-
46-
return ptx.math.reciprocal(ptx.math.square(sigma))
47-
48-
49-
class Flat(DimDistribution):
50-
xrv_op = pxr._as_xrv(flat)
51-
52-
@classmethod
53-
def dist(cls, **kwargs):
54-
return super().dist([], **kwargs)
55-
56-
57-
class HalfFlat(PositiveDimDistribution):
58-
xrv_op = pxr._as_xrv(halfflat, [], ())
59-
60-
@classmethod
61-
def dist(cls, **kwargs):
62-
return super().dist([], **kwargs)
63-
64-
65-
class Normal(DimDistribution):
66-
xrv_op = pxr.normal
67-
68-
@classmethod
69-
def dist(cls, mu=0, sigma=None, *, tau=None, **kwargs):
70-
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=tau)
71-
return super().dist([mu, sigma], **kwargs)
72-
73-
74-
class HalfNormal(PositiveDimDistribution):
75-
xrv_op = pxr.halfnormal
76-
77-
@classmethod
78-
def dist(cls, sigma=None, *, tau=None, **kwargs):
79-
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=tau)
80-
return super().dist([0.0, sigma], **kwargs)
81-
82-
83-
class LogNormal(PositiveDimDistribution):
84-
xrv_op = pxr.lognormal
85-
86-
@classmethod
87-
def dist(cls, mu=0, sigma=None, *, tau=None, **kwargs):
88-
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=tau)
89-
return super().dist([mu, sigma], **kwargs)
90-
91-
92-
class StudentT(DimDistribution):
93-
xrv_op = pxr.t
94-
95-
@classmethod
96-
def dist(cls, nu, mu=0, sigma=None, *, lam=None, **kwargs):
97-
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=lam)
98-
return super().dist([nu, mu, sigma], **kwargs)
99-
100-
101-
class HalfStudentT(PositiveDimDistribution):
102-
xrv_op = pxr._as_xrv(HalfStudentTRV.rv_op, [(), ()], ())
103-
104-
@classmethod
105-
def dist(cls, nu, sigma=None, *, lam=None, **kwargs):
106-
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=lam)
107-
return super().dist([nu, sigma], **kwargs)
108-
109-
110-
class Cauchy(DimDistribution):
111-
xrv_op = pxr.cauchy
112-
113-
@classmethod
114-
def dist(cls, alpha, beta, **kwargs):
115-
return super().dist([alpha, beta], **kwargs)
116-
117-
118-
class HalfCauchy(PositiveDimDistribution):
119-
xrv_op = pxr.halfcauchy
120-
121-
@classmethod
122-
def dist(cls, beta, **kwargs):
123-
return super().dist([0.0, beta], **kwargs)
124-
125-
126-
class Beta(UnitDimDistribution):
127-
xrv_op = pxr.beta
128-
129-
@classmethod
130-
def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, nu=None, **kwargs):
131-
alpha, beta = RegularBeta.get_alpha_beta(alpha=alpha, beta=beta, mu=mu, sigma=sigma, nu=nu)
132-
return super().dist([alpha, beta], **kwargs)
133-
134-
135-
class Laplace(DimDistribution):
136-
xrv_op = pxr.laplace
137-
138-
@classmethod
139-
def dist(cls, mu=0, b=1, **kwargs):
140-
return super().dist([mu, b], **kwargs)
141-
142-
143-
class Exponential(PositiveDimDistribution):
144-
xrv_op = pxr.exponential
28+
# FIXME: Find a better name for this class, which is based on needed core_dims for inputs/outputs
29+
class Categorical(VectorDimDistribution):
30+
xrv_op = ptxr.categorical
14531

14632
@classmethod
147-
def dist(cls, lam=None, *, scale=None, **kwargs):
148-
if lam is None and scale is None:
149-
scale = 1.0
150-
elif lam is not None and scale is not None:
151-
raise ValueError("Cannot pass both 'lam' and 'scale'. Use one of them.")
152-
elif lam is not None:
153-
scale = 1 / lam
154-
return super().dist([scale], **kwargs)
33+
def dist(cls, p=None, *, logit_p=None, core_dims=None, **kwargs):
34+
if p is not None and logit_p is not None:
35+
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
36+
elif p is None and logit_p is None:
37+
raise ValueError("Incompatible parametrization. Must specify either p or logit_p.")
15538

156-
157-
class Gamma(PositiveDimDistribution):
158-
xrv_op = pxr.gamma
159-
160-
@classmethod
161-
def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, **kwargs):
162-
alpha, beta = RegularGamma.get_alpha_beta(alpha=alpha, beta=beta, mu=None, sigma=None)
163-
return super().dist([alpha, 1 / beta], **kwargs)
164-
165-
166-
class InverseGamma(PositiveDimDistribution):
167-
xrv_op = pxr.invgamma
168-
169-
@classmethod
170-
def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, **kwargs):
171-
alpha, beta = RegularInverseGamma.get_alpha_beta(alpha=alpha, beta=beta, mu=mu, sigma=sigma)
172-
return super().dist([alpha, beta], **kwargs)
39+
if logit_p is not None:
40+
p = ptx.math.softmax(logit_p, dims=core_dims)
41+
return super().dist([p], core_dims=core_dims, **kwargs)
17342

17443

175-
class MvNormal(MultivariateDimDistribution):
44+
class MvNormal(VectorDimDistribution):
17645
"""Multivariate Normal distribution.
17746
17847
Parameters
@@ -238,7 +107,7 @@ def make_node(self, rng, size, sigma, support_shape):
238107
return super().make_node(rng, size, sigma, support_shape)
239108

240109

241-
class ZeroSumNormal(MultivariateDimDistribution):
110+
class ZeroSumNormal(VectorDimDistribution):
242111
@classmethod
243112
def __new__(
244113
cls, *args, core_dims=None, dims=None, default_transform=UNSET, observed=None, **kwargs

pymc/dims/model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ def _register_and_return_xtensor_variable(
5555
value = value.transpose(*dims)
5656
# Regardless of whether dims are provided, we now have them
5757
dims = value.type.dims
58-
# Register the equivalent TensorVariable with the model so it doesn't see XTensorVariables directly.
59-
value = value.values
6058

6159
value = registration_func(name, value, dims=dims, model=model)
6260

0 commit comments

Comments
 (0)