Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions numpyro/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,27 @@
NumLikeT = TypeVar("NumLikeT", bound=NumLike)


# ConstraintLike represents any constraint object (both Constraint[NumLike] and
# Constraint[NonScalarArray]). We use Any because:
# 1. Constraint[T] is a generic class, and mypy struggles with Protocol subtyping
# of generic classes due to variance issues
# 2. Some constraints only accept NonScalarArray while others accept full NumLike,
# making a single Protocol definition either too restrictive or too permissive
# 3. Creating a structural Protocol causes mypy errors when assigning concrete
# constraint instances to protocol-typed variables
# At runtime, all constraints share the interface (is_discrete, event_dim, __call__, etc.)
# and work correctly. This type alias provides better documentation than raw `Any`
# while acknowledging the limitations of Python's type system for this use case.
ConstraintLike: TypeAlias = Any
Comment on lines +43 to +54
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this defeating the purpose of why we are introducing the types?



@runtime_checkable
class ConstraintT(Protocol):
"""A protocol for typing constraints."""
"""A protocol for typing constraints that accept NumLike inputs.

This is a more specific protocol for constraints that can handle both
arrays and scalars (NumLike type).
"""

@property
def is_discrete(self) -> bool: ...
Expand Down Expand Up @@ -117,9 +135,9 @@ class TransformT(Protocol):
_inv: Optional[Union["TransformT", weakref.ref]] = ...

@property
def domain(self) -> ConstraintT: ...
def domain(self) -> ConstraintLike: ...
@property
def codomain(self) -> ConstraintT: ...
def codomain(self) -> ConstraintLike: ...
@property
def inv(self) -> "TransformT": ...
@property
Expand Down
43 changes: 24 additions & 19 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@ def is_discrete(self) -> bool:
def event_dim(self) -> int:
return self._event_dim

@is_discrete.setter # type: ignore[attr-defined]
@is_discrete.setter # type: ignore[attr-defined, no-redef]
def is_discrete(self, value: bool):
self._is_discrete = value

@event_dim.setter # type: ignore[attr-defined]
@event_dim.setter # type: ignore[attr-defined, no-redef]
def event_dim(self, value: int):
self._event_dim = value

Expand Down Expand Up @@ -283,7 +283,9 @@ def __init__(
self._is_discrete = is_discrete
self._event_dim = event_dim

def __call__(self, x: NumLikeT) -> ArrayLike:
def __call__( # type: ignore[override]
self, x: NumLikeT
) -> ArrayLike:
if not callable(x):
return super().__call__(x)

Expand Down Expand Up @@ -351,6 +353,9 @@ class _IndependentConstraint(Constraint[NumLikeT]):
independent entries are valid.
"""

base_constraint: ConstraintT
reinterpreted_batch_ndims: int

Comment on lines +356 to +358
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer a factory function over these changes. What are your thoughts?

def __init__(self, base_constraint: ConstraintT, reinterpreted_batch_ndims: int):
assert isinstance(base_constraint, Constraint)
assert isinstance(reinterpreted_batch_ndims, int)
Expand Down Expand Up @@ -394,7 +399,7 @@ def __repr__(self) -> str:
)

def feasible_like(self, prototype: NumLikeT) -> NumLikeT:
return self.base_constraint.feasible_like(prototype)
return self.base_constraint.feasible_like(prototype) # type: ignore[return-value]

def tree_flatten(self):
return (self.base_constraint,), (
Expand Down Expand Up @@ -846,9 +851,9 @@ def tree_flatten(self):
boolean: ConstraintT = _Boolean()
circular: ConstraintT = _Circular()
complex: ConstraintT = _Complex()
corr_cholesky: ConstraintT = _CorrCholesky()
corr_matrix: ConstraintT = _CorrMatrix()
dependent: ConstraintT = _Dependent()
corr_cholesky = _CorrCholesky()
corr_matrix = _CorrMatrix()
dependent: _Dependent = _Dependent()
greater_than = _GreaterThan
greater_than_eq = _GreaterThanEq
less_than = _LessThan
Expand All @@ -858,25 +863,25 @@ def tree_flatten(self):
integer_greater_than = _IntegerGreaterThan
interval = _Interval
l1_ball: ConstraintT = _L1Ball()
lower_cholesky: ConstraintT = _LowerCholesky()
scaled_unit_lower_cholesky: ConstraintT = _ScaledUnitLowerCholesky()
lower_cholesky = _LowerCholesky()
scaled_unit_lower_cholesky = _ScaledUnitLowerCholesky()
multinomial = _Multinomial
nonnegative: ConstraintT = _Nonnegative()
nonnegative_integer: ConstraintT = _IntegerNonnegative()
ordered_vector: ConstraintT = _OrderedVector()
ordered_vector = _OrderedVector()
positive: ConstraintT = _Positive()
positive_definite: ConstraintT = _PositiveDefinite()
positive_definite_circulant_vector: ConstraintT = _PositiveDefiniteCirculantVector()
positive_semidefinite: ConstraintT = _PositiveSemiDefinite()
positive_definite = _PositiveDefinite()
positive_definite_circulant_vector = _PositiveDefiniteCirculantVector()
positive_semidefinite = _PositiveSemiDefinite()
positive_integer: ConstraintT = _IntegerPositive()
positive_ordered_vector: ConstraintT = _PositiveOrderedVector()
positive_ordered_vector = _PositiveOrderedVector()
real: ConstraintT = _Real()
real_vector: ConstraintT = _RealVector()
real_matrix: ConstraintT = _RealMatrix()
simplex: ConstraintT = _Simplex()
softplus_lower_cholesky: ConstraintT = _SoftplusLowerCholesky()
real_vector = _RealVector()
real_matrix = _RealMatrix()
simplex = _Simplex()
softplus_lower_cholesky = _SoftplusLowerCholesky()
softplus_positive: ConstraintT = _SoftplusPositive()
sphere: ConstraintT = _Sphere()
sphere = _Sphere()
unit_interval: ConstraintT = _UnitInterval()
open_interval = _OpenInterval
zero_sum = _ZeroSum
43 changes: 22 additions & 21 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from jax.typing import ArrayLike

from numpyro._typing import (
ConstraintLike,
ConstraintT,
NonScalarArray,
NumLike,
Expand Down Expand Up @@ -76,11 +77,11 @@ class Transform(Generic[NumLikeT]):
_inv: Optional[Union[TransformT, weakref.ref]] = None

@property
def domain(self) -> ConstraintT:
def domain(self) -> ConstraintLike:
return constraints.real

@property
def codomain(self) -> ConstraintT:
def codomain(self) -> ConstraintLike:
return constraints.real

def __init_subclass__(cls, **kwargs):
Expand Down Expand Up @@ -170,11 +171,11 @@ def __init__(self, transform: TransformT):
self._inv: TransformT = transform

@property
def domain(self) -> ConstraintT:
def domain(self) -> ConstraintLike:
return self._inv.codomain

@property
def codomain(self) -> ConstraintT:
def codomain(self) -> ConstraintLike:
return self._inv.domain

@property
Expand Down Expand Up @@ -242,11 +243,11 @@ def __init__(
self._domain = domain

@property
def domain(self) -> ConstraintT:
def domain(self) -> ConstraintLike:
return self._domain

@property
def codomain(self) -> ConstraintT:
def codomain(self) -> ConstraintLike:
if self.domain is constraints.real:
return constraints.real
elif isinstance(self.domain, constraints.greater_than):
Expand Down Expand Up @@ -338,7 +339,7 @@ def __init__(self, parts: Sequence[TransformT]) -> None:
self.parts = parts

@property
def domain(self) -> ConstraintT:
def domain(self) -> ConstraintLike:
input_event_dim = _get_compose_transform_input_event_dim(self.parts)
first_input_event_dim = self.parts[0].domain.event_dim
assert input_event_dim >= first_input_event_dim
Expand All @@ -353,7 +354,7 @@ def domain(self) -> ConstraintT:
)

@property
def codomain(self) -> ConstraintT:
def codomain(self) -> ConstraintLike:
output_event_dim = _get_compose_transform_output_event_dim(self.parts)
last_output_event_dim = self.parts[-1].codomain.event_dim
assert output_event_dim >= last_output_event_dim
Expand Down Expand Up @@ -575,8 +576,8 @@ class CorrMatrixCholeskyTransform(CholeskyTransform):
correlation matrix.
"""

domain = constraints.corr_matrix
codomain = constraints.corr_cholesky
domain: ConstraintLike = constraints.corr_matrix
codomain: ConstraintLike = constraints.corr_cholesky

def log_abs_det_jacobian(
self,
Expand All @@ -599,11 +600,11 @@ def __init__(self, domain: ConstraintT = constraints.real):
self._domain = domain

@property
def domain(self) -> ConstraintT:
def domain(self) -> ConstraintLike:
return self._domain

@property
def codomain(self) -> ConstraintT:
def codomain(self) -> ConstraintLike:
if self.domain is constraints.ordered_vector:
return constraints.positive_ordered_vector
elif self.domain is constraints.real:
Expand Down Expand Up @@ -672,7 +673,7 @@ def __init__(
super().__init__()

@property
def domain(self) -> ConstraintT:
def domain(self) -> ConstraintLike:
return cast(
ConstraintT,
constraints.independent(
Expand All @@ -681,7 +682,7 @@ def domain(self) -> ConstraintT:
)

@property
def codomain(self) -> ConstraintT:
def codomain(self) -> ConstraintLike:
return cast(
ConstraintT,
constraints.independent(
Expand Down Expand Up @@ -1361,14 +1362,14 @@ def __init__(
self._inverse_shape = inverse_shape

@property
def domain(self) -> ConstraintT:
def domain(self) -> ConstraintLike:
return cast(
ConstraintT,
constraints.independent(constraints.real, len(self._inverse_shape)),
)

@property
def codomain(self) -> ConstraintT:
def codomain(self) -> ConstraintLike:
return cast(
ConstraintT,
constraints.independent(constraints.real, len(self._forward_shape)),
Expand Down Expand Up @@ -1485,13 +1486,13 @@ def tree_flatten(self):
return (), ((), aux_data)

@property
def domain(self) -> ConstraintT:
def domain(self) -> ConstraintLike:
return cast(
ConstraintT, constraints.independent(constraints.real, self.transform_ndims)
)

@property
def codomain(self) -> ConstraintT:
def codomain(self) -> ConstraintLike:
return cast(
ConstraintT,
constraints.independent(constraints.complex, self.transform_ndims),
Expand All @@ -1513,7 +1514,7 @@ class PackRealFastFourierCoefficientsTransform(Transform[NonScalarArray]):
"""

domain = constraints.real_vector
codomain = cast(ConstraintT, constraints.independent(constraints.complex, 1))
codomain = cast(ConstraintLike, constraints.independent(constraints.complex, 1))

def __init__(self, transform_shape: Optional[tuple[int, ...]] = None) -> None:
assert transform_shape is None or len(transform_shape) == 1, (
Expand Down Expand Up @@ -1729,13 +1730,13 @@ def __init__(self, transform_ndims: int = 1) -> None:
self.transform_ndims = transform_ndims

@property
def domain(self) -> ConstraintT:
def domain(self) -> ConstraintLike:
return cast(
ConstraintT, constraints.independent(constraints.real, self.transform_ndims)
)

@property
def codomain(self) -> ConstraintT:
def codomain(self) -> ConstraintLike:
return cast(ConstraintT, constraints.zero_sum(self.transform_ndims))

def __call__(self, x: NonScalarArray) -> NonScalarArray:
Expand Down