Skip to content

Commit e1fc12a

Browse files
committed
Bump PyTensor dependency
1 parent 227da2c commit e1fc12a

File tree

10 files changed

+355
-8
lines changed

10 files changed

+355
-8
lines changed

conda-envs/environment-alternative-backends.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ dependencies:
2222
- numpyro>=0.8.0
2323
- pandas>=0.24.0
2424
- pip
25-
- pytensor>=2.31.2,<2.32
25+
- pytensor>=2.31.7,<2.32
2626
- python-graphviz
2727
- networkx
2828
- rich>=13.7.1

conda-envs/environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies:
1212
- numpy>=1.25.0
1313
- pandas>=0.24.0
1414
- pip
15-
- pytensor>=2.31.2,<2.32
15+
- pytensor>=2.31.7,<2.32
1616
- python-graphviz
1717
- networkx
1818
- scipy>=1.4.1

conda-envs/environment-docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ dependencies:
1111
- numpy>=1.25.0
1212
- pandas>=0.24.0
1313
- pip
14-
- pytensor>=2.31.2,<2.32
14+
- pytensor>=2.31.7,<2.32
1515
- python-graphviz
1616
- rich>=13.7.1
1717
- scipy>=1.4.1

conda-envs/environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ dependencies:
1414
- pandas>=0.24.0
1515
- pip
1616
- polyagamma
17-
- pytensor>=2.31.2,<2.32
17+
- pytensor>=2.31.7,<2.32
1818
- python-graphviz
1919
- networkx
2020
- rich>=13.7.1

conda-envs/windows-environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies:
1212
- numpy>=1.25.0
1313
- pandas>=0.24.0
1414
- pip
15-
- pytensor>=2.31.2,<2.32
15+
- pytensor>=2.31.7,<2.32
1616
- python-graphviz
1717
- networkx
1818
- rich>=13.7.1

conda-envs/windows-environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ dependencies:
1515
- pandas>=0.24.0
1616
- pip
1717
- polyagamma
18-
- pytensor>=2.31.2,<2.32
18+
- pytensor>=2.31.7,<2.32
1919
- python-graphviz
2020
- networkx
2121
- rich>=13.7.1

pymc/dims/distributions/scalar.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytensor.xtensor as ptx
15+
import pytensor.xtensor.random as pxr
16+
17+
from pytensor.xtensor import as_xtensor
18+
19+
from pymc.dims.distributions.core import (
20+
DimDistribution,
21+
PositiveDimDistribution,
22+
UnitDimDistribution,
23+
)
24+
from pymc.distributions.continuous import Beta as RegularBeta
25+
from pymc.distributions.continuous import Gamma as RegularGamma
26+
from pymc.distributions.continuous import HalfStudentTRV, flat, halfflat
27+
28+
29+
def _get_sigma_from_either_sigma_or_tau(*, sigma, tau):
30+
if sigma is not None and tau is not None:
31+
raise ValueError("Can't pass both tau and sigma")
32+
33+
if sigma is None and tau is None:
34+
return 1.0
35+
36+
if sigma is not None:
37+
return sigma
38+
39+
return ptx.math.reciprocal(ptx.math.sqrt(tau))
40+
41+
42+
class Flat(DimDistribution):
43+
xrv_op = pxr._as_xrv(flat)
44+
45+
@classmethod
46+
def dist(cls, **kwargs):
47+
return super().dist([], **kwargs)
48+
49+
50+
class HalfFlat(PositiveDimDistribution):
51+
xrv_op = pxr._as_xrv(halfflat, [], ())
52+
53+
@classmethod
54+
def dist(cls, **kwargs):
55+
return super().dist([], **kwargs)
56+
57+
58+
class Normal(DimDistribution):
59+
xrv_op = pxr.normal
60+
61+
@classmethod
62+
def dist(cls, mu=0, sigma=None, *, tau=None, **kwargs):
63+
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=tau)
64+
return super().dist([mu, sigma], **kwargs)
65+
66+
67+
class HalfNormal(PositiveDimDistribution):
68+
xrv_op = pxr.halfnormal
69+
70+
@classmethod
71+
def dist(cls, sigma=None, *, tau=None, **kwargs):
72+
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=tau)
73+
return super().dist([0.0, sigma], **kwargs)
74+
75+
76+
class LogNormal(PositiveDimDistribution):
77+
xrv_op = pxr.lognormal
78+
79+
@classmethod
80+
def dist(cls, mu=0, sigma=None, *, tau=None, **kwargs):
81+
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=tau)
82+
return super().dist([mu, sigma], **kwargs)
83+
84+
85+
class StudentT(DimDistribution):
86+
xrv_op = pxr.t
87+
88+
@classmethod
89+
def dist(cls, nu, mu=0, sigma=None, *, lam=None, **kwargs):
90+
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=lam)
91+
return super().dist([nu, mu, sigma], **kwargs)
92+
93+
94+
class HalfStudentT(PositiveDimDistribution):
95+
@classmethod
96+
def dist(cls, nu, sigma=None, *, lam=None, **kwargs):
97+
sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=lam)
98+
return super().dist([nu, sigma], **kwargs)
99+
100+
@classmethod
101+
def xrv_op(self, nu, sigma, core_dims=None, extra_dims=None, rng=None):
102+
nu = as_xtensor(nu)
103+
sigma = as_xtensor(sigma)
104+
core_rv = HalfStudentTRV.rv_op(nu=nu.values, sigma=sigma.values).owner.op
105+
xop = pxr._as_xrv(core_rv)
106+
return xop(nu, sigma, core_dims=core_dims, extra_dims=extra_dims, rng=rng)
107+
108+
109+
class Cauchy(DimDistribution):
110+
xrv_op = pxr.cauchy
111+
112+
@classmethod
113+
def dist(cls, alpha, beta, **kwargs):
114+
return super().dist([alpha, beta], **kwargs)
115+
116+
117+
class HalfCauchy(PositiveDimDistribution):
118+
xrv_op = pxr.halfcauchy
119+
120+
@classmethod
121+
def dist(cls, beta, **kwargs):
122+
return super().dist([0.0, beta], **kwargs)
123+
124+
125+
class Beta(UnitDimDistribution):
126+
xrv_op = pxr.beta
127+
128+
@classmethod
129+
def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, nu=None, **kwargs):
130+
alpha, beta = RegularBeta.get_alpha_beta(alpha=alpha, beta=beta, mu=mu, sigma=sigma, nu=nu)
131+
return super().dist([alpha, beta], **kwargs)
132+
133+
134+
class Laplace(DimDistribution):
135+
xrv_op = pxr.laplace
136+
137+
@classmethod
138+
def dist(cls, mu=0, b=1, **kwargs):
139+
return super().dist([mu, b], **kwargs)
140+
141+
142+
class Exponential(PositiveDimDistribution):
143+
xrv_op = pxr.exponential
144+
145+
@classmethod
146+
def dist(cls, lam=None, *, scale=None, **kwargs):
147+
if lam is None and scale is None:
148+
scale = 1.0
149+
elif lam is not None and scale is not None:
150+
raise ValueError("Cannot pass both 'lam' and 'scale'. Use one of them.")
151+
elif lam is not None:
152+
scale = 1 / lam
153+
return super().dist([scale], **kwargs)
154+
155+
156+
class Gamma(PositiveDimDistribution):
157+
xrv_op = pxr.gamma
158+
159+
@classmethod
160+
def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, **kwargs):
161+
if (alpha is not None) and (beta is not None):
162+
pass
163+
elif (mu is not None) and (sigma is not None):
164+
# Use sign of sigma to not let negative sigma fly by
165+
alpha = (mu**2 / sigma**2) * ptx.math.sign(sigma)
166+
beta = mu / sigma**2
167+
else:
168+
raise ValueError(
169+
"Incompatible parameterization. Either use alpha and beta, or mu and sigma."
170+
)
171+
alpha, beta = RegularGamma.get_alpha_beta(alpha=alpha, beta=beta, mu=mu, sigma=sigma)
172+
return super().dist([alpha, ptx.math.reciprocal(beta)], **kwargs)
173+
174+
175+
class InverseGamma(PositiveDimDistribution):
176+
xrv_op = pxr.invgamma
177+
178+
@classmethod
179+
def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, **kwargs):
180+
if alpha is not None:
181+
if beta is None:
182+
beta = 1.0
183+
elif (mu is not None) and (sigma is not None):
184+
# Use sign of sigma to not let negative sigma fly by
185+
alpha = ((2 * sigma**2 + mu**2) / sigma**2) * ptx.math.sign(sigma)
186+
beta = mu * (mu**2 + sigma**2) / sigma**2
187+
else:
188+
raise ValueError(
189+
"Incompatible parameterization. Either use alpha and (optionally) beta, or mu and sigma"
190+
)
191+
return super().dist([alpha, beta], **kwargs)

0 commit comments

Comments
 (0)