Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 48 additions & 30 deletions numpyro/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,39 @@

from collections import OrderedDict
from collections.abc import Callable
from typing import Any, Protocol, runtime_checkable
from typing import Any, Optional, Protocol, Union, runtime_checkable

from typing_extensions import ParamSpec, TypeAlias

import jax
from jax import Array
from jax.typing import ArrayLike

from numpyro.distributions import MaskedDistribution
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would prefer not importing MaskedDistribution here


P = ParamSpec("P")
ModelT: TypeAlias = Callable[P, Any]

Message: TypeAlias = dict[str, Any]
TraceT: TypeAlias = OrderedDict[str, Message]
PRNGKeyT: TypeAlias = Union[jax.dtypes.prng_key, ArrayLike]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just use jax.Array for PRNGKey



@runtime_checkable
class ConstraintT(Protocol):
is_discrete: bool = ...
event_dim: int = ...
# is_discrete: bool = ...
# event_dim: int = ...

def __call__(self, x: ArrayLike) -> ArrayLike: ...
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
def __call__(self, x: Array) -> Array: ...
def __repr__(self) -> str: ...
def check(self, value: ArrayLike) -> ArrayLike: ...
def feasible_like(self, prototype: ArrayLike) -> ArrayLike: ...
def check(self, value: Array) -> Array: ...
def feasible_like(self, prototype: Array) -> Array: ...

@property
def is_discrete(self) -> bool: ...
@property
def event_dim(self) -> int: ...


@runtime_checkable
Expand All @@ -38,27 +48,35 @@ class DistributionT(Protocol):
"""

arg_constraints: dict[str, ConstraintT] = ...
support: ConstraintT = ...
has_enumerate_support: bool = ...
reparametrized_params: list[str] = ...
_validate_args: bool = ...
pytree_data_fields: tuple = ...
pytree_aux_fields: tuple = ...

def __call__(self, *args: Any, **kwargs: Any) -> Any: ...

def rsample(
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = ()
) -> ArrayLike: ...
self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = ()
) -> Array: ...
def sample(
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = ()
) -> ArrayLike: ...
def log_prob(self, value: ArrayLike) -> ArrayLike: ...
def cdf(self, value: ArrayLike) -> ArrayLike: ...
def icdf(self, q: ArrayLike) -> ArrayLike: ...
def entropy(self) -> ArrayLike: ...
def enumerate_support(self, expand: bool = True) -> ArrayLike: ...
self, key: Optional[PRNGKeyT], sample_shape: tuple[int, ...] = ()
) -> Array: ...
def log_prob(self, value: Array) -> Array: ...
def cdf(self, value: Array) -> Array: ...
def icdf(self, q: Array) -> Array: ...
def entropy(self) -> Array: ...
def enumerate_support(self, expand: bool = True) -> Array: ...
def shape(self, sample_shape: tuple[int, ...] = ()) -> tuple[int, ...]: ...
def to_event(
self, reinterpreted_batch_ndims: Optional[int] = None
) -> "DistributionT": ...
def expand(self, batch_shape: tuple[int, ...]) -> "DistributionT": ...
def expand_by(self, sample_shape: tuple[int, ...]) -> "DistributionT": ...
def mask(self, mask: Array) -> MaskedDistribution: ...
@classmethod
def infer_shapes(cls, *args, **kwargs): ...

@property
def support(self) -> ConstraintT: ...

@property
def batch_shape(self) -> tuple[int, ...]: ...
Expand All @@ -76,6 +94,8 @@ def variance(self) -> ArrayLike: ...

@property
def is_discrete(self) -> bool: ...
@property
def has_enumerate_support(self) -> bool: ...


# To avoid breaking changes for user code that uses `DistributionLike`
Expand All @@ -84,20 +104,18 @@ def is_discrete(self) -> bool: ...

@runtime_checkable
class TransformT(Protocol):
domain = ConstraintT
codomain = ConstraintT
_inv: "TransformT" = None

def __call__(self, x: ArrayLike) -> ArrayLike: ...
def _inverse(self, y: ArrayLike) -> ArrayLike: ...
def log_abs_det_jacobian(
self, x: ArrayLike, y: ArrayLike, intermediates=None
) -> ArrayLike: ...
def call_with_intermediates(self, x: ArrayLike) -> tuple[ArrayLike, None]: ...
domain: ConstraintT = ...
codomain: ConstraintT = ...
_inv: Optional["TransformT"] = None

def __call__(self, x: Array) -> Array: ...
def _inverse(self, y: Array) -> Array: ...
def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None) -> Array: ...
def call_with_intermediates(self, x: Array) -> tuple[Array, Optional[Array]]: ...
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ...
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ...

@property
def inv(self) -> "TransformT": ...
def inv(self) -> Optional["TransformT"]: ...
@property
def sign(self) -> ArrayLike: ...
def sign(self) -> Array: ...
13 changes: 6 additions & 7 deletions numpyro/contrib/hsgp/approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from jax import Array
import jax.numpy as jnp
from jax.typing import ArrayLike

import numpyro
from numpyro.contrib.hsgp.laplacian import eigenfunctions, eigenfunctions_periodic
Expand Down Expand Up @@ -62,7 +61,7 @@ def linear_approximation(


def hsgp_squared_exponential(
x: ArrayLike,
x: Array,
alpha: float,
length: float,
ell: float | int | list[float | int],
Expand All @@ -84,7 +83,7 @@ def hsgp_squared_exponential(
2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

:param ArrayLike x: input data
:param Array x: input data
:param float alpha: amplitude of the squared exponential kernel
:param float length: length scale of the squared exponential kernel
:param float | int | list[float | int] ell: positive value that parametrizes the length of the D-dimensional box so
Expand All @@ -110,7 +109,7 @@ def hsgp_squared_exponential(


def hsgp_matern(
x: ArrayLike,
x: Array,
nu: float,
alpha: float,
length: float,
Expand All @@ -133,7 +132,7 @@ def hsgp_matern(
2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

:param ArrayLike x: input data
:param Array x: input data
:param float nu: smoothness parameter
:param float alpha: amplitude of the squared exponential kernel
:param float length: length scale of the squared exponential kernel
Expand All @@ -160,7 +159,7 @@ def hsgp_matern(


def hsgp_periodic_non_centered(
x: ArrayLike, alpha: float, length: float, w0: float, m: int
x: Array, alpha: float, length: float, w0: float, m: int
) -> Array:
"""
Low rank approximation for the periodic squared exponential kernel in the non-centered parametrization.
Expand All @@ -172,7 +171,7 @@ def hsgp_periodic_non_centered(
1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

:param ArrayLike x: input data
:param Array x: input data
:param float alpha: amplitude
:param float length: length scale
:param float w0: frequency of the periodic kernel
Expand Down
15 changes: 7 additions & 8 deletions numpyro/contrib/hsgp/laplacian.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from jax import Array
import jax.numpy as jnp
from jax.typing import ArrayLike


def eigenindices(m: list[int] | int, dim: int) -> Array:
Expand Down Expand Up @@ -76,7 +75,7 @@ def eigenindices(m: list[int] | int, dim: int) -> Array:


def sqrt_eigenvalues(
ell: ArrayLike | list[int | float], m: list[int] | int, dim: int
ell: Array | list[int | float], m: list[int] | int, dim: int
) -> Array:
"""
The first :math:`m^\\star \\times D` square root of eigenvalues of the laplacian operator in
Expand All @@ -101,7 +100,7 @@ def sqrt_eigenvalues(
return S * jnp.pi / 2 / ell_ # dim x prod(m) array of eigenvalues


def eigenfunctions(x: ArrayLike, ell: float | list[float], m: int | list[int]) -> Array:
def eigenfunctions(x: Array, ell: float | list[float], m: int | list[int]) -> Array:
"""
The first :math:`m^\\star` eigenfunctions of the laplacian operator in
:math:`[-L_1, L_1] \\times ... \\times [-L_D, L_D]`
Expand Down Expand Up @@ -137,7 +136,7 @@ def eigenfunctions(x: ArrayLike, ell: float | list[float], m: int | list[int]) -
1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression.
Stat Comput 30, 419-446 (2020)

:param ArrayLike x: The points at which to evaluate the eigenfunctions.
:param Array x: The points at which to evaluate the eigenfunctions.
If `x` is 1D the problem is assumed unidimensional.
Otherwise, the dimension of the input space is inferred as the last dimension of `x`.
Other dimensions are treated as batch dimensions.
Expand All @@ -162,11 +161,11 @@ def eigenfunctions(x: ArrayLike, ell: float | list[float], m: int | list[int]) -
)


def eigenfunctions_periodic(x: ArrayLike, w0: float, m: int) -> tuple[Array, Array]:
def eigenfunctions_periodic(x: Array, w0: float, m: int) -> tuple[Array, Array]:
"""
Basis functions for the approximation of the periodic kernel.

:param ArrayLike x: The points at which to evaluate the eigenfunctions.
:param Array x: The points at which to evaluate the eigenfunctions.
:param float w0: The frequency of the periodic kernel.
:param int m: The number of eigenfunctions to compute.

Expand All @@ -188,13 +187,13 @@ def eigenfunctions_periodic(x: ArrayLike, w0: float, m: int) -> tuple[Array, Arr
return cosines, sines


def _convert_ell(ell: float | int | list[float | int] | ArrayLike, dim: int) -> Array:
def _convert_ell(ell: float | int | list[float | int] | Array, dim: int) -> Array:
"""
Process the half-length of the approximation interval and return a `D \\times 1` array.

If `ell` is a scalar, it is converted to a list of length dim, then transformed into an Array.

:param float | int | list[float | int] | ArrayLike ell: The length of the interval in each dimension divided by 2.
:param float | int | list[float | int] | Array ell: The length of the interval in each dimension divided by 2.
If a float or int, the same length is used in each dimension.
:param int dim: The dimension of the space.

Expand Down
9 changes: 4 additions & 5 deletions numpyro/contrib/hsgp/spectral_densities.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from jax import Array, vmap
import jax.numpy as jnp
from jax.scipy import special
from jax.typing import ArrayLike

from numpyro.contrib.hsgp.laplacian import sqrt_eigenvalues

Expand All @@ -20,7 +19,7 @@ def align_param(dim, param):


def spectral_density_squared_exponential(
dim: int, w: ArrayLike, alpha: float, length: float | ArrayLike
dim: int, w: Array, alpha: float, length: float | Array
) -> Array:
"""
Spectral density of the squared exponential kernel.
Expand All @@ -41,7 +40,7 @@ def spectral_density_squared_exponential(
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

:param int dim: dimension
:param ArrayLike w: frequency
:param Array w: frequency
:param float alpha: amplitude
:param float length: length scale
:return: spectral density value
Expand All @@ -54,7 +53,7 @@ def spectral_density_squared_exponential(


def spectral_density_matern(
dim: int, nu: float, w: ArrayLike, alpha: float, length: float | ArrayLike
dim: int, nu: float, w: Array, alpha: float, length: float | Array
) -> float:
"""
Spectral density of the Matérn kernel.
Expand All @@ -77,7 +76,7 @@ def spectral_density_matern(

:param int dim: dimension
:param float nu: smoothness
:param ArrayLike w: frequency
:param Array w: frequency
:param float alpha: amplitude
:param float length: length scale
:return: spectral density value
Expand Down
20 changes: 9 additions & 11 deletions numpyro/distributions/batch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import copy
from functools import singledispatch
from typing import Union

import jax
import jax.numpy as jnp

from numpyro._typing import DistributionT
from numpyro.distributions import constraints
from numpyro.distributions.conjugate import (
BetaBinomial,
Expand All @@ -17,7 +17,6 @@
NegativeBinomialLogits,
NegativeBinomialProbs,
)
from numpyro.distributions.constraints import Constraint
from numpyro.distributions.continuous import (
CAR,
LKJ,
Expand Down Expand Up @@ -59,7 +58,6 @@
AffineTransform,
CorrCholeskyTransform,
PowerTransform,
Transform,
)
from numpyro.distributions.truncated import (
LeftTruncatedDistribution,
Expand All @@ -69,7 +67,7 @@


@singledispatch
def vmap_over(d: Union[Distribution, Transform, Constraint], **kwargs):
def vmap_over(d: DistributionT, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why removing other objects?

raise NotImplementedError


Expand Down Expand Up @@ -498,12 +496,12 @@ def _vmap_over_half_normal(dist: HalfNormal, scale=None):


@singledispatch
def promote_batch_shape(d: Distribution):
def promote_batch_shape(d: DistributionT) -> DistributionT:
raise NotImplementedError


@promote_batch_shape.register
def _default_promote_batch_shape(d: Distribution):
def _default_promote_batch_shape(d: DistributionT) -> DistributionT:
attr_batch_shapes = [d.batch_shape]
for attr_name, constraint in d.arg_constraints.items():
try:
Expand All @@ -515,12 +513,12 @@ def _default_promote_batch_shape(d: Distribution):
attr_batch_shapes.append(jnp.shape(attr)[:attr_batch_ndim])
resolved_batch_shape = jnp.broadcast_shapes(*attr_batch_shapes)
new_self = copy.deepcopy(d)
new_self._batch_shape = resolved_batch_shape
new_self._batch_shape = resolved_batch_shape # type: ignore
return new_self


@promote_batch_shape.register
def _promote_batch_shape_expanded(d: ExpandedDistribution):
def _promote_batch_shape_expanded(d: ExpandedDistribution) -> ExpandedDistribution:
orig_delta_batch_shape = d.batch_shape[
: len(d.batch_shape) - len(d.base_dist.batch_shape)
]
Expand Down Expand Up @@ -560,7 +558,7 @@ def _promote_batch_shape_expanded(d: ExpandedDistribution):


@promote_batch_shape.register
def _promote_batch_shape_masked(d: MaskedDistribution):
def _promote_batch_shape_masked(d: MaskedDistribution) -> MaskedDistribution:
new_self = copy.copy(d)
new_base_dist = promote_batch_shape(d.base_dist)
new_self._batch_shape = new_base_dist.batch_shape
Expand All @@ -569,7 +567,7 @@ def _promote_batch_shape_masked(d: MaskedDistribution):


@promote_batch_shape.register
def _promote_batch_shape_independent(d: Independent):
def _promote_batch_shape_independent(d: Independent) -> DistributionT:
new_self = copy.copy(d)
new_base_dist = promote_batch_shape(d.base_dist)
new_self._batch_shape = new_base_dist.batch_shape[: -d.event_dim]
Expand All @@ -578,5 +576,5 @@ def _promote_batch_shape_independent(d: Independent):


@promote_batch_shape.register
def _promote_batch_shape_unit(d: Unit):
def _promote_batch_shape_unit(d: Unit) -> Unit:
return d
Loading
Loading