Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Prior
create_dim_handler
handle_dims
Prior
register_tensor_transform
VariableFactory
sample_prior
Censored
Expand Down
191 changes: 186 additions & 5 deletions pymc_extras/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,23 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa

Doesn't check for validity of the dims

Parameters
----------
x : pt.TensorLike
The tensor to align.
dims : Dims
The current dimensions of the tensor.
desired_dims : Dims
The desired dimensions of the tensor.

Returns
-------
pt.TensorVariable
The aligned tensor.

Examples
--------
1D to 2D with new dim
Handle transpose 1D to 2D with new dimension.

.. code-block:: python

Expand Down Expand Up @@ -177,10 +191,40 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa


DimHandler = Callable[[pt.TensorLike, Dims], pt.TensorLike]
"""A function that takes a tensor and its current dims and makes it broadcastable to the desired dims."""


def create_dim_handler(desired_dims: Dims) -> DimHandler:
"""Wrap the `handle_dims` function to act like the previous `create_dim_handler` function."""
"""Wrap the :func:`handle_dims` function to always use the same desired_dims.

Parameters
----------
desired_dims : Dims
The desired dimensions to align to.

Returns
-------
DimHandler
A function that takes a tensor and its current dims and aligns it to
the desired dims.


Examples
--------
Create a dim handler to align to ("channel", "group").

.. code-block:: python

import numpy as np

from pymc_extras.prior import create_dim_handler

dim_handler = create_dim_handler(("channel", "group"))

result = dim_handler(np.array([1, 2, 3]), dims="channel")


"""

def func(x: pt.TensorLike, dims: Dims) -> pt.TensorVariable:
return handle_dims(x, dims, desired_dims)
Expand Down Expand Up @@ -268,9 +312,44 @@ def _get_pymc_parameters(distribution: pm.Distribution) -> set[str]:

@runtime_checkable
class VariableFactory(Protocol):
"""Protocol for something that works like a Prior class."""
'''Protocol for something that works like a Prior class.

Sample with :func:`sample_prior`.

Examples
--------
Create a custom variable factory.

.. code-block:: python

import pymc as pm

import pytensor.tensor as pt

from pymc_extras.prior import sample_prior, VariableFactory


class PowerSumDistribution:
"""Create a distribution that is the sum of powers of a base distribution."""
def __init__(self, distribution: VariableFactory, n: int):
self.distribution = distribution
self.n = n

@property
def dims(self):
return self.distribution.dims

def create_variable(self, name: str) -> "TensorVariable":
raw = self.distribution.create_variable(f"{name}_raw")
return pm.Deterministic(name, pt.sum([raw ** n for n in range(1, self.n + 1)], axis=0), dims=self.dims,)

cubic = PowerSumDistribution(Prior("Normal"), n=3)
samples = sample_prior(cubic)

'''

dims: tuple[str, ...]
"""The dimensions of the variable to create."""

def create_variable(self, name: str) -> pt.TensorVariable:
"""Create a TensorVariable."""
Expand Down Expand Up @@ -381,6 +460,80 @@ class Prior:
be registered with `register_tensor_transform` function or
be available in either `pytensor.tensor` or `pymc.math`.

Examples
--------
Create a normal prior.

.. code-block:: python

from pymc_extras.prior import Prior

normal = Prior("Normal")

Create a hierarchical normal prior by using distributions for the parameters
and specifying the dims.

.. code-block:: python

hierarchical_normal = Prior(
"Normal",
mu=Prior("Normal"),
sigma=Prior("HalfNormal"),
dims="channel",
)

Create a non-centered hierarchical normal prior with the `centered` parameter.

.. code-block:: python

non_centered_hierarchical_normal = Prior(
"Normal",
mu=Prior("Normal"),
sigma=Prior("HalfNormal"),
dims="channel",
# Only change needed to make it non-centered
centered=False,
)

Create a hierarchical beta prior by using Beta distribution, distributions for
the parameters, and specifying the dims.

.. code-block:: python

hierarchical_beta = Prior(
"Beta",
alpha=Prior("HalfNormal"),
beta=Prior("HalfNormal"),
dims="channel",
)

Create a transformed hierarchical normal prior by using the `transform`
parameter. Here the "sigmoid" transformation comes from `pm.math`.

.. code-block:: python

transformed_hierarchical_normal = Prior(
"Normal",
mu=Prior("Normal"),
sigma=Prior("HalfNormal"),
transform="sigmoid",
dims="channel",
)

Create a prior with a custom transform function by registering it with
:func:`register_tensor_transform`.

.. code-block:: python

from pymc_extras.prior import register_tensor_transform

def custom_transform(x):
return x ** 2

register_tensor_transform("square", custom_transform)

custom_distribution = Prior("Normal", transform="square")

"""

# Taken from https://en.wikipedia.org/wiki/Location%E2%80%93scale_family
Expand All @@ -389,9 +542,13 @@ class Prior:
"StudentT": {"mu": 0, "sigma": 1},
"ZeroSumNormal": {"sigma": 1},
}
"""Available non-centered distributions and their default parameters."""

pymc_distribution: type[pm.Distribution]
"""The PyMC distribution class."""

pytensor_transform: Callable[[pt.TensorLike], pt.TensorLike] | None
"""The PyTensor transform function."""

@validate_call
def __init__(
Expand Down Expand Up @@ -1317,9 +1474,33 @@ def create_likelihood_variable(


class Scaled:
"""Scaled distribution for numerical stability."""
"""Scaled distribution for numerical stability.

This is the same as multiplying the variable by a constant factor.

Parameters
----------
dist : Prior
The prior distribution to scale.
factor : pt.TensorLike
The scaling factor. This will have to be broadcastable to the
dimensions of the distribution.

Examples
--------
Create a scaled normal distribution.

.. code-block:: python

from pymc_extras.prior import Prior, Scaled

normal = Prior("Normal", mu=0, sigma=1)
# Same as Normal(mu=0, sigma=10)
scaled_normal = Scaled(normal, factor=10)

"""

def __init__(self, dist: Prior, factor: float | pt.TensorVariable) -> None:
def __init__(self, dist: Prior, factor: pt.TensorLike) -> None:
self.dist = dist
self.factor = factor

Expand Down