Skip to content

Commit 7973967

Browse files
authored
Add Type Hints to distribution parameters (#6635)
1 parent 09dc9d0 commit 7973967

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

pymc/distributions/continuous.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,13 @@ class HalfNormal(PositiveContinuous):
816816
rv_op = halfnormal
817817

818818
@classmethod
819-
def dist(cls, sigma=None, tau=None, *args, **kwargs):
819+
def dist(
820+
cls,
821+
sigma: Optional[DIST_PARAMETER_TYPES] = None,
822+
tau: Optional[DIST_PARAMETER_TYPES] = None,
823+
*args,
824+
**kwargs,
825+
):
820826
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
821827

822828
return super().dist([0.0, sigma], **kwargs)
@@ -948,7 +954,14 @@ class Wald(PositiveContinuous):
948954
rv_op = wald
949955

950956
@classmethod
951-
def dist(cls, mu=None, lam=None, phi=None, alpha=0.0, **kwargs):
957+
def dist(
958+
cls,
959+
mu: Optional[DIST_PARAMETER_TYPES] = None,
960+
lam: Optional[DIST_PARAMETER_TYPES] = None,
961+
phi: Optional[DIST_PARAMETER_TYPES] = None,
962+
alpha: Optional[DIST_PARAMETER_TYPES] = 0.0,
963+
**kwargs,
964+
):
952965
mu, lam, phi = cls.get_mu_lam_phi(mu, lam, phi)
953966
alpha = pt.as_tensor_variable(floatX(alpha))
954967
mu = pt.as_tensor_variable(floatX(mu))
@@ -1115,7 +1128,16 @@ class Beta(UnitContinuous):
11151128
rv_op = pytensor.tensor.random.beta
11161129

11171130
@classmethod
1118-
def dist(cls, alpha=None, beta=None, mu=None, sigma=None, nu=None, *args, **kwargs):
1131+
def dist(
1132+
cls,
1133+
alpha: Optional[DIST_PARAMETER_TYPES] = None,
1134+
beta: Optional[DIST_PARAMETER_TYPES] = None,
1135+
mu: Optional[DIST_PARAMETER_TYPES] = None,
1136+
sigma: Optional[DIST_PARAMETER_TYPES] = None,
1137+
nu: Optional[DIST_PARAMETER_TYPES] = None,
1138+
*args,
1139+
**kwargs,
1140+
):
11191141
alpha, beta = cls.get_alpha_beta(alpha, beta, mu, sigma, nu)
11201142
alpha = pt.as_tensor_variable(floatX(alpha))
11211143
beta = pt.as_tensor_variable(floatX(beta))
@@ -1243,7 +1265,7 @@ class Kumaraswamy(UnitContinuous):
12431265
rv_op = kumaraswamy
12441266

12451267
@classmethod
1246-
def dist(cls, a, b, *args, **kwargs):
1268+
def dist(cls, a: DIST_PARAMETER_TYPES, b: DIST_PARAMETER_TYPES, *args, **kwargs):
12471269
a = pt.as_tensor_variable(floatX(a))
12481270
b = pt.as_tensor_variable(floatX(b))
12491271

@@ -1329,7 +1351,7 @@ class Exponential(PositiveContinuous):
13291351
rv_op = exponential
13301352

13311353
@classmethod
1332-
def dist(cls, lam, *args, **kwargs):
1354+
def dist(cls, lam: DIST_PARAMETER_TYPES, *args, **kwargs):
13331355
lam = pt.as_tensor_variable(floatX(lam))
13341356

13351357
# PyTensor exponential op is parametrized in terms of mu (1/lam)

0 commit comments

Comments
 (0)