Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 0 additions & 53 deletions numpyro/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,59 +58,6 @@ def check(self, value: NumLike) -> ArrayLike: ...
def feasible_like(self, prototype: NumLike) -> NumLike: ...


@runtime_checkable
class DistributionT(Protocol):
"""A protocol for typing distributions.

Used to type object of type numpyro.distributions.Distribution, funsor.Funsor
or tensorflow_probability.distributions.Distribution.
"""

arg_constraints: dict[str, ConstraintT] = ...
support: ConstraintT = ...
has_enumerate_support: bool = ...
reparametrized_params: list[str] = ...
_validate_args: bool = ...
pytree_data_fields: tuple = ...
pytree_aux_fields: tuple = ...

def __call__(self, *args: Any, **kwargs: Any) -> Any: ...

def rsample(
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = ()
) -> ArrayLike: ...
def sample(
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = ()
) -> ArrayLike: ...
def log_prob(self, value: ArrayLike) -> ArrayLike: ...
def cdf(self, value: ArrayLike) -> ArrayLike: ...
def icdf(self, q: ArrayLike) -> ArrayLike: ...
def entropy(self) -> ArrayLike: ...
def enumerate_support(self, expand: bool = True) -> ArrayLike: ...
def shape(self, sample_shape: tuple[int, ...] = ()) -> tuple[int, ...]: ...

@property
def batch_shape(self) -> tuple[int, ...]: ...
@property
def event_shape(self) -> tuple[int, ...]: ...
@property
def event_dim(self) -> int: ...
@property
def has_rsample(self) -> bool: ...

@property
def mean(self) -> ArrayLike: ...
@property
def variance(self) -> ArrayLike: ...

@property
def is_discrete(self) -> bool: ...


# To avoid breaking changes for user code that uses `DistributionLike`
DistributionLike = DistributionT


@runtime_checkable
class TransformT(Protocol):
_inv: Optional[Union["TransformT", weakref.ref]] = ...
Expand Down
8 changes: 4 additions & 4 deletions numpyro/distributions/censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import jax.numpy as jnp
from jax.typing import ArrayLike

from numpyro._typing import ConstraintT, DistributionT
from numpyro._typing import ConstraintT
from numpyro.distributions import constraints
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import log1mexp, promote_shapes, validate_sample
Expand Down Expand Up @@ -78,7 +78,7 @@ class LeftCensoredDistribution(Distribution):

def __init__(
self,
base_dist: DistributionT,
base_dist: Distribution,
censored: ArrayLike = False,
*,
validate_args: bool = False,
Expand Down Expand Up @@ -194,7 +194,7 @@ class RightCensoredDistribution(Distribution):

def __init__(
self,
base_dist: DistributionT,
base_dist: Distribution,
censored: ArrayLike = False,
*,
validate_args: bool = False,
Expand Down Expand Up @@ -331,7 +331,7 @@ class IntervalCensoredDistribution(Distribution):

def __init__(
self,
base_dist: DistributionT,
base_dist: Distribution,
left_censored: ArrayLike,
right_censored: ArrayLike,
*,
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ def tree_flatten(self):
complex = _Complex()
corr_cholesky = _CorrCholesky()
corr_matrix = _CorrMatrix()
dependent: Constraint = _Dependent()
dependent: _Dependent = _Dependent()
greater_than = _GreaterThan
greater_than_eq = _GreaterThanEq
less_than = _LessThan
Expand Down
3 changes: 1 addition & 2 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
from numpyro.distributions.discrete import _to_logits_bernoulli
from numpyro.distributions.distribution import (
Distribution,
DistributionT,
TransformedDistribution,
)
from numpyro.distributions.transforms import (
Expand Down Expand Up @@ -409,7 +408,7 @@ def __init__(
self,
t: Array,
sde_fn: Callable[[Array, Array], tuple[Array, Array]],
init_dist: DistributionT,
init_dist: Distribution,
*,
validate_args: Optional[bool] = None,
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions numpyro/distributions/copula.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from jax import Array, lax, numpy as jnp
from jax.typing import ArrayLike

from numpyro._typing import ConstraintT, DistributionT
from numpyro._typing import ConstraintT
import numpyro.distributions.constraints as constraints
from numpyro.distributions.continuous import Beta, MultivariateNormal, Normal
from numpyro.distributions.distribution import Distribution
Expand Down Expand Up @@ -39,7 +39,7 @@ class GaussianCopula(Distribution):

def __init__(
self,
marginal_dist: DistributionT,
marginal_dist: Distribution,
correlation_matrix: Optional[Array] = None,
correlation_cholesky: Optional[Array] = None,
*,
Expand Down
8 changes: 4 additions & 4 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from jax.typing import ArrayLike

from numpyro.distributions import constraints, transforms
from numpyro.distributions.distribution import Distribution, DistributionT
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
assert_one_of,
binary_cross_entropy_with_logits,
Expand Down Expand Up @@ -845,7 +845,7 @@ class ZeroInflatedProbs(Distribution):

def __init__(
self,
base_dist: DistributionT,
base_dist: Distribution,
gate: ArrayLike,
*,
validate_args: Optional[bool] = None,
Expand Down Expand Up @@ -908,7 +908,7 @@ class ZeroInflatedLogits(ZeroInflatedProbs):

def __init__(
self,
base_dist: DistributionT,
base_dist: Distribution,
gate_logits: ArrayLike,
*,
validate_args: Optional[bool] = None,
Expand All @@ -928,7 +928,7 @@ def log_prob(self, value: ArrayLike) -> ArrayLike:


def ZeroInflatedDistribution(
base_dist: DistributionT,
base_dist: Distribution,
*,
gate: Optional[ArrayLike] = None,
gate_logits: Optional[ArrayLike] = None,
Expand Down
Loading