diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 6e09c4ab4..e21772b8b 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -56,6 +56,7 @@ Prior create_dim_handler handle_dims Prior + register_tensor_transform VariableFactory sample_prior Censored diff --git a/pymc_extras/prior.py b/pymc_extras/prior.py index 0f368f986..d9d94604f 100644 --- a/pymc_extras/prior.py +++ b/pymc_extras/prior.py @@ -138,9 +138,23 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa Doesn't check for validity of the dims + Parameters + ---------- + x : pt.TensorLike + The tensor to align. + dims : Dims + The current dimensions of the tensor. + desired_dims : Dims + The desired dimensions of the tensor. + + Returns + ------- + pt.TensorVariable + The aligned tensor. + Examples -------- - 1D to 2D with new dim + Handle transpose 1D to 2D with new dimension. .. code-block:: python @@ -177,10 +191,40 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa DimHandler = Callable[[pt.TensorLike, Dims], pt.TensorLike] +"""A function that takes a tensor and its current dims and makes it broadcastable to the desired dims.""" def create_dim_handler(desired_dims: Dims) -> DimHandler: - """Wrap the `handle_dims` function to act like the previous `create_dim_handler` function.""" + """Wrap the :func:`handle_dims` function to always use the same desired_dims. + + Parameters + ---------- + desired_dims : Dims + The desired dimensions to align to. + + Returns + ------- + DimHandler + A function that takes a tensor and its current dims and aligns it to + the desired dims. + + + Examples + -------- + Create a dim handler to align to ("channel", "group"). + + .. code-block:: python + + import numpy as np + + from pymc_extras.prior import create_dim_handler + + dim_handler = create_dim_handler(("channel", "group")) + + result = dim_handler(np.array([1, 2, 3]), dims="channel") + + + """ def func(x: pt.TensorLike, dims: Dims) -> pt.TensorVariable: return handle_dims(x, dims, desired_dims) @@ -268,9 +312,44 @@ def _get_pymc_parameters(distribution: pm.Distribution) -> set[str]: @runtime_checkable class VariableFactory(Protocol): - """Protocol for something that works like a Prior class.""" + '''Protocol for something that works like a Prior class. + + Sample with :func:`sample_prior`. + + Examples + -------- + Create a custom variable factory. + + .. code-block:: python + + import pymc as pm + + import pytensor.tensor as pt + + from pymc_extras.prior import sample_prior, VariableFactory + + + class PowerSumDistribution: + """Create a distribution that is the sum of powers of a base distribution.""" + def __init__(self, distribution: VariableFactory, n: int): + self.distribution = distribution + self.n = n + + @property + def dims(self): + return self.distribution.dims + + def create_variable(self, name: str) -> "TensorVariable": + raw = self.distribution.create_variable(f"{name}_raw") + return pm.Deterministic(name, pt.sum([raw ** n for n in range(1, self.n + 1)], axis=0), dims=self.dims,) + + cubic = PowerSumDistribution(Prior("Normal"), n=3) + samples = sample_prior(cubic) + + ''' dims: tuple[str, ...] + """The dimensions of the variable to create.""" def create_variable(self, name: str) -> pt.TensorVariable: """Create a TensorVariable.""" @@ -381,6 +460,80 @@ class Prior: be registered with `register_tensor_transform` function or be available in either `pytensor.tensor` or `pymc.math`. + Examples + -------- + Create a normal prior. + + .. code-block:: python + + from pymc_extras.prior import Prior + + normal = Prior("Normal") + + Create a hierarchical normal prior by using distributions for the parameters + and specifying the dims. + + .. code-block:: python + + hierarchical_normal = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + + Create a non-centered hierarchical normal prior with the `centered` parameter. + + .. code-block:: python + + non_centered_hierarchical_normal = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + # Only change needed to make it non-centered + centered=False, + ) + + Create a hierarchical beta prior by using Beta distribution, distributions for + the parameters, and specifying the dims. + + .. code-block:: python + + hierarchical_beta = Prior( + "Beta", + alpha=Prior("HalfNormal"), + beta=Prior("HalfNormal"), + dims="channel", + ) + + Create a transformed hierarchical normal prior by using the `transform` + parameter. Here the "sigmoid" transformation comes from `pm.math`. + + .. code-block:: python + + transformed_hierarchical_normal = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + transform="sigmoid", + dims="channel", + ) + + Create a prior with a custom transform function by registering it with + :func:`register_tensor_transform`. + + .. code-block:: python + + from pymc_extras.prior import register_tensor_transform + + def custom_transform(x): + return x ** 2 + + register_tensor_transform("square", custom_transform) + + custom_distribution = Prior("Normal", transform="square") + """ # Taken from https://en.wikipedia.org/wiki/Location%E2%80%93scale_family @@ -389,9 +542,13 @@ class Prior: "StudentT": {"mu": 0, "sigma": 1}, "ZeroSumNormal": {"sigma": 1}, } + """Available non-centered distributions and their default parameters.""" pymc_distribution: type[pm.Distribution] + """The PyMC distribution class.""" + pytensor_transform: Callable[[pt.TensorLike], pt.TensorLike] | None + """The PyTensor transform function.""" @validate_call def __init__( @@ -1317,9 +1474,33 @@ def create_likelihood_variable( class Scaled: - """Scaled distribution for numerical stability.""" + """Scaled distribution for numerical stability. + + This is the same as multiplying the variable by a constant factor. + + Parameters + ---------- + dist : Prior + The prior distribution to scale. + factor : pt.TensorLike + The scaling factor. This will have to be broadcastable to the + dimensions of the distribution. + + Examples + -------- + Create a scaled normal distribution. + + .. code-block:: python + + from pymc_extras.prior import Prior, Scaled + + normal = Prior("Normal", mu=0, sigma=1) + # Same as Normal(mu=0, sigma=10) + scaled_normal = Scaled(normal, factor=10) + + """ - def __init__(self, dist: Prior, factor: float | pt.TensorVariable) -> None: + def __init__(self, dist: Prior, factor: pt.TensorLike) -> None: self.dist = dist self.factor = factor