-
Notifications
You must be signed in to change notification settings - Fork 267
fix(gh-2036): MyPy Errors in Distributions Module #2050
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Qazalbash
wants to merge
3
commits into
pyro-ppl:master
Choose a base branch
from
Qazalbash:issue-2036
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,29 +4,39 @@ | |
|
||
from collections import OrderedDict | ||
from collections.abc import Callable | ||
from typing import Any, Protocol, runtime_checkable | ||
from typing import Any, Optional, Protocol, Union, runtime_checkable | ||
|
||
from typing_extensions import ParamSpec, TypeAlias | ||
|
||
import jax | ||
from jax import Array | ||
from jax.typing import ArrayLike | ||
|
||
from numpyro.distributions import MaskedDistribution | ||
|
||
P = ParamSpec("P") | ||
ModelT: TypeAlias = Callable[P, Any] | ||
|
||
Message: TypeAlias = dict[str, Any] | ||
TraceT: TypeAlias = OrderedDict[str, Message] | ||
PRNGKeyT: TypeAlias = Union[jax.dtypes.prng_key, ArrayLike] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can just use jax.Array for PRNGKey |
||
|
||
|
||
@runtime_checkable | ||
class ConstraintT(Protocol): | ||
is_discrete: bool = ... | ||
event_dim: int = ... | ||
# is_discrete: bool = ... | ||
# event_dim: int = ... | ||
|
||
def __call__(self, x: ArrayLike) -> ArrayLike: ... | ||
def __init__(self, *args: Any, **kwargs: Any) -> None: ... | ||
def __call__(self, x: Array) -> Array: ... | ||
def __repr__(self) -> str: ... | ||
def check(self, value: ArrayLike) -> ArrayLike: ... | ||
def feasible_like(self, prototype: ArrayLike) -> ArrayLike: ... | ||
def check(self, value: Array) -> Array: ... | ||
def feasible_like(self, prototype: Array) -> Array: ... | ||
|
||
@property | ||
def is_discrete(self) -> bool: ... | ||
@property | ||
def event_dim(self) -> int: ... | ||
|
||
|
||
@runtime_checkable | ||
|
@@ -38,27 +48,35 @@ class DistributionT(Protocol): | |
""" | ||
|
||
arg_constraints: dict[str, ConstraintT] = ... | ||
support: ConstraintT = ... | ||
has_enumerate_support: bool = ... | ||
reparametrized_params: list[str] = ... | ||
_validate_args: bool = ... | ||
pytree_data_fields: tuple = ... | ||
pytree_aux_fields: tuple = ... | ||
|
||
def __call__(self, *args: Any, **kwargs: Any) -> Any: ... | ||
|
||
def rsample( | ||
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () | ||
) -> ArrayLike: ... | ||
self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () | ||
) -> Array: ... | ||
def sample( | ||
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () | ||
) -> ArrayLike: ... | ||
def log_prob(self, value: ArrayLike) -> ArrayLike: ... | ||
def cdf(self, value: ArrayLike) -> ArrayLike: ... | ||
def icdf(self, q: ArrayLike) -> ArrayLike: ... | ||
def entropy(self) -> ArrayLike: ... | ||
def enumerate_support(self, expand: bool = True) -> ArrayLike: ... | ||
self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = () | ||
) -> Array: ... | ||
def log_prob(self, value: Array) -> Array: ... | ||
def cdf(self, value: Array) -> Array: ... | ||
def icdf(self, q: Array) -> Array: ... | ||
def entropy(self) -> Array: ... | ||
def enumerate_support(self, expand: bool = True) -> Array: ... | ||
def shape(self, sample_shape: tuple[int, ...] = ()) -> tuple[int, ...]: ... | ||
def to_event( | ||
self, reinterpreted_batch_ndims: Optional[int] = None | ||
) -> "DistributionT": ... | ||
def expand(self, batch_shape: tuple[int, ...]) -> "DistributionT": ... | ||
def expand_by(self, sample_shape: tuple[int, ...]) -> "DistributionT": ... | ||
def mask(self, mask: Array) -> MaskedDistribution: ... | ||
@classmethod | ||
def infer_shapes(cls, *args, **kwargs): ... | ||
|
||
@property | ||
def support(self) -> ConstraintT: ... | ||
|
||
@property | ||
def batch_shape(self) -> tuple[int, ...]: ... | ||
|
@@ -76,6 +94,8 @@ def variance(self) -> ArrayLike: ... | |
|
||
@property | ||
def is_discrete(self) -> bool: ... | ||
@property | ||
def has_enumerate_support(self) -> bool: ... | ||
|
||
|
||
# To avoid breaking changes for user code that uses `DistributionLike` | ||
|
@@ -84,20 +104,18 @@ def is_discrete(self) -> bool: ... | |
|
||
@runtime_checkable | ||
class TransformT(Protocol): | ||
domain = ConstraintT | ||
codomain = ConstraintT | ||
_inv: "TransformT" = None | ||
|
||
def __call__(self, x: ArrayLike) -> ArrayLike: ... | ||
def _inverse(self, y: ArrayLike) -> ArrayLike: ... | ||
def log_abs_det_jacobian( | ||
self, x: ArrayLike, y: ArrayLike, intermediates=None | ||
) -> ArrayLike: ... | ||
def call_with_intermediates(self, x: ArrayLike) -> tuple[ArrayLike, None]: ... | ||
domain: ConstraintT = ... | ||
codomain: ConstraintT = ... | ||
_inv: Optional["TransformT"] = None | ||
|
||
def __call__(self, x: Array) -> Array: ... | ||
def _inverse(self, y: Array) -> Array: ... | ||
def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: ... | ||
def call_with_intermediates(self, x: Array) -> tuple[Array, Optional[Array]]: ... | ||
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... | ||
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... | ||
|
||
@property | ||
def inv(self) -> "TransformT": ... | ||
def inv(self) -> Optional["TransformT"]: ... | ||
@property | ||
def sign(self) -> ArrayLike: ... | ||
def sign(self) -> Array: ... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,11 +3,11 @@ | |
|
||
import copy | ||
from functools import singledispatch | ||
from typing import Union | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
||
from numpyro._typing import DistributionT | ||
from numpyro.distributions import constraints | ||
from numpyro.distributions.conjugate import ( | ||
BetaBinomial, | ||
|
@@ -17,7 +17,6 @@ | |
NegativeBinomialLogits, | ||
NegativeBinomialProbs, | ||
) | ||
from numpyro.distributions.constraints import Constraint | ||
from numpyro.distributions.continuous import ( | ||
CAR, | ||
LKJ, | ||
|
@@ -59,7 +58,6 @@ | |
AffineTransform, | ||
CorrCholeskyTransform, | ||
PowerTransform, | ||
Transform, | ||
) | ||
from numpyro.distributions.truncated import ( | ||
LeftTruncatedDistribution, | ||
|
@@ -69,7 +67,7 @@ | |
|
||
|
||
@singledispatch | ||
def vmap_over(d: Union[Distribution, Transform, Constraint], **kwargs): | ||
def vmap_over(d: DistributionT, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why removing other objects? |
||
raise NotImplementedError | ||
|
||
|
||
|
@@ -498,12 +496,12 @@ def _vmap_over_half_normal(dist: HalfNormal, scale=None): | |
|
||
|
||
@singledispatch | ||
def promote_batch_shape(d: Distribution): | ||
def promote_batch_shape(d: DistributionT) -> DistributionT: | ||
raise NotImplementedError | ||
|
||
|
||
@promote_batch_shape.register | ||
def _default_promote_batch_shape(d: Distribution): | ||
def _default_promote_batch_shape(d: DistributionT) -> DistributionT: | ||
attr_batch_shapes = [d.batch_shape] | ||
for attr_name, constraint in d.arg_constraints.items(): | ||
try: | ||
|
@@ -515,12 +513,12 @@ def _default_promote_batch_shape(d: Distribution): | |
attr_batch_shapes.append(jnp.shape(attr)[:attr_batch_ndim]) | ||
resolved_batch_shape = jnp.broadcast_shapes(*attr_batch_shapes) | ||
new_self = copy.deepcopy(d) | ||
new_self._batch_shape = resolved_batch_shape | ||
new_self._batch_shape = resolved_batch_shape # type: ignore | ||
return new_self | ||
|
||
|
||
@promote_batch_shape.register | ||
def _promote_batch_shape_expanded(d: ExpandedDistribution): | ||
def _promote_batch_shape_expanded(d: ExpandedDistribution) -> ExpandedDistribution: | ||
orig_delta_batch_shape = d.batch_shape[ | ||
: len(d.batch_shape) - len(d.base_dist.batch_shape) | ||
] | ||
|
@@ -560,7 +558,7 @@ def _promote_batch_shape_expanded(d: ExpandedDistribution): | |
|
||
|
||
@promote_batch_shape.register | ||
def _promote_batch_shape_masked(d: MaskedDistribution): | ||
def _promote_batch_shape_masked(d: MaskedDistribution) -> MaskedDistribution: | ||
new_self = copy.copy(d) | ||
new_base_dist = promote_batch_shape(d.base_dist) | ||
new_self._batch_shape = new_base_dist.batch_shape | ||
|
@@ -569,7 +567,7 @@ def _promote_batch_shape_masked(d: MaskedDistribution): | |
|
||
|
||
@promote_batch_shape.register | ||
def _promote_batch_shape_independent(d: Independent): | ||
def _promote_batch_shape_independent(d: Independent) -> DistributionT: | ||
new_self = copy.copy(d) | ||
new_base_dist = promote_batch_shape(d.base_dist) | ||
new_self._batch_shape = new_base_dist.batch_shape[: -d.event_dim] | ||
|
@@ -578,5 +576,5 @@ def _promote_batch_shape_independent(d: Independent): | |
|
||
|
||
@promote_batch_shape.register | ||
def _promote_batch_shape_unit(d: Unit): | ||
def _promote_batch_shape_unit(d: Unit) -> Unit: | ||
return d |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would prefer not importing MaskedDistribution here