|
13 | 13 | # limitations under the License. |
14 | 14 | import warnings |
15 | 15 |
|
| 16 | +from typing import Optional, TypeAlias, Union |
| 17 | + |
16 | 18 | import numpy as np |
| 19 | +import numpy.typing as npt |
17 | 20 | import pytensor.tensor as pt |
18 | 21 |
|
19 | 22 | from pytensor.tensor import TensorConstant |
|
29 | 32 | nbinom, |
30 | 33 | poisson, |
31 | 34 | ) |
| 35 | +from pytensor.tensor.variable import TensorVariable |
32 | 36 | from scipy import stats |
33 | 37 |
|
34 | 38 | import pymc as pm |
|
45 | 49 | normal_lccdf, |
46 | 50 | normal_lcdf, |
47 | 51 | ) |
48 | | -from pymc.distributions.distribution import Discrete |
| 52 | +from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Discrete |
49 | 53 | from pymc.distributions.shape_utils import rv_size_is_none |
50 | 54 | from pymc.logprob.basic import logcdf, logp |
51 | 55 | from pymc.math import sigmoid |
|
65 | 69 | "OrderedProbit", |
66 | 70 | ] |
67 | 71 |
|
| 72 | +DISCRETE_DIST_PARAMETER_TYPES: TypeAlias = Union[npt.NDArray[np.int_], int, TensorVariable] |
| 73 | + |
68 | 74 |
|
69 | 75 | class Binomial(Discrete): |
70 | 76 | R""" |
@@ -115,7 +121,14 @@ class Binomial(Discrete): |
115 | 121 | rv_op = binomial |
116 | 122 |
|
117 | 123 | @classmethod |
118 | | - def dist(cls, n, p=None, logit_p=None, *args, **kwargs): |
| 124 | + def dist( |
| 125 | + cls, |
| 126 | + n: DISCRETE_DIST_PARAMETER_TYPES, |
| 127 | + p: Optional[DIST_PARAMETER_TYPES] = None, |
| 128 | + logit_p: Optional[DIST_PARAMETER_TYPES] = None, |
| 129 | + *args, |
| 130 | + **kwargs, |
| 131 | + ): |
119 | 132 | if p is not None and logit_p is not None: |
120 | 133 | raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.") |
121 | 134 | elif p is None and logit_p is None: |
@@ -231,7 +244,14 @@ def BetaBinom(a, b, n, x): |
231 | 244 | rv_op = betabinom |
232 | 245 |
|
233 | 246 | @classmethod |
234 | | - def dist(cls, alpha, beta, n, *args, **kwargs): |
| 247 | + def dist( |
| 248 | + cls, |
| 249 | + alpha: DIST_PARAMETER_TYPES, |
| 250 | + beta: DIST_PARAMETER_TYPES, |
| 251 | + n: DISCRETE_DIST_PARAMETER_TYPES, |
| 252 | + *args, |
| 253 | + **kwargs, |
| 254 | + ): |
235 | 255 | alpha = pt.as_tensor_variable(alpha) |
236 | 256 | beta = pt.as_tensor_variable(beta) |
237 | 257 | n = pt.as_tensor_variable(n, dtype=int) |
@@ -338,7 +358,13 @@ class Bernoulli(Discrete): |
338 | 358 | rv_op = bernoulli |
339 | 359 |
|
340 | 360 | @classmethod |
341 | | - def dist(cls, p=None, logit_p=None, *args, **kwargs): |
| 361 | + def dist( |
| 362 | + cls, |
| 363 | + p: Optional[DIST_PARAMETER_TYPES] = None, |
| 364 | + logit_p: Optional[DIST_PARAMETER_TYPES] = None, |
| 365 | + *args, |
| 366 | + **kwargs, |
| 367 | + ): |
342 | 368 | if p is not None and logit_p is not None: |
343 | 369 | raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.") |
344 | 370 | elif p is None and logit_p is None: |
@@ -455,7 +481,7 @@ def DiscreteWeibull(q, b, x): |
455 | 481 | rv_op = discrete_weibull |
456 | 482 |
|
457 | 483 | @classmethod |
458 | | - def dist(cls, q, beta, *args, **kwargs): |
| 484 | + def dist(cls, q: DIST_PARAMETER_TYPES, beta: DIST_PARAMETER_TYPES, *args, **kwargs): |
459 | 485 | q = pt.as_tensor_variable(q) |
460 | 486 | beta = pt.as_tensor_variable(beta) |
461 | 487 | return super().dist([q, beta], **kwargs) |
@@ -545,7 +571,7 @@ class Poisson(Discrete): |
545 | 571 | rv_op = poisson |
546 | 572 |
|
547 | 573 | @classmethod |
548 | | - def dist(cls, mu, *args, **kwargs): |
| 574 | + def dist(cls, mu: DIST_PARAMETER_TYPES, *args, **kwargs): |
549 | 575 | mu = pt.as_tensor_variable(mu) |
550 | 576 | return super().dist([mu], *args, **kwargs) |
551 | 577 |
|
@@ -669,7 +695,15 @@ def NegBinom(a, m, x): |
669 | 695 | rv_op = nbinom |
670 | 696 |
|
671 | 697 | @classmethod |
672 | | - def dist(cls, mu=None, alpha=None, p=None, n=None, *args, **kwargs): |
| 698 | + def dist( |
| 699 | + cls, |
| 700 | + mu: Optional[DIST_PARAMETER_TYPES] = None, |
| 701 | + alpha: Optional[DIST_PARAMETER_TYPES] = None, |
| 702 | + p: Optional[DIST_PARAMETER_TYPES] = None, |
| 703 | + n: Optional[DIST_PARAMETER_TYPES] = None, |
| 704 | + *args, |
| 705 | + **kwargs, |
| 706 | + ): |
673 | 707 | n, p = cls.get_n_p(mu=mu, alpha=alpha, p=p, n=n) |
674 | 708 | n = pt.as_tensor_variable(n) |
675 | 709 | p = pt.as_tensor_variable(p) |
@@ -782,7 +816,7 @@ class Geometric(Discrete): |
782 | 816 | rv_op = geometric |
783 | 817 |
|
784 | 818 | @classmethod |
785 | | - def dist(cls, p, *args, **kwargs): |
| 819 | + def dist(cls, p: DIST_PARAMETER_TYPES, *args, **kwargs): |
786 | 820 | p = pt.as_tensor_variable(p) |
787 | 821 | return super().dist([p], *args, **kwargs) |
788 | 822 |
|
@@ -883,7 +917,14 @@ class HyperGeometric(Discrete): |
883 | 917 | rv_op = hypergeometric |
884 | 918 |
|
885 | 919 | @classmethod |
886 | | - def dist(cls, N, k, n, *args, **kwargs): |
| 920 | + def dist( |
| 921 | + cls, |
| 922 | + N: Optional[DISCRETE_DIST_PARAMETER_TYPES], |
| 923 | + k: Optional[DISCRETE_DIST_PARAMETER_TYPES], |
| 924 | + n: Optional[DISCRETE_DIST_PARAMETER_TYPES], |
| 925 | + *args, |
| 926 | + **kwargs, |
| 927 | + ): |
887 | 928 | good = pt.as_tensor_variable(k, dtype=int) |
888 | 929 | bad = pt.as_tensor_variable(N - k, dtype=int) |
889 | 930 | n = pt.as_tensor_variable(n, dtype=int) |
@@ -1020,7 +1061,13 @@ class DiscreteUniform(Discrete): |
1020 | 1061 | rv_op = discrete_uniform |
1021 | 1062 |
|
1022 | 1063 | @classmethod |
1023 | | - def dist(cls, lower, upper, *args, **kwargs): |
| 1064 | + def dist( |
| 1065 | + cls, |
| 1066 | + lower: DISCRETE_DIST_PARAMETER_TYPES, |
| 1067 | + upper: DISCRETE_DIST_PARAMETER_TYPES, |
| 1068 | + *args, |
| 1069 | + **kwargs, |
| 1070 | + ): |
1024 | 1071 | lower = pt.floor(lower) |
1025 | 1072 | upper = pt.floor(upper) |
1026 | 1073 | return super().dist([lower, upper], **kwargs) |
@@ -1116,7 +1163,12 @@ class Categorical(Discrete): |
1116 | 1163 | rv_op = categorical |
1117 | 1164 |
|
1118 | 1165 | @classmethod |
1119 | | - def dist(cls, p=None, logit_p=None, **kwargs): |
| 1166 | + def dist( |
| 1167 | + cls, |
| 1168 | + p: Optional[np.ndarray] = None, |
| 1169 | + logit_p: Optional[float] = None, |
| 1170 | + **kwargs, |
| 1171 | + ): |
1120 | 1172 | if p is not None and logit_p is not None: |
1121 | 1173 | raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.") |
1122 | 1174 | elif p is None and logit_p is None: |
@@ -1192,7 +1244,7 @@ class _OrderedLogistic(Categorical): |
1192 | 1244 | rv_op = categorical |
1193 | 1245 |
|
1194 | 1246 | @classmethod |
1195 | | - def dist(cls, eta, cutpoints, *args, **kwargs): |
| 1247 | + def dist(cls, eta: DIST_PARAMETER_TYPES, cutpoints: DIST_PARAMETER_TYPES, *args, **kwargs): |
1196 | 1248 | eta = pt.as_tensor_variable(eta) |
1197 | 1249 | cutpoints = pt.as_tensor_variable(cutpoints) |
1198 | 1250 |
|
@@ -1299,7 +1351,14 @@ class _OrderedProbit(Categorical): |
1299 | 1351 | rv_op = categorical |
1300 | 1352 |
|
1301 | 1353 | @classmethod |
1302 | | - def dist(cls, eta, cutpoints, sigma=1, *args, **kwargs): |
| 1354 | + def dist( |
| 1355 | + cls, |
| 1356 | + eta: DIST_PARAMETER_TYPES, |
| 1357 | + cutpoints: DIST_PARAMETER_TYPES, |
| 1358 | + sigma: DIST_PARAMETER_TYPES = 1.0, |
| 1359 | + *args, |
| 1360 | + **kwargs, |
| 1361 | + ): |
1303 | 1362 | eta = pt.as_tensor_variable(eta) |
1304 | 1363 | cutpoints = pt.as_tensor_variable(cutpoints) |
1305 | 1364 |
|
|
0 commit comments