Skip to content

Commit 222d60c

Browse files
authored
Add to prior module docs (#590)
* provide some more docs for the prior module * add the referenced function * closes #576
1 parent 71d3ac0 commit 222d60c

File tree

2 files changed

+187
-5
lines changed

2 files changed

+187
-5
lines changed

docs/api_reference.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ Prior
5656
create_dim_handler
5757
handle_dims
5858
Prior
59+
register_tensor_transform
5960
VariableFactory
6061
sample_prior
6162
Censored

pymc_extras/prior.py

Lines changed: 186 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,23 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa
140140
141141
Doesn't check for validity of the dims
142142
143+
Parameters
144+
----------
145+
x : pt.TensorLike
146+
The tensor to align.
147+
dims : Dims
148+
The current dimensions of the tensor.
149+
desired_dims : Dims
150+
The desired dimensions of the tensor.
151+
152+
Returns
153+
-------
154+
pt.TensorVariable
155+
The aligned tensor.
156+
143157
Examples
144158
--------
145-
1D to 2D with new dim
159+
Handle transpose 1D to 2D with new dimension.
146160
147161
.. code-block:: python
148162
@@ -179,10 +193,40 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa
179193

180194

181195
DimHandler = Callable[[pt.TensorLike, Dims], pt.TensorLike]
196+
"""A function that takes a tensor and its current dims and makes it broadcastable to the desired dims."""
182197

183198

184199
def create_dim_handler(desired_dims: Dims) -> DimHandler:
185-
"""Wrap the `handle_dims` function to act like the previous `create_dim_handler` function."""
200+
"""Wrap the :func:`handle_dims` function to always use the same desired_dims.
201+
202+
Parameters
203+
----------
204+
desired_dims : Dims
205+
The desired dimensions to align to.
206+
207+
Returns
208+
-------
209+
DimHandler
210+
A function that takes a tensor and its current dims and aligns it to
211+
the desired dims.
212+
213+
214+
Examples
215+
--------
216+
Create a dim handler to align to ("channel", "group").
217+
218+
.. code-block:: python
219+
220+
import numpy as np
221+
222+
from pymc_extras.prior import create_dim_handler
223+
224+
dim_handler = create_dim_handler(("channel", "group"))
225+
226+
result = dim_handler(np.array([1, 2, 3]), dims="channel")
227+
228+
229+
"""
186230

187231
def func(x: pt.TensorLike, dims: Dims) -> pt.TensorVariable:
188232
return handle_dims(x, dims, desired_dims)
@@ -272,9 +316,44 @@ def _get_pymc_parameters(distribution: pm.Distribution) -> set[str]:
272316

273317
@runtime_checkable
274318
class VariableFactory(Protocol):
275-
"""Protocol for something that works like a Prior class."""
319+
'''Protocol for something that works like a Prior class.
320+
321+
Sample with :func:`sample_prior`.
322+
323+
Examples
324+
--------
325+
Create a custom variable factory.
326+
327+
.. code-block:: python
328+
329+
import pymc as pm
330+
331+
import pytensor.tensor as pt
332+
333+
from pymc_extras.prior import sample_prior, VariableFactory
334+
335+
336+
class PowerSumDistribution:
337+
"""Create a distribution that is the sum of powers of a base distribution."""
338+
def __init__(self, distribution: VariableFactory, n: int):
339+
self.distribution = distribution
340+
self.n = n
341+
342+
@property
343+
def dims(self):
344+
return self.distribution.dims
345+
346+
def create_variable(self, name: str) -> "TensorVariable":
347+
raw = self.distribution.create_variable(f"{name}_raw")
348+
return pm.Deterministic(name, pt.sum([raw ** n for n in range(1, self.n + 1)], axis=0), dims=self.dims,)
349+
350+
cubic = PowerSumDistribution(Prior("Normal"), n=3)
351+
samples = sample_prior(cubic)
352+
353+
'''
276354

277355
dims: tuple[str, ...]
356+
"""The dimensions of the variable to create."""
278357

279358
def create_variable(self, name: str) -> pt.TensorVariable:
280359
"""Create a TensorVariable."""
@@ -387,6 +466,80 @@ class Prior:
387466
be registered with `register_tensor_transform` function or
388467
be available in either `pytensor.tensor` or `pymc.math`.
389468
469+
Examples
470+
--------
471+
Create a normal prior.
472+
473+
.. code-block:: python
474+
475+
from pymc_extras.prior import Prior
476+
477+
normal = Prior("Normal")
478+
479+
Create a hierarchical normal prior by using distributions for the parameters
480+
and specifying the dims.
481+
482+
.. code-block:: python
483+
484+
hierarchical_normal = Prior(
485+
"Normal",
486+
mu=Prior("Normal"),
487+
sigma=Prior("HalfNormal"),
488+
dims="channel",
489+
)
490+
491+
Create a non-centered hierarchical normal prior with the `centered` parameter.
492+
493+
.. code-block:: python
494+
495+
non_centered_hierarchical_normal = Prior(
496+
"Normal",
497+
mu=Prior("Normal"),
498+
sigma=Prior("HalfNormal"),
499+
dims="channel",
500+
# Only change needed to make it non-centered
501+
centered=False,
502+
)
503+
504+
Create a hierarchical beta prior by using Beta distribution, distributions for
505+
the parameters, and specifying the dims.
506+
507+
.. code-block:: python
508+
509+
hierarchical_beta = Prior(
510+
"Beta",
511+
alpha=Prior("HalfNormal"),
512+
beta=Prior("HalfNormal"),
513+
dims="channel",
514+
)
515+
516+
Create a transformed hierarchical normal prior by using the `transform`
517+
parameter. Here the "sigmoid" transformation comes from `pm.math`.
518+
519+
.. code-block:: python
520+
521+
transformed_hierarchical_normal = Prior(
522+
"Normal",
523+
mu=Prior("Normal"),
524+
sigma=Prior("HalfNormal"),
525+
transform="sigmoid",
526+
dims="channel",
527+
)
528+
529+
Create a prior with a custom transform function by registering it with
530+
:func:`register_tensor_transform`.
531+
532+
.. code-block:: python
533+
534+
from pymc_extras.prior import register_tensor_transform
535+
536+
def custom_transform(x):
537+
return x ** 2
538+
539+
register_tensor_transform("square", custom_transform)
540+
541+
custom_distribution = Prior("Normal", transform="square")
542+
390543
"""
391544

392545
# Taken from https://en.wikipedia.org/wiki/Location%E2%80%93scale_family
@@ -395,9 +548,13 @@ class Prior:
395548
"StudentT": {"mu": 0, "sigma": 1},
396549
"ZeroSumNormal": {"sigma": 1},
397550
}
551+
"""Available non-centered distributions and their default parameters."""
398552

399553
pymc_distribution: type[pm.Distribution]
554+
"""The PyMC distribution class."""
555+
400556
pytensor_transform: Callable[[pt.TensorLike], pt.TensorLike] | None
557+
"""The PyTensor transform function."""
401558

402559
@validate_call
403560
def __init__(
@@ -1323,9 +1480,33 @@ def create_likelihood_variable(
13231480

13241481

13251482
class Scaled:
1326-
"""Scaled distribution for numerical stability."""
1483+
"""Scaled distribution for numerical stability.
1484+
1485+
This is the same as multiplying the variable by a constant factor.
1486+
1487+
Parameters
1488+
----------
1489+
dist : Prior
1490+
The prior distribution to scale.
1491+
factor : pt.TensorLike
1492+
The scaling factor. This will have to be broadcastable to the
1493+
dimensions of the distribution.
1494+
1495+
Examples
1496+
--------
1497+
Create a scaled normal distribution.
1498+
1499+
.. code-block:: python
1500+
1501+
from pymc_extras.prior import Prior, Scaled
1502+
1503+
normal = Prior("Normal", mu=0, sigma=1)
1504+
# Same as Normal(mu=0, sigma=10)
1505+
scaled_normal = Scaled(normal, factor=10)
1506+
1507+
"""
13271508

1328-
def __init__(self, dist: Prior, factor: float | pt.TensorVariable) -> None:
1509+
def __init__(self, dist: Prior, factor: pt.TensorLike) -> None:
13291510
self.dist = dist
13301511
self.factor = factor
13311512

0 commit comments

Comments
 (0)