|
13 | 13 | # limitations under the License. |
14 | 14 | import warnings |
15 | 15 |
|
| 16 | +from typing import Optional |
| 17 | + |
16 | 18 | import numpy as np |
17 | 19 | import pytensor.tensor as pt |
18 | 20 |
|
@@ -118,7 +120,14 @@ class Binomial(Discrete): |
118 | 120 | rv_op = binomial |
119 | 121 |
|
120 | 122 | @classmethod |
121 | | - def dist(cls, n, p=None, logit_p=None, *args, **kwargs): |
| 123 | + def dist( |
| 124 | + cls, |
| 125 | + n: DIST_PARAMETER_TYPES, |
| 126 | + p: Optional[DIST_PARAMETER_TYPES] = None, |
| 127 | + logit_p: Optional[DIST_PARAMETER_TYPES] = None, |
| 128 | + *args, |
| 129 | + **kwargs, |
| 130 | + ): |
122 | 131 | if p is not None and logit_p is not None: |
123 | 132 | raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.") |
124 | 133 | elif p is None and logit_p is None: |
@@ -234,7 +243,14 @@ def BetaBinom(a, b, n, x): |
234 | 243 | rv_op = betabinom |
235 | 244 |
|
236 | 245 | @classmethod |
237 | | - def dist(cls, alpha, beta, n, *args, **kwargs): |
| 246 | + def dist( |
| 247 | + cls, |
| 248 | + alpha: DIST_PARAMETER_TYPES, |
| 249 | + beta: DIST_PARAMETER_TYPES, |
| 250 | + n: DIST_PARAMETER_TYPES, |
| 251 | + *args, |
| 252 | + **kwargs, |
| 253 | + ): |
238 | 254 | alpha = pt.as_tensor_variable(alpha) |
239 | 255 | beta = pt.as_tensor_variable(beta) |
240 | 256 | n = pt.as_tensor_variable(n, dtype=int) |
@@ -341,7 +357,13 @@ class Bernoulli(Discrete): |
341 | 357 | rv_op = bernoulli |
342 | 358 |
|
343 | 359 | @classmethod |
344 | | - def dist(cls, p=None, logit_p=None, *args, **kwargs): |
| 360 | + def dist( |
| 361 | + cls, |
| 362 | + p: Optional[DIST_PARAMETER_TYPES] = None, |
| 363 | + logit_p: Optional[DIST_PARAMETER_TYPES] = None, |
| 364 | + *args, |
| 365 | + **kwargs, |
| 366 | + ): |
345 | 367 | if p is not None and logit_p is not None: |
346 | 368 | raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.") |
347 | 369 | elif p is None and logit_p is None: |
@@ -465,7 +487,8 @@ def DiscreteWeibull(q, b, x): |
465 | 487 | rv_op = DiscreteWeibullRV.rv_op |
466 | 488 |
|
467 | 489 | @classmethod |
468 | | - def dist(cls, q, beta, *args, **kwargs): |
| 490 | + def dist(cls, q: DIST_PARAMETER_TYPES, beta: DIST_PARAMETER_TYPES, *args, **kwargs): |
| 491 | + |
469 | 492 | return super().dist([q, beta], **kwargs) |
470 | 493 |
|
471 | 494 | def support_point(rv, size, q, beta): |
@@ -553,7 +576,8 @@ class Poisson(Discrete): |
553 | 576 | rv_op = poisson |
554 | 577 |
|
555 | 578 | @classmethod |
556 | | - def dist(cls, mu, *args, **kwargs): |
| 579 | + def dist(cls, mu: DIST_PARAMETER_TYPES, *args, **kwargs): |
| 580 | + |
557 | 581 | mu = pt.as_tensor_variable(mu) |
558 | 582 | return super().dist([mu], *args, **kwargs) |
559 | 583 |
|
@@ -677,7 +701,16 @@ def NegBinom(a, m, x): |
677 | 701 | rv_op = nbinom |
678 | 702 |
|
679 | 703 | @classmethod |
680 | | - def dist(cls, mu=None, alpha=None, p=None, n=None, *args, **kwargs): |
| 704 | + def dist( |
| 705 | + cls, |
| 706 | + mu: Optional[DIST_PARAMETER_TYPES] = None, |
| 707 | + alpha: Optional[DIST_PARAMETER_TYPES] = None, |
| 708 | + p: Optional[DIST_PARAMETER_TYPES] = None, |
| 709 | + n: Optional[DIST_PARAMETER_TYPES] = None, |
| 710 | + *args, |
| 711 | + **kwargs, |
| 712 | + ): |
| 713 | + |
681 | 714 | n, p = cls.get_n_p(mu=mu, alpha=alpha, p=p, n=n) |
682 | 715 | n = pt.as_tensor_variable(n) |
683 | 716 | p = pt.as_tensor_variable(p) |
@@ -790,7 +823,8 @@ class Geometric(Discrete): |
790 | 823 | rv_op = geometric |
791 | 824 |
|
792 | 825 | @classmethod |
793 | | - def dist(cls, p, *args, **kwargs): |
| 826 | + def dist(cls, p: DIST_PARAMETER_TYPES, *args, **kwargs): |
| 827 | + |
794 | 828 | p = pt.as_tensor_variable(p) |
795 | 829 | return super().dist([p], *args, **kwargs) |
796 | 830 |
|
@@ -1027,7 +1061,8 @@ class DiscreteUniform(Discrete): |
1027 | 1061 | rv_op = discrete_uniform |
1028 | 1062 |
|
1029 | 1063 | @classmethod |
1030 | | - def dist(cls, lower, upper, *args, **kwargs): |
| 1064 | + def dist(cls, lower: DIST_PARAMETER_TYPES, upper: DIST_PARAMETER_TYPES, *args, **kwargs): |
| 1065 | + |
1031 | 1066 | lower = pt.floor(lower) |
1032 | 1067 | upper = pt.floor(upper) |
1033 | 1068 | return super().dist([lower, upper], **kwargs) |
@@ -1123,7 +1158,12 @@ class Categorical(Discrete): |
1123 | 1158 | rv_op = categorical |
1124 | 1159 |
|
1125 | 1160 | @classmethod |
1126 | | - def dist(cls, p=None, logit_p=None, **kwargs): |
| 1161 | + def dist( |
| 1162 | + cls, |
| 1163 | + p: Optional[DIST_PARAMETER_TYPES] = None, |
| 1164 | + logit_p: Optional[DIST_PARAMETER_TYPES] = None, |
| 1165 | + **kwargs, |
| 1166 | + ): |
1127 | 1167 | if p is not None and logit_p is not None: |
1128 | 1168 | raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.") |
1129 | 1169 | elif p is None and logit_p is None: |
|
0 commit comments