diff --git a/numpyro/_typing.py b/numpyro/_typing.py index 184063dd2..0c9f490b0 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -40,9 +40,27 @@ NumLikeT = TypeVar("NumLikeT", bound=NumLike) +# ConstraintLike represents any constraint object (both Constraint[NumLike] and +# Constraint[NonScalarArray]). We use Any because: +# 1. Constraint[T] is a generic class, and mypy struggles with Protocol subtyping +# of generic classes due to variance issues +# 2. Some constraints only accept NonScalarArray while others accept full NumLike, +# making a single Protocol definition either too restrictive or too permissive +# 3. Creating a structural Protocol causes mypy errors when assigning concrete +# constraint instances to protocol-typed variables +# At runtime, all constraints share the interface (is_discrete, event_dim, __call__, etc.) +# and work correctly. This type alias provides better documentation than raw `Any` +# while acknowledging the limitations of Python's type system for this use case. +ConstraintLike: TypeAlias = Any + + @runtime_checkable class ConstraintT(Protocol): - """A protocol for typing constraints.""" + """A protocol for typing constraints that accept NumLike inputs. + + This is a more specific protocol for constraints that can handle both + arrays and scalars (NumLike type). + """ @property def is_discrete(self) -> bool: ... @@ -117,9 +135,9 @@ class TransformT(Protocol): _inv: Optional[Union["TransformT", weakref.ref]] = ... @property - def domain(self) -> ConstraintT: ... + def domain(self) -> ConstraintLike: ... @property - def codomain(self) -> ConstraintT: ... + def codomain(self) -> ConstraintLike: ... @property def inv(self) -> "TransformT": ... @property diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index ef99b9750..afe65eea2 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -118,11 +118,11 @@ def is_discrete(self) -> bool: def event_dim(self) -> int: return self._event_dim - @is_discrete.setter # type: ignore[attr-defined] + @is_discrete.setter # type: ignore[attr-defined, no-redef] def is_discrete(self, value: bool): self._is_discrete = value - @event_dim.setter # type: ignore[attr-defined] + @event_dim.setter # type: ignore[attr-defined, no-redef] def event_dim(self, value: int): self._event_dim = value @@ -283,7 +283,9 @@ def __init__( self._is_discrete = is_discrete self._event_dim = event_dim - def __call__(self, x: NumLikeT) -> ArrayLike: + def __call__( # type: ignore[override] + self, x: NumLikeT + ) -> ArrayLike: if not callable(x): return super().__call__(x) @@ -351,6 +353,9 @@ class _IndependentConstraint(Constraint[NumLikeT]): independent entries are valid. """ + base_constraint: ConstraintT + reinterpreted_batch_ndims: int + def __init__(self, base_constraint: ConstraintT, reinterpreted_batch_ndims: int): assert isinstance(base_constraint, Constraint) assert isinstance(reinterpreted_batch_ndims, int) @@ -394,7 +399,7 @@ def __repr__(self) -> str: ) def feasible_like(self, prototype: NumLikeT) -> NumLikeT: - return self.base_constraint.feasible_like(prototype) + return self.base_constraint.feasible_like(prototype) # type: ignore[return-value] def tree_flatten(self): return (self.base_constraint,), ( @@ -846,9 +851,9 @@ def tree_flatten(self): boolean: ConstraintT = _Boolean() circular: ConstraintT = _Circular() complex: ConstraintT = _Complex() -corr_cholesky: ConstraintT = _CorrCholesky() -corr_matrix: ConstraintT = _CorrMatrix() -dependent: ConstraintT = _Dependent() +corr_cholesky = _CorrCholesky() +corr_matrix = _CorrMatrix() +dependent: _Dependent = _Dependent() greater_than = _GreaterThan greater_than_eq = _GreaterThanEq less_than = _LessThan @@ -858,25 +863,25 @@ def tree_flatten(self): integer_greater_than = _IntegerGreaterThan interval = _Interval l1_ball: ConstraintT = _L1Ball() -lower_cholesky: ConstraintT = _LowerCholesky() -scaled_unit_lower_cholesky: ConstraintT = _ScaledUnitLowerCholesky() +lower_cholesky = _LowerCholesky() +scaled_unit_lower_cholesky = _ScaledUnitLowerCholesky() multinomial = _Multinomial nonnegative: ConstraintT = _Nonnegative() nonnegative_integer: ConstraintT = _IntegerNonnegative() -ordered_vector: ConstraintT = _OrderedVector() +ordered_vector = _OrderedVector() positive: ConstraintT = _Positive() -positive_definite: ConstraintT = _PositiveDefinite() -positive_definite_circulant_vector: ConstraintT = _PositiveDefiniteCirculantVector() -positive_semidefinite: ConstraintT = _PositiveSemiDefinite() +positive_definite = _PositiveDefinite() +positive_definite_circulant_vector = _PositiveDefiniteCirculantVector() +positive_semidefinite = _PositiveSemiDefinite() positive_integer: ConstraintT = _IntegerPositive() -positive_ordered_vector: ConstraintT = _PositiveOrderedVector() +positive_ordered_vector = _PositiveOrderedVector() real: ConstraintT = _Real() -real_vector: ConstraintT = _RealVector() -real_matrix: ConstraintT = _RealMatrix() -simplex: ConstraintT = _Simplex() -softplus_lower_cholesky: ConstraintT = _SoftplusLowerCholesky() +real_vector = _RealVector() +real_matrix = _RealMatrix() +simplex = _Simplex() +softplus_lower_cholesky = _SoftplusLowerCholesky() softplus_positive: ConstraintT = _SoftplusPositive() -sphere: ConstraintT = _Sphere() +sphere = _Sphere() unit_interval: ConstraintT = _UnitInterval() open_interval = _OpenInterval zero_sum = _ZeroSum diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 11cf179e0..e7121b45a 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -19,6 +19,7 @@ from jax.typing import ArrayLike from numpyro._typing import ( + ConstraintLike, ConstraintT, NonScalarArray, NumLike, @@ -76,11 +77,11 @@ class Transform(Generic[NumLikeT]): _inv: Optional[Union[TransformT, weakref.ref]] = None @property - def domain(self) -> ConstraintT: + def domain(self) -> ConstraintLike: return constraints.real @property - def codomain(self) -> ConstraintT: + def codomain(self) -> ConstraintLike: return constraints.real def __init_subclass__(cls, **kwargs): @@ -170,11 +171,11 @@ def __init__(self, transform: TransformT): self._inv: TransformT = transform @property - def domain(self) -> ConstraintT: + def domain(self) -> ConstraintLike: return self._inv.codomain @property - def codomain(self) -> ConstraintT: + def codomain(self) -> ConstraintLike: return self._inv.domain @property @@ -242,11 +243,11 @@ def __init__( self._domain = domain @property - def domain(self) -> ConstraintT: + def domain(self) -> ConstraintLike: return self._domain @property - def codomain(self) -> ConstraintT: + def codomain(self) -> ConstraintLike: if self.domain is constraints.real: return constraints.real elif isinstance(self.domain, constraints.greater_than): @@ -338,7 +339,7 @@ def __init__(self, parts: Sequence[TransformT]) -> None: self.parts = parts @property - def domain(self) -> ConstraintT: + def domain(self) -> ConstraintLike: input_event_dim = _get_compose_transform_input_event_dim(self.parts) first_input_event_dim = self.parts[0].domain.event_dim assert input_event_dim >= first_input_event_dim @@ -353,7 +354,7 @@ def domain(self) -> ConstraintT: ) @property - def codomain(self) -> ConstraintT: + def codomain(self) -> ConstraintLike: output_event_dim = _get_compose_transform_output_event_dim(self.parts) last_output_event_dim = self.parts[-1].codomain.event_dim assert output_event_dim >= last_output_event_dim @@ -575,8 +576,8 @@ class CorrMatrixCholeskyTransform(CholeskyTransform): correlation matrix. """ - domain = constraints.corr_matrix - codomain = constraints.corr_cholesky + domain: ConstraintLike = constraints.corr_matrix + codomain: ConstraintLike = constraints.corr_cholesky def log_abs_det_jacobian( self, @@ -599,11 +600,11 @@ def __init__(self, domain: ConstraintT = constraints.real): self._domain = domain @property - def domain(self) -> ConstraintT: + def domain(self) -> ConstraintLike: return self._domain @property - def codomain(self) -> ConstraintT: + def codomain(self) -> ConstraintLike: if self.domain is constraints.ordered_vector: return constraints.positive_ordered_vector elif self.domain is constraints.real: @@ -672,7 +673,7 @@ def __init__( super().__init__() @property - def domain(self) -> ConstraintT: + def domain(self) -> ConstraintLike: return cast( ConstraintT, constraints.independent( @@ -681,7 +682,7 @@ def domain(self) -> ConstraintT: ) @property - def codomain(self) -> ConstraintT: + def codomain(self) -> ConstraintLike: return cast( ConstraintT, constraints.independent( @@ -1361,14 +1362,14 @@ def __init__( self._inverse_shape = inverse_shape @property - def domain(self) -> ConstraintT: + def domain(self) -> ConstraintLike: return cast( ConstraintT, constraints.independent(constraints.real, len(self._inverse_shape)), ) @property - def codomain(self) -> ConstraintT: + def codomain(self) -> ConstraintLike: return cast( ConstraintT, constraints.independent(constraints.real, len(self._forward_shape)), @@ -1485,13 +1486,13 @@ def tree_flatten(self): return (), ((), aux_data) @property - def domain(self) -> ConstraintT: + def domain(self) -> ConstraintLike: return cast( ConstraintT, constraints.independent(constraints.real, self.transform_ndims) ) @property - def codomain(self) -> ConstraintT: + def codomain(self) -> ConstraintLike: return cast( ConstraintT, constraints.independent(constraints.complex, self.transform_ndims), @@ -1513,7 +1514,7 @@ class PackRealFastFourierCoefficientsTransform(Transform[NonScalarArray]): """ domain = constraints.real_vector - codomain = cast(ConstraintT, constraints.independent(constraints.complex, 1)) + codomain = cast(ConstraintLike, constraints.independent(constraints.complex, 1)) def __init__(self, transform_shape: Optional[tuple[int, ...]] = None) -> None: assert transform_shape is None or len(transform_shape) == 1, ( @@ -1729,13 +1730,13 @@ def __init__(self, transform_ndims: int = 1) -> None: self.transform_ndims = transform_ndims @property - def domain(self) -> ConstraintT: + def domain(self) -> ConstraintLike: return cast( ConstraintT, constraints.independent(constraints.real, self.transform_ndims) ) @property - def codomain(self) -> ConstraintT: + def codomain(self) -> ConstraintLike: return cast(ConstraintT, constraints.zero_sum(self.transform_ndims)) def __call__(self, x: NonScalarArray) -> NonScalarArray: