Skip to content

Conversation

Qazalbash
Copy link
Contributor

This PR contains the partial resolution of mypy errors passed by #2032.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Qazalbash! I thought we can merge your changes and iterate on later PRs. Turns out that there are many non-trivial changes so it would be nice to split this into several PRs, each for one script like continuous.py etc. (would prefer simple ones first).

Let's resolve the typing issues when you have bandwidth. Thanks again for this important contribution!

P = ParamSpec("P")
ModelT: TypeAlias = Callable[P, Any]

Message: TypeAlias = dict[str, Any]
TraceT: TypeAlias = OrderedDict[str, Message]
PRNGKeyT: TypeAlias = Union[jax.dtypes.prng_key, ArrayLike]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just use jax.Array for PRNGKey

from jax.typing import ArrayLike

from numpyro.distributions import MaskedDistribution
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would prefer not importing MaskedDistribution here

@@ -123,6 +125,7 @@ module = [
"numpyro.contrib.hsgp.*",
"numpyro.contrib.stochastic_support.*",
"numpyro.diagnostics.*",
"numpyro.distributions.*",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can comment out this and add TODO link to #2036

@@ -677,7 +676,7 @@ class scale(Messenger):
def __init__(
self,
fn: Optional[Callable] = None,
scale: ArrayLike = 1.0,
scale: Array = 1.0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: ArrayLike works for float

@@ -583,7 +582,7 @@ class mask(Messenger):
def __init__(
self,
fn: Optional[Callable] = None,
mask: Optional[ArrayLike] = True,
mask: Optional[Array] = True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: ArrayLike works for boolean

@@ -69,7 +67,7 @@


@singledispatch
def vmap_over(d: Union[Distribution, Transform, Constraint], **kwargs):
def vmap_over(d: DistributionT, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why removing other objects?

@Qazalbash
Copy link
Contributor Author

Thanks @Qazalbash! I thought we can merge your changes and iterate on later PRs. Turns out that there are many non-trivial changes so it would be nice to split this into several PRs, each for one script like continuous.py etc. (would prefer simple ones first).

Let's resolve the typing issues when you have bandwidth. Thanks again for this important contribution!

Thanks, I will make sure these issues get resolved as early as possible.

@fehiepsi fehiepsi added the WIP label Aug 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants