Skip to content

Commit 43bcfde

Browse files
authored
Remove Typing Protocols (#2125)
* rm protocols * transform type
1 parent e7e1d42 commit 43bcfde

File tree

7 files changed

+21
-67
lines changed

7 files changed

+21
-67
lines changed

numpyro/_typing.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,15 @@
66
from collections.abc import Callable
77
from typing import (
88
Any,
9-
Optional,
109
ParamSpec,
11-
Protocol,
1210
TypeAlias,
1311
TypeVar,
1412
Union,
15-
runtime_checkable,
1613
)
17-
import weakref
1814

1915
import numpy as np
2016

2117
import jax
22-
from jax.typing import ArrayLike
2318

2419
P = ParamSpec("P")
2520
ModelT: TypeAlias = Callable[P, Any]
@@ -41,43 +36,3 @@
4136

4237

4338
NumLikeT = TypeVar("NumLikeT", bound=NumLike)
44-
45-
46-
@runtime_checkable
47-
class ConstraintT(Protocol):
48-
"""A protocol for typing constraints."""
49-
50-
@property
51-
def is_discrete(self) -> bool: ...
52-
@property
53-
def event_dim(self) -> int: ...
54-
55-
def __call__(self, x: NumLike) -> ArrayLike: ...
56-
def __repr__(self) -> str: ...
57-
def check(self, value: NumLike) -> ArrayLike: ...
58-
def feasible_like(self, prototype: NumLike) -> NumLike: ...
59-
60-
61-
@runtime_checkable
62-
class TransformT(Protocol):
63-
_inv: Optional[Union["TransformT", weakref.ref]] = ...
64-
65-
@property
66-
def domain(self) -> ConstraintT: ...
67-
@property
68-
def codomain(self) -> ConstraintT: ...
69-
@property
70-
def inv(self) -> "TransformT": ...
71-
@property
72-
def sign(self) -> NumLike: ...
73-
74-
def __call__(self, x: NumLike) -> NumLike: ...
75-
def _inverse(self, y: NumLike) -> NumLike: ...
76-
def log_abs_det_jacobian(
77-
self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None
78-
) -> NumLike: ...
79-
def call_with_intermediates(
80-
self, x: NumLike
81-
) -> tuple[NumLike, Optional[PyTree]]: ...
82-
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ...
83-
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ...

numpyro/distributions/censored.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
import jax.numpy as jnp
1313
from jax.typing import ArrayLike
1414

15-
from numpyro._typing import ConstraintT
1615
from numpyro.distributions import constraints
16+
from numpyro.distributions.constraints import Constraint
1717
from numpyro.distributions.distribution import Distribution
1818
from numpyro.distributions.util import log1mexp, promote_shapes, validate_sample
1919
from numpyro.util import find_stack_level, not_jax_tracer
@@ -116,7 +116,7 @@ def sample(
116116
return self.base_dist.expand(self.batch_shape).sample(key, sample_shape)
117117

118118
@constraints.dependent_property(is_discrete=False, event_dim=0)
119-
def support(self) -> ConstraintT:
119+
def support(self) -> Constraint:
120120
return self._support
121121

122122
@validate_sample
@@ -232,7 +232,7 @@ def sample(
232232
return self.base_dist.expand(self.batch_shape).sample(key, sample_shape)
233233

234234
@constraints.dependent_property(is_discrete=False, event_dim=0)
235-
def support(self) -> ConstraintT:
235+
def support(self) -> Constraint:
236236
return self._support
237237

238238
@validate_sample
@@ -367,7 +367,7 @@ def sample(
367367
return self.base_dist.expand(self.batch_shape).sample(key, sample_shape)
368368

369369
@constraints.dependent_property(is_discrete=False, event_dim=1)
370-
def support(self) -> ConstraintT:
370+
def support(self) -> Constraint:
371371
return self._support
372372

373373
def _get_censoring_masks(self, value):

numpyro/distributions/conjugate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from jax.scipy.special import betainc, betaln, gammaln
1111
from jax.typing import ArrayLike
1212

13-
from numpyro._typing import ConstraintT
1413
from numpyro.distributions import constraints
14+
from numpyro.distributions.constraints import Constraint
1515
from numpyro.distributions.continuous import Beta, Dirichlet, Gamma
1616
from numpyro.distributions.discrete import (
1717
BinomialProbs,
@@ -105,7 +105,7 @@ def variance(self) -> ArrayLike:
105105
)
106106

107107
@constraints.dependent_property(is_discrete=True, event_dim=0)
108-
def support(self) -> ConstraintT:
108+
def support(self) -> Constraint:
109109
return constraints.integer_interval(0, self.total_count)
110110

111111

@@ -324,7 +324,7 @@ def variance(self) -> ArrayLike:
324324
return n * alpha_ratio * (1 - alpha_ratio) * (n + alpha_sum) / (1 + alpha_sum)
325325

326326
@constraints.dependent_property(is_discrete=True, event_dim=1)
327-
def support(self) -> ConstraintT:
327+
def support(self) -> Constraint:
328328
return constraints.multinomial(self.total_count)
329329

330330
@staticmethod

numpyro/distributions/copula.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from jax import Array, lax, numpy as jnp
99
from jax.typing import ArrayLike
1010

11-
from numpyro._typing import ConstraintT
1211
import numpyro.distributions.constraints as constraints
12+
from numpyro.distributions.constraints import Constraint
1313
from numpyro.distributions.continuous import Beta, MultivariateNormal, Normal
1414
from numpyro.distributions.distribution import Distribution
1515
from numpyro.distributions.util import clamp_probs, lazy_property, validate_sample
@@ -100,7 +100,7 @@ def variance(self) -> ArrayLike:
100100
return jnp.broadcast_to(self.marginal_dist.variance, self.shape())
101101

102102
@constraints.dependent_property(is_discrete=False, event_dim=1)
103-
def support(self) -> ConstraintT:
103+
def support(self) -> Constraint:
104104
return constraints.independent(self.marginal_dist.support, 1)
105105

106106
@lazy_property

numpyro/distributions/flows.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import jax.numpy as jnp
77
from jax.typing import ArrayLike
88

9-
from numpyro._typing import TransformT
109
from numpyro.distributions.constraints import real_vector
1110
from numpyro.distributions.transforms import Transform
1211
from numpyro.util import fori_loop
@@ -109,7 +108,7 @@ def tree_flatten(self):
109108
{"arn": self.arn},
110109
)
111110

112-
def __eq__(self, other: TransformT) -> bool:
111+
def __eq__(self, other: Transform) -> bool:
113112
if not isinstance(other, InverseAutoregressiveTransform):
114113
return False
115114
return (
@@ -170,7 +169,7 @@ def log_abs_det_jacobian(
170169
def tree_flatten(self):
171170
return (), ((), {"bn_arn": self.bn_arn})
172171

173-
def __eq__(self, other: TransformT) -> bool:
172+
def __eq__(self, other: Transform) -> bool:
174173
return (
175174
isinstance(other, BlockNeuralAutoregressiveTransform)
176175
and self.bn_arn is other.bn_arn

numpyro/distributions/mixtures.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import jax.numpy as jnp
1010
from jax.typing import ArrayLike
1111

12-
from numpyro._typing import ConstraintT
1312
from numpyro.distributions import constraints
13+
from numpyro.distributions.constraints import Constraint
1414
from numpyro.distributions.discrete import CategoricalLogits, CategoricalProbs
1515
from numpyro.distributions.distribution import Distribution
1616
from numpyro.distributions.util import validate_sample
@@ -258,7 +258,7 @@ def component_distribution(self) -> Distribution:
258258
return self._component_distribution
259259

260260
@constraints.dependent_property
261-
def support(self) -> ConstraintT:
261+
def support(self) -> Constraint:
262262
return self.component_distribution.support
263263

264264
@property
@@ -353,7 +353,7 @@ def __init__(
353353
mixing_distribution: Union[CategoricalProbs, CategoricalLogits],
354354
component_distributions: list[Distribution],
355355
*,
356-
support: Optional[ConstraintT] = None,
356+
support: Optional[Constraint] = None,
357357
validate_args: Optional[bool] = None,
358358
):
359359
_check_mixing_distribution(mixing_distribution)
@@ -424,7 +424,7 @@ def component_distributions(self) -> list[Distribution]:
424424
return self._component_distributions
425425

426426
@constraints.dependent_property
427-
def support(self) -> ConstraintT:
427+
def support(self) -> Constraint:
428428
if self._support is not None:
429429
return self._support
430430
return self.component_distributions[0].support

numpyro/distributions/truncated.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from jax.scipy.special import logsumexp
1212
from jax.typing import ArrayLike
1313

14-
from numpyro._typing import ConstraintT
1514
from numpyro.distributions import constraints
15+
from numpyro.distributions.constraints import Constraint
1616
from numpyro.distributions.continuous import (
1717
Cauchy,
1818
Laplace,
@@ -57,7 +57,7 @@ def __init__(
5757
super().__init__(batch_shape, validate_args=validate_args)
5858

5959
@constraints.dependent_property(is_discrete=False, event_dim=0)
60-
def support(self) -> ConstraintT:
60+
def support(self) -> Constraint:
6161
return self._support
6262

6363
@lazy_property
@@ -162,7 +162,7 @@ def __init__(
162162
super().__init__(batch_shape, validate_args=validate_args)
163163

164164
@constraints.dependent_property(is_discrete=False, event_dim=0)
165-
def support(self) -> ConstraintT:
165+
def support(self) -> Constraint:
166166
return self._support
167167

168168
@lazy_property
@@ -259,7 +259,7 @@ def __init__(
259259
super().__init__(batch_shape, validate_args=validate_args)
260260

261261
@constraints.dependent_property(is_discrete=False, event_dim=0)
262-
def support(self) -> ConstraintT:
262+
def support(self) -> Constraint:
263263
return self._support
264264

265265
@lazy_property
@@ -529,7 +529,7 @@ def __init__(
529529
)
530530

531531
@constraints.dependent_property(is_discrete=False, event_dim=0)
532-
def support(self) -> ConstraintT:
532+
def support(self) -> Constraint:
533533
return self._support
534534

535535
@validate_sample
@@ -1010,7 +1010,7 @@ def __init__(
10101010
)
10111011

10121012
@constraints.dependent_property(is_discrete=False, event_dim=0)
1013-
def support(self) -> ConstraintT:
1013+
def support(self) -> Constraint:
10141014
return self._support
10151015

10161016
@validate_sample

0 commit comments

Comments
 (0)