Skip to content

Commit e7e1d42

Browse files
authored
Add Types to numpyro.distributions.distribution (#2122)
* init * rm custom noqa * fix specific cases * undo primitive change * fix * try fix test * try fix tests * no DistributionT * try fix test * rm distribution type * rm * improve tol test * try to fix test * try to fix test * try fix test
1 parent 4fd6c73 commit e7e1d42

File tree

12 files changed

+203
-365
lines changed

12 files changed

+203
-365
lines changed

numpyro/_typing.py

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -58,59 +58,6 @@ def check(self, value: NumLike) -> ArrayLike: ...
5858
def feasible_like(self, prototype: NumLike) -> NumLike: ...
5959

6060

61-
@runtime_checkable
62-
class DistributionT(Protocol):
63-
"""A protocol for typing distributions.
64-
65-
Used to type object of type numpyro.distributions.Distribution, funsor.Funsor
66-
or tensorflow_probability.distributions.Distribution.
67-
"""
68-
69-
arg_constraints: dict[str, ConstraintT] = ...
70-
support: ConstraintT = ...
71-
has_enumerate_support: bool = ...
72-
reparametrized_params: list[str] = ...
73-
_validate_args: bool = ...
74-
pytree_data_fields: tuple = ...
75-
pytree_aux_fields: tuple = ...
76-
77-
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
78-
79-
def rsample(
80-
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = ()
81-
) -> ArrayLike: ...
82-
def sample(
83-
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = ()
84-
) -> ArrayLike: ...
85-
def log_prob(self, value: ArrayLike) -> ArrayLike: ...
86-
def cdf(self, value: ArrayLike) -> ArrayLike: ...
87-
def icdf(self, q: ArrayLike) -> ArrayLike: ...
88-
def entropy(self) -> ArrayLike: ...
89-
def enumerate_support(self, expand: bool = True) -> ArrayLike: ...
90-
def shape(self, sample_shape: tuple[int, ...] = ()) -> tuple[int, ...]: ...
91-
92-
@property
93-
def batch_shape(self) -> tuple[int, ...]: ...
94-
@property
95-
def event_shape(self) -> tuple[int, ...]: ...
96-
@property
97-
def event_dim(self) -> int: ...
98-
@property
99-
def has_rsample(self) -> bool: ...
100-
101-
@property
102-
def mean(self) -> ArrayLike: ...
103-
@property
104-
def variance(self) -> ArrayLike: ...
105-
106-
@property
107-
def is_discrete(self) -> bool: ...
108-
109-
110-
# To avoid breaking changes for user code that uses `DistributionLike`
111-
DistributionLike = DistributionT
112-
113-
11461
@runtime_checkable
11562
class TransformT(Protocol):
11663
_inv: Optional[Union["TransformT", weakref.ref]] = ...

numpyro/distributions/censored.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import jax.numpy as jnp
1313
from jax.typing import ArrayLike
1414

15-
from numpyro._typing import ConstraintT, DistributionT
15+
from numpyro._typing import ConstraintT
1616
from numpyro.distributions import constraints
1717
from numpyro.distributions.distribution import Distribution
1818
from numpyro.distributions.util import log1mexp, promote_shapes, validate_sample
@@ -78,7 +78,7 @@ class LeftCensoredDistribution(Distribution):
7878

7979
def __init__(
8080
self,
81-
base_dist: DistributionT,
81+
base_dist: Distribution,
8282
censored: ArrayLike = False,
8383
*,
8484
validate_args: bool = False,
@@ -194,7 +194,7 @@ class RightCensoredDistribution(Distribution):
194194

195195
def __init__(
196196
self,
197-
base_dist: DistributionT,
197+
base_dist: Distribution,
198198
censored: ArrayLike = False,
199199
*,
200200
validate_args: bool = False,
@@ -331,7 +331,7 @@ class IntervalCensoredDistribution(Distribution):
331331

332332
def __init__(
333333
self,
334-
base_dist: DistributionT,
334+
base_dist: Distribution,
335335
left_censored: ArrayLike,
336336
right_censored: ArrayLike,
337337
*,

numpyro/distributions/constraints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,7 @@ def tree_flatten(self):
825825
complex = _Complex()
826826
corr_cholesky = _CorrCholesky()
827827
corr_matrix = _CorrMatrix()
828-
dependent: Constraint = _Dependent()
828+
dependent: _Dependent = _Dependent()
829829
greater_than = _GreaterThan
830830
greater_than_eq = _GreaterThanEq
831831
less_than = _LessThan

numpyro/distributions/continuous.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
from numpyro.distributions.discrete import _to_logits_bernoulli
6060
from numpyro.distributions.distribution import (
6161
Distribution,
62-
DistributionT,
6362
TransformedDistribution,
6463
)
6564
from numpyro.distributions.transforms import (
@@ -409,7 +408,7 @@ def __init__(
409408
self,
410409
t: Array,
411410
sde_fn: Callable[[Array, Array], tuple[Array, Array]],
412-
init_dist: DistributionT,
411+
init_dist: Distribution,
413412
*,
414413
validate_args: Optional[bool] = None,
415414
) -> None:

numpyro/distributions/copula.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from jax import Array, lax, numpy as jnp
99
from jax.typing import ArrayLike
1010

11-
from numpyro._typing import ConstraintT, DistributionT
11+
from numpyro._typing import ConstraintT
1212
import numpyro.distributions.constraints as constraints
1313
from numpyro.distributions.continuous import Beta, MultivariateNormal, Normal
1414
from numpyro.distributions.distribution import Distribution
@@ -39,7 +39,7 @@ class GaussianCopula(Distribution):
3939

4040
def __init__(
4141
self,
42-
marginal_dist: DistributionT,
42+
marginal_dist: Distribution,
4343
correlation_matrix: Optional[Array] = None,
4444
correlation_cholesky: Optional[Array] = None,
4545
*,

numpyro/distributions/discrete.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from jax.typing import ArrayLike
4040

4141
from numpyro.distributions import constraints, transforms
42-
from numpyro.distributions.distribution import Distribution, DistributionT
42+
from numpyro.distributions.distribution import Distribution
4343
from numpyro.distributions.util import (
4444
assert_one_of,
4545
binary_cross_entropy_with_logits,
@@ -845,7 +845,7 @@ class ZeroInflatedProbs(Distribution):
845845

846846
def __init__(
847847
self,
848-
base_dist: DistributionT,
848+
base_dist: Distribution,
849849
gate: ArrayLike,
850850
*,
851851
validate_args: Optional[bool] = None,
@@ -908,7 +908,7 @@ class ZeroInflatedLogits(ZeroInflatedProbs):
908908

909909
def __init__(
910910
self,
911-
base_dist: DistributionT,
911+
base_dist: Distribution,
912912
gate_logits: ArrayLike,
913913
*,
914914
validate_args: Optional[bool] = None,
@@ -928,7 +928,7 @@ def log_prob(self, value: ArrayLike) -> ArrayLike:
928928

929929

930930
def ZeroInflatedDistribution(
931-
base_dist: DistributionT,
931+
base_dist: Distribution,
932932
*,
933933
gate: Optional[ArrayLike] = None,
934934
gate_logits: Optional[ArrayLike] = None,

0 commit comments

Comments
 (0)