Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 66 additions & 13 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
# limitations under the License.
import warnings

from typing import TypeAlias

import numpy as np
import numpy.typing as npt
import pytensor.tensor as pt

from pytensor.tensor import TensorConstant
Expand All @@ -30,6 +33,7 @@
uniform,
)
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.variable import TensorVariable
from scipy import stats

import pymc as pm
Expand All @@ -46,10 +50,11 @@
normal_lccdf,
normal_lcdf,
)
from pymc.distributions.distribution import Discrete, SymbolicRandomVariable
from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Discrete, SymbolicRandomVariable
from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none
from pymc.logprob.basic import logcdf, logp
from pymc.math import sigmoid
from pymc.pytensorf import normalize_rng_param

__all__ = [
"Binomial",
Expand All @@ -66,7 +71,7 @@
"OrderedProbit",
]

from pymc.pytensorf import normalize_rng_param
DISCRETE_DIST_PARAMETER_TYPES: TypeAlias = npt.NDArray[np.int_] | int | TensorVariable


class Binomial(Discrete):
Expand Down Expand Up @@ -118,7 +123,14 @@ class Binomial(Discrete):
rv_op = binomial

@classmethod
def dist(cls, n, p=None, logit_p=None, *args, **kwargs):
def dist(
cls,
n: DISCRETE_DIST_PARAMETER_TYPES,
p: DIST_PARAMETER_TYPES | None = None,
logit_p: DIST_PARAMETER_TYPES | None = None,
*args,
**kwargs,
):
if p is not None and logit_p is not None:
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
elif p is None and logit_p is None:
Expand Down Expand Up @@ -234,7 +246,14 @@ def BetaBinom(a, b, n, x):
rv_op = betabinom

@classmethod
def dist(cls, alpha, beta, n, *args, **kwargs):
def dist(
cls,
alpha: DIST_PARAMETER_TYPES,
beta: DIST_PARAMETER_TYPES,
n: DISCRETE_DIST_PARAMETER_TYPES,
*args,
**kwargs,
):
alpha = pt.as_tensor_variable(alpha)
beta = pt.as_tensor_variable(beta)
n = pt.as_tensor_variable(n, dtype=int)
Expand Down Expand Up @@ -341,7 +360,13 @@ class Bernoulli(Discrete):
rv_op = bernoulli

@classmethod
def dist(cls, p=None, logit_p=None, *args, **kwargs):
def dist(
cls,
p: DIST_PARAMETER_TYPES | None = None,
logit_p: DIST_PARAMETER_TYPES | None = None,
*args,
**kwargs,
):
if p is not None and logit_p is not None:
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
elif p is None and logit_p is None:
Expand Down Expand Up @@ -465,7 +490,9 @@ def DiscreteWeibull(q, b, x):
rv_op = DiscreteWeibullRV.rv_op

@classmethod
def dist(cls, q, beta, *args, **kwargs):
def dist(cls, q: DIST_PARAMETER_TYPES, beta: DIST_PARAMETER_TYPES, *args, **kwargs):
q = pt.as_tensor_variable(q)
beta = pt.as_tensor_variable(beta)
return super().dist([q, beta], **kwargs)

def support_point(rv, size, q, beta):
Expand Down Expand Up @@ -553,7 +580,7 @@ class Poisson(Discrete):
rv_op = poisson

@classmethod
def dist(cls, mu, *args, **kwargs):
def dist(cls, mu: DIST_PARAMETER_TYPES, *args, **kwargs):
mu = pt.as_tensor_variable(mu)
return super().dist([mu], *args, **kwargs)

Expand Down Expand Up @@ -677,7 +704,15 @@ def NegBinom(a, m, x):
rv_op = nbinom

@classmethod
def dist(cls, mu=None, alpha=None, p=None, n=None, *args, **kwargs):
def dist(
cls,
mu: DIST_PARAMETER_TYPES | None = None,
alpha: DIST_PARAMETER_TYPES | None = None,
p: DIST_PARAMETER_TYPES | None = None,
n: DIST_PARAMETER_TYPES | None = None,
*args,
**kwargs,
):
n, p = cls.get_n_p(mu=mu, alpha=alpha, p=p, n=n)
n = pt.as_tensor_variable(n)
p = pt.as_tensor_variable(p)
Expand Down Expand Up @@ -790,7 +825,7 @@ class Geometric(Discrete):
rv_op = geometric

@classmethod
def dist(cls, p, *args, **kwargs):
def dist(cls, p: DIST_PARAMETER_TYPES, *args, **kwargs):
p = pt.as_tensor_variable(p)
return super().dist([p], *args, **kwargs)

Expand Down Expand Up @@ -891,7 +926,14 @@ class HyperGeometric(Discrete):
rv_op = hypergeometric

@classmethod
def dist(cls, N, k, n, *args, **kwargs):
def dist(
cls,
N: DISCRETE_DIST_PARAMETER_TYPES,
k: DISCRETE_DIST_PARAMETER_TYPES,
n: DISCRETE_DIST_PARAMETER_TYPES,
*args,
**kwargs,
):
good = pt.as_tensor_variable(k, dtype=int)
bad = pt.as_tensor_variable(N - k, dtype=int)
n = pt.as_tensor_variable(n, dtype=int)
Expand Down Expand Up @@ -1027,7 +1069,13 @@ class DiscreteUniform(Discrete):
rv_op = discrete_uniform

@classmethod
def dist(cls, lower, upper, *args, **kwargs):
def dist(
cls,
lower: DISCRETE_DIST_PARAMETER_TYPES,
upper: DISCRETE_DIST_PARAMETER_TYPES,
*args,
**kwargs,
):
lower = pt.floor(lower)
upper = pt.floor(upper)
return super().dist([lower, upper], **kwargs)
Expand Down Expand Up @@ -1123,7 +1171,12 @@ class Categorical(Discrete):
rv_op = categorical

@classmethod
def dist(cls, p=None, logit_p=None, **kwargs):
def dist(
cls,
p: np.ndarray | None = None,
logit_p: float | None = None,
**kwargs,
):
if p is not None and logit_p is not None:
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
elif p is None and logit_p is None:
Expand Down Expand Up @@ -1261,7 +1314,7 @@ def __new__(cls, name, eta, cutpoints, compute_p=True, **kwargs):
return out_rv

@classmethod
def dist(cls, eta, cutpoints, **kwargs):
def dist(cls, eta: DIST_PARAMETER_TYPES, cutpoints: DIST_PARAMETER_TYPES, **kwargs):
p = cls.compute_p(eta, cutpoints)
return Categorical.dist(p=p, **kwargs)

Expand Down
Loading