|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import pytensor.xtensor as ptx |
15 | | -import pytensor.xtensor.random as pxr |
| 15 | +import pytensor.xtensor.random as ptxr |
16 | 16 |
|
17 | 17 | from pytensor.tensor import as_tensor |
18 | 18 | from pytensor.tensor.random.utils import normalize_size_param |
19 | 19 | from pytensor.xtensor import as_xtensor |
| 20 | +from pytensor.xtensor import random as pxr |
20 | 21 |
|
21 | | -from pymc.dims.distribution_core import ( |
22 | | - DimDistribution, |
23 | | - MultivariateDimDistribution, |
24 | | - PositiveDimDistribution, |
25 | | - UnitDimDistribution, |
26 | | -) |
| 22 | +from pymc.dims.distributions.core import VectorDimDistribution |
27 | 23 | from pymc.dims.transforms import ZeroSumTransform |
28 | | -from pymc.distributions.continuous import Beta as RegularBeta |
29 | | -from pymc.distributions.continuous import Gamma as RegularGamma |
30 | | -from pymc.distributions.continuous import HalfStudentTRV, flat, halfflat |
31 | | -from pymc.distributions.continuous import InverseGamma as RegularInverseGamma |
32 | 24 | from pymc.distributions.multivariate import ZeroSumNormalRV |
33 | 25 | from pymc.util import UNSET |
34 | 26 |
|
35 | 27 |
|
36 | | -def _get_sigma_from_either_sigma_or_tau(*, sigma, tau): |
37 | | - if sigma is not None and tau is not None: |
38 | | - raise ValueError("Can't pass both tau and sigma") |
39 | | - |
40 | | - if sigma is None and tau is None: |
41 | | - return 1.0 |
42 | | - |
43 | | - if sigma is not None: |
44 | | - return sigma |
45 | | - |
46 | | - return ptx.math.reciprocal(ptx.math.square(sigma)) |
47 | | - |
48 | | - |
49 | | -class Flat(DimDistribution): |
50 | | - xrv_op = pxr._as_xrv(flat) |
51 | | - |
52 | | - @classmethod |
53 | | - def dist(cls, **kwargs): |
54 | | - return super().dist([], **kwargs) |
55 | | - |
56 | | - |
57 | | -class HalfFlat(PositiveDimDistribution): |
58 | | - xrv_op = pxr._as_xrv(halfflat, [], ()) |
59 | | - |
60 | | - @classmethod |
61 | | - def dist(cls, **kwargs): |
62 | | - return super().dist([], **kwargs) |
63 | | - |
64 | | - |
65 | | -class Normal(DimDistribution): |
66 | | - xrv_op = pxr.normal |
67 | | - |
68 | | - @classmethod |
69 | | - def dist(cls, mu=0, sigma=None, *, tau=None, **kwargs): |
70 | | - sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=tau) |
71 | | - return super().dist([mu, sigma], **kwargs) |
72 | | - |
73 | | - |
74 | | -class HalfNormal(PositiveDimDistribution): |
75 | | - xrv_op = pxr.halfnormal |
76 | | - |
77 | | - @classmethod |
78 | | - def dist(cls, sigma=None, *, tau=None, **kwargs): |
79 | | - sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=tau) |
80 | | - return super().dist([0.0, sigma], **kwargs) |
81 | | - |
82 | | - |
83 | | -class LogNormal(PositiveDimDistribution): |
84 | | - xrv_op = pxr.lognormal |
85 | | - |
86 | | - @classmethod |
87 | | - def dist(cls, mu=0, sigma=None, *, tau=None, **kwargs): |
88 | | - sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=tau) |
89 | | - return super().dist([mu, sigma], **kwargs) |
90 | | - |
91 | | - |
92 | | -class StudentT(DimDistribution): |
93 | | - xrv_op = pxr.t |
94 | | - |
95 | | - @classmethod |
96 | | - def dist(cls, nu, mu=0, sigma=None, *, lam=None, **kwargs): |
97 | | - sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=lam) |
98 | | - return super().dist([nu, mu, sigma], **kwargs) |
99 | | - |
100 | | - |
101 | | -class HalfStudentT(PositiveDimDistribution): |
102 | | - xrv_op = pxr._as_xrv(HalfStudentTRV.rv_op, [(), ()], ()) |
103 | | - |
104 | | - @classmethod |
105 | | - def dist(cls, nu, sigma=None, *, lam=None, **kwargs): |
106 | | - sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=lam) |
107 | | - return super().dist([nu, sigma], **kwargs) |
108 | | - |
109 | | - |
110 | | -class Cauchy(DimDistribution): |
111 | | - xrv_op = pxr.cauchy |
112 | | - |
113 | | - @classmethod |
114 | | - def dist(cls, alpha, beta, **kwargs): |
115 | | - return super().dist([alpha, beta], **kwargs) |
116 | | - |
117 | | - |
118 | | -class HalfCauchy(PositiveDimDistribution): |
119 | | - xrv_op = pxr.halfcauchy |
120 | | - |
121 | | - @classmethod |
122 | | - def dist(cls, beta, **kwargs): |
123 | | - return super().dist([0.0, beta], **kwargs) |
124 | | - |
125 | | - |
126 | | -class Beta(UnitDimDistribution): |
127 | | - xrv_op = pxr.beta |
128 | | - |
129 | | - @classmethod |
130 | | - def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, nu=None, **kwargs): |
131 | | - alpha, beta = RegularBeta.get_alpha_beta(alpha=alpha, beta=beta, mu=mu, sigma=sigma, nu=nu) |
132 | | - return super().dist([alpha, beta], **kwargs) |
133 | | - |
134 | | - |
135 | | -class Laplace(DimDistribution): |
136 | | - xrv_op = pxr.laplace |
137 | | - |
138 | | - @classmethod |
139 | | - def dist(cls, mu=0, b=1, **kwargs): |
140 | | - return super().dist([mu, b], **kwargs) |
141 | | - |
142 | | - |
143 | | -class Exponential(PositiveDimDistribution): |
144 | | - xrv_op = pxr.exponential |
| 28 | +# FIXME: Find a better name for this class, which is based on needed core_dims for inputs/outputs |
| 29 | +class Categorical(VectorDimDistribution): |
| 30 | + xrv_op = ptxr.categorical |
145 | 31 |
|
146 | 32 | @classmethod |
147 | | - def dist(cls, lam=None, *, scale=None, **kwargs): |
148 | | - if lam is None and scale is None: |
149 | | - scale = 1.0 |
150 | | - elif lam is not None and scale is not None: |
151 | | - raise ValueError("Cannot pass both 'lam' and 'scale'. Use one of them.") |
152 | | - elif lam is not None: |
153 | | - scale = 1 / lam |
154 | | - return super().dist([scale], **kwargs) |
| 33 | + def dist(cls, p=None, *, logit_p=None, core_dims=None, **kwargs): |
| 34 | + if p is not None and logit_p is not None: |
| 35 | + raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.") |
| 36 | + elif p is None and logit_p is None: |
| 37 | + raise ValueError("Incompatible parametrization. Must specify either p or logit_p.") |
155 | 38 |
|
156 | | - |
157 | | -class Gamma(PositiveDimDistribution): |
158 | | - xrv_op = pxr.gamma |
159 | | - |
160 | | - @classmethod |
161 | | - def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, **kwargs): |
162 | | - alpha, beta = RegularGamma.get_alpha_beta(alpha=alpha, beta=beta, mu=None, sigma=None) |
163 | | - return super().dist([alpha, 1 / beta], **kwargs) |
164 | | - |
165 | | - |
166 | | -class InverseGamma(PositiveDimDistribution): |
167 | | - xrv_op = pxr.invgamma |
168 | | - |
169 | | - @classmethod |
170 | | - def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, **kwargs): |
171 | | - alpha, beta = RegularInverseGamma.get_alpha_beta(alpha=alpha, beta=beta, mu=mu, sigma=sigma) |
172 | | - return super().dist([alpha, beta], **kwargs) |
| 39 | + if logit_p is not None: |
| 40 | + p = ptx.math.softmax(logit_p, dims=core_dims) |
| 41 | + return super().dist([p], core_dims=core_dims, **kwargs) |
173 | 42 |
|
174 | 43 |
|
175 | | -class MvNormal(MultivariateDimDistribution): |
| 44 | +class MvNormal(VectorDimDistribution): |
176 | 45 | """Multivariate Normal distribution. |
177 | 46 |
|
178 | 47 | Parameters |
@@ -238,7 +107,7 @@ def make_node(self, rng, size, sigma, support_shape): |
238 | 107 | return super().make_node(rng, size, sigma, support_shape) |
239 | 108 |
|
240 | 109 |
|
241 | | -class ZeroSumNormal(MultivariateDimDistribution): |
| 110 | +class ZeroSumNormal(VectorDimDistribution): |
242 | 111 | @classmethod |
243 | 112 | def __new__( |
244 | 113 | cls, *args, core_dims=None, dims=None, default_transform=UNSET, observed=None, **kwargs |
|
0 commit comments