Skip to content

Commit 10cc350

Browse files
committed
Add type hints to discrete distributions
1 parent d1aff0b commit 10cc350

File tree

2 files changed

+50
-10
lines changed

2 files changed

+50
-10
lines changed

pymc/distributions/discrete.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414
import warnings
1515

16+
from typing import Optional
17+
1618
import numpy as np
1719
import pytensor.tensor as pt
1820

@@ -118,7 +120,14 @@ class Binomial(Discrete):
118120
rv_op = binomial
119121

120122
@classmethod
121-
def dist(cls, n, p=None, logit_p=None, *args, **kwargs):
123+
def dist(
124+
cls,
125+
n: DIST_PARAMETER_TYPES,
126+
p: Optional[DIST_PARAMETER_TYPES] = None,
127+
logit_p: Optional[DIST_PARAMETER_TYPES] = None,
128+
*args,
129+
**kwargs,
130+
):
122131
if p is not None and logit_p is not None:
123132
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
124133
elif p is None and logit_p is None:
@@ -234,7 +243,14 @@ def BetaBinom(a, b, n, x):
234243
rv_op = betabinom
235244

236245
@classmethod
237-
def dist(cls, alpha, beta, n, *args, **kwargs):
246+
def dist(
247+
cls,
248+
alpha: DIST_PARAMETER_TYPES,
249+
beta: DIST_PARAMETER_TYPES,
250+
n: DIST_PARAMETER_TYPES,
251+
*args,
252+
**kwargs,
253+
):
238254
alpha = pt.as_tensor_variable(alpha)
239255
beta = pt.as_tensor_variable(beta)
240256
n = pt.as_tensor_variable(n, dtype=int)
@@ -341,7 +357,13 @@ class Bernoulli(Discrete):
341357
rv_op = bernoulli
342358

343359
@classmethod
344-
def dist(cls, p=None, logit_p=None, *args, **kwargs):
360+
def dist(
361+
cls,
362+
p: Optional[DIST_PARAMETER_TYPES] = None,
363+
logit_p: Optional[DIST_PARAMETER_TYPES] = None,
364+
*args,
365+
**kwargs,
366+
):
345367
if p is not None and logit_p is not None:
346368
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
347369
elif p is None and logit_p is None:
@@ -465,7 +487,8 @@ def DiscreteWeibull(q, b, x):
465487
rv_op = DiscreteWeibullRV.rv_op
466488

467489
@classmethod
468-
def dist(cls, q, beta, *args, **kwargs):
490+
def dist(cls, q: DIST_PARAMETER_TYPES, beta: DIST_PARAMETER_TYPES, *args, **kwargs):
491+
469492
return super().dist([q, beta], **kwargs)
470493

471494
def support_point(rv, size, q, beta):
@@ -553,7 +576,8 @@ class Poisson(Discrete):
553576
rv_op = poisson
554577

555578
@classmethod
556-
def dist(cls, mu, *args, **kwargs):
579+
def dist(cls, mu: DIST_PARAMETER_TYPES, *args, **kwargs):
580+
557581
mu = pt.as_tensor_variable(mu)
558582
return super().dist([mu], *args, **kwargs)
559583

@@ -677,7 +701,16 @@ def NegBinom(a, m, x):
677701
rv_op = nbinom
678702

679703
@classmethod
680-
def dist(cls, mu=None, alpha=None, p=None, n=None, *args, **kwargs):
704+
def dist(
705+
cls,
706+
mu: Optional[DIST_PARAMETER_TYPES] = None,
707+
alpha: Optional[DIST_PARAMETER_TYPES] = None,
708+
p: Optional[DIST_PARAMETER_TYPES] = None,
709+
n: Optional[DIST_PARAMETER_TYPES] = None,
710+
*args,
711+
**kwargs,
712+
):
713+
681714
n, p = cls.get_n_p(mu=mu, alpha=alpha, p=p, n=n)
682715
n = pt.as_tensor_variable(n)
683716
p = pt.as_tensor_variable(p)
@@ -790,7 +823,8 @@ class Geometric(Discrete):
790823
rv_op = geometric
791824

792825
@classmethod
793-
def dist(cls, p, *args, **kwargs):
826+
def dist(cls, p: DIST_PARAMETER_TYPES, *args, **kwargs):
827+
794828
p = pt.as_tensor_variable(p)
795829
return super().dist([p], *args, **kwargs)
796830

@@ -1027,7 +1061,8 @@ class DiscreteUniform(Discrete):
10271061
rv_op = discrete_uniform
10281062

10291063
@classmethod
1030-
def dist(cls, lower, upper, *args, **kwargs):
1064+
def dist(cls, lower: DIST_PARAMETER_TYPES, upper: DIST_PARAMETER_TYPES, *args, **kwargs):
1065+
10311066
lower = pt.floor(lower)
10321067
upper = pt.floor(upper)
10331068
return super().dist([lower, upper], **kwargs)
@@ -1123,7 +1158,12 @@ class Categorical(Discrete):
11231158
rv_op = categorical
11241159

11251160
@classmethod
1126-
def dist(cls, p=None, logit_p=None, **kwargs):
1161+
def dist(
1162+
cls,
1163+
p: Optional[DIST_PARAMETER_TYPES] = None,
1164+
logit_p: Optional[DIST_PARAMETER_TYPES] = None,
1165+
**kwargs,
1166+
):
11271167
if p is not None and logit_p is not None:
11281168
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
11291169
elif p is None and logit_p is None:

pymc/distributions/distribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from functools import singledispatch
2424
from typing import Any, TypeAlias
2525

26-
import numpy as np
26+
import numpy as np # type: ignore
2727

2828
from pytensor import tensor as pt
2929
from pytensor.compile.builders import OpFromGraph

0 commit comments

Comments
 (0)