Skip to content

Commit a11d903

Browse files
committed
Add type hints to dist() in discrete distributions
1 parent 46f1675 commit a11d903

File tree

1 file changed

+72
-13
lines changed

1 file changed

+72
-13
lines changed

pymc/distributions/discrete.py

Lines changed: 72 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
# limitations under the License.
1414
import warnings
1515

16+
from typing import Optional, TypeAlias, Union
17+
1618
import numpy as np
19+
import numpy.typing as npt
1720
import pytensor.tensor as pt
1821

1922
from pytensor.tensor import TensorConstant
@@ -29,6 +32,7 @@
2932
nbinom,
3033
poisson,
3134
)
35+
from pytensor.tensor.variable import TensorVariable
3236
from scipy import stats
3337

3438
import pymc as pm
@@ -45,7 +49,7 @@
4549
normal_lccdf,
4650
normal_lcdf,
4751
)
48-
from pymc.distributions.distribution import Discrete
52+
from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Discrete
4953
from pymc.distributions.shape_utils import rv_size_is_none
5054
from pymc.logprob.basic import logcdf, logp
5155
from pymc.math import sigmoid
@@ -65,6 +69,8 @@
6569
"OrderedProbit",
6670
]
6771

72+
DISCRETE_DIST_PARAMETER_TYPES: TypeAlias = Union[npt.NDArray[np.int_], int, TensorVariable]
73+
6874

6975
class Binomial(Discrete):
7076
R"""
@@ -115,7 +121,14 @@ class Binomial(Discrete):
115121
rv_op = binomial
116122

117123
@classmethod
118-
def dist(cls, n, p=None, logit_p=None, *args, **kwargs):
124+
def dist(
125+
cls,
126+
n: DISCRETE_DIST_PARAMETER_TYPES,
127+
p: Optional[DIST_PARAMETER_TYPES] = None,
128+
logit_p: Optional[DIST_PARAMETER_TYPES] = None,
129+
*args,
130+
**kwargs,
131+
):
119132
if p is not None and logit_p is not None:
120133
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
121134
elif p is None and logit_p is None:
@@ -231,7 +244,14 @@ def BetaBinom(a, b, n, x):
231244
rv_op = betabinom
232245

233246
@classmethod
234-
def dist(cls, alpha, beta, n, *args, **kwargs):
247+
def dist(
248+
cls,
249+
alpha: DIST_PARAMETER_TYPES,
250+
beta: DIST_PARAMETER_TYPES,
251+
n: DISCRETE_DIST_PARAMETER_TYPES,
252+
*args,
253+
**kwargs,
254+
):
235255
alpha = pt.as_tensor_variable(alpha)
236256
beta = pt.as_tensor_variable(beta)
237257
n = pt.as_tensor_variable(n, dtype=int)
@@ -338,7 +358,13 @@ class Bernoulli(Discrete):
338358
rv_op = bernoulli
339359

340360
@classmethod
341-
def dist(cls, p=None, logit_p=None, *args, **kwargs):
361+
def dist(
362+
cls,
363+
p: Optional[DIST_PARAMETER_TYPES] = None,
364+
logit_p: Optional[DIST_PARAMETER_TYPES] = None,
365+
*args,
366+
**kwargs,
367+
):
342368
if p is not None and logit_p is not None:
343369
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
344370
elif p is None and logit_p is None:
@@ -455,7 +481,7 @@ def DiscreteWeibull(q, b, x):
455481
rv_op = discrete_weibull
456482

457483
@classmethod
458-
def dist(cls, q, beta, *args, **kwargs):
484+
def dist(cls, q: DIST_PARAMETER_TYPES, beta: DIST_PARAMETER_TYPES, *args, **kwargs):
459485
q = pt.as_tensor_variable(q)
460486
beta = pt.as_tensor_variable(beta)
461487
return super().dist([q, beta], **kwargs)
@@ -545,7 +571,7 @@ class Poisson(Discrete):
545571
rv_op = poisson
546572

547573
@classmethod
548-
def dist(cls, mu, *args, **kwargs):
574+
def dist(cls, mu: DIST_PARAMETER_TYPES, *args, **kwargs):
549575
mu = pt.as_tensor_variable(mu)
550576
return super().dist([mu], *args, **kwargs)
551577

@@ -669,7 +695,15 @@ def NegBinom(a, m, x):
669695
rv_op = nbinom
670696

671697
@classmethod
672-
def dist(cls, mu=None, alpha=None, p=None, n=None, *args, **kwargs):
698+
def dist(
699+
cls,
700+
mu: Optional[DIST_PARAMETER_TYPES] = None,
701+
alpha: Optional[DIST_PARAMETER_TYPES] = None,
702+
p: Optional[DIST_PARAMETER_TYPES] = None,
703+
n: Optional[DIST_PARAMETER_TYPES] = None,
704+
*args,
705+
**kwargs,
706+
):
673707
n, p = cls.get_n_p(mu=mu, alpha=alpha, p=p, n=n)
674708
n = pt.as_tensor_variable(n)
675709
p = pt.as_tensor_variable(p)
@@ -782,7 +816,7 @@ class Geometric(Discrete):
782816
rv_op = geometric
783817

784818
@classmethod
785-
def dist(cls, p, *args, **kwargs):
819+
def dist(cls, p: DIST_PARAMETER_TYPES, *args, **kwargs):
786820
p = pt.as_tensor_variable(p)
787821
return super().dist([p], *args, **kwargs)
788822

@@ -883,7 +917,14 @@ class HyperGeometric(Discrete):
883917
rv_op = hypergeometric
884918

885919
@classmethod
886-
def dist(cls, N, k, n, *args, **kwargs):
920+
def dist(
921+
cls,
922+
N: Optional[DISCRETE_DIST_PARAMETER_TYPES],
923+
k: Optional[DISCRETE_DIST_PARAMETER_TYPES],
924+
n: Optional[DISCRETE_DIST_PARAMETER_TYPES],
925+
*args,
926+
**kwargs,
927+
):
887928
good = pt.as_tensor_variable(k, dtype=int)
888929
bad = pt.as_tensor_variable(N - k, dtype=int)
889930
n = pt.as_tensor_variable(n, dtype=int)
@@ -1020,7 +1061,13 @@ class DiscreteUniform(Discrete):
10201061
rv_op = discrete_uniform
10211062

10221063
@classmethod
1023-
def dist(cls, lower, upper, *args, **kwargs):
1064+
def dist(
1065+
cls,
1066+
lower: DISCRETE_DIST_PARAMETER_TYPES,
1067+
upper: DISCRETE_DIST_PARAMETER_TYPES,
1068+
*args,
1069+
**kwargs,
1070+
):
10241071
lower = pt.floor(lower)
10251072
upper = pt.floor(upper)
10261073
return super().dist([lower, upper], **kwargs)
@@ -1116,7 +1163,12 @@ class Categorical(Discrete):
11161163
rv_op = categorical
11171164

11181165
@classmethod
1119-
def dist(cls, p=None, logit_p=None, **kwargs):
1166+
def dist(
1167+
cls,
1168+
p: Optional[np.ndarray] = None,
1169+
logit_p: Optional[float] = None,
1170+
**kwargs,
1171+
):
11201172
if p is not None and logit_p is not None:
11211173
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
11221174
elif p is None and logit_p is None:
@@ -1192,7 +1244,7 @@ class _OrderedLogistic(Categorical):
11921244
rv_op = categorical
11931245

11941246
@classmethod
1195-
def dist(cls, eta, cutpoints, *args, **kwargs):
1247+
def dist(cls, eta: DIST_PARAMETER_TYPES, cutpoints: DIST_PARAMETER_TYPES, *args, **kwargs):
11961248
eta = pt.as_tensor_variable(eta)
11971249
cutpoints = pt.as_tensor_variable(cutpoints)
11981250

@@ -1299,7 +1351,14 @@ class _OrderedProbit(Categorical):
12991351
rv_op = categorical
13001352

13011353
@classmethod
1302-
def dist(cls, eta, cutpoints, sigma=1, *args, **kwargs):
1354+
def dist(
1355+
cls,
1356+
eta: DIST_PARAMETER_TYPES,
1357+
cutpoints: DIST_PARAMETER_TYPES,
1358+
sigma: DIST_PARAMETER_TYPES = 1.0,
1359+
*args,
1360+
**kwargs,
1361+
):
13031362
eta = pt.as_tensor_variable(eta)
13041363
cutpoints = pt.as_tensor_variable(cutpoints)
13051364

0 commit comments

Comments
 (0)