Skip to content

Commit a3d45c8

Browse files
committed
provide some more docs for the prior module
1 parent 504ef1c commit a3d45c8

File tree

1 file changed

+160
-3
lines changed

1 file changed

+160
-3
lines changed

pymc_extras/prior.py

Lines changed: 160 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,23 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa
138138
139139
Doesn't check for validity of the dims
140140
141+
Parameters
142+
----------
143+
x : pt.TensorLike
144+
The tensor to align.
145+
dims : Dims
146+
The current dimensions of the tensor.
147+
desired_dims : Dims
148+
The desired dimensions of the tensor.
149+
150+
Returns
151+
-------
152+
pt.TensorVariable
153+
The aligned tensor.
154+
141155
Examples
142156
--------
143-
1D to 2D with new dim
157+
Handle transpose 1D to 2D with new dimension.
144158
145159
.. code-block:: python
146160
@@ -177,10 +191,40 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa
177191

178192

179193
DimHandler = Callable[[pt.TensorLike, Dims], pt.TensorLike]
194+
"""A function that takes a tensor and its current dims and makes it broadcastable to the desired dims."""
180195

181196

182197
def create_dim_handler(desired_dims: Dims) -> DimHandler:
183-
"""Wrap the `handle_dims` function to act like the previous `create_dim_handler` function."""
198+
"""Wrap the :func:`handle_dims` function to always use the same desired_dims.
199+
200+
Parameters
201+
----------
202+
desired_dims : Dims
203+
The desired dimensions to align to.
204+
205+
Returns
206+
-------
207+
DimHandler
208+
A function that takes a tensor and its current dims and aligns it to
209+
the desired dims.
210+
211+
212+
Examples
213+
--------
214+
Create a dim handler to align to ("channel", "group").
215+
216+
.. code-block:: python
217+
218+
import numpy as np
219+
220+
from pymc_extras.prior import create_dim_handler
221+
222+
dim_handler = create_dim_handler(("channel", "group"))
223+
224+
result = dim_handler(np.array([1, 2, 3]), dims="channel")
225+
226+
227+
"""
184228

185229
def func(x: pt.TensorLike, dims: Dims) -> pt.TensorVariable:
186230
return handle_dims(x, dims, desired_dims)
@@ -268,9 +312,44 @@ def _get_pymc_parameters(distribution: pm.Distribution) -> set[str]:
268312

269313
@runtime_checkable
270314
class VariableFactory(Protocol):
271-
"""Protocol for something that works like a Prior class."""
315+
'''Protocol for something that works like a Prior class.
316+
317+
Sample with :func:`sample_prior`.
318+
319+
Examples
320+
--------
321+
Create a custom variable factory.
322+
323+
.. code-block:: python
324+
325+
import pymc as pm
326+
327+
import pytensor.tensor as pt
328+
329+
from pymc_extras.prior import sample_prior, VariableFactory
330+
331+
332+
class PowerSumDistribution:
333+
"""Create a distribution that is the sum of powers of a base distribution."""
334+
def __init__(self, distribution: VariableFactory, n: int):
335+
self.distribution = distribution
336+
self.n = n
337+
338+
@property
339+
def dims(self):
340+
return self.distribution.dims
341+
342+
def create_variable(self, name: str) -> "TensorVariable":
343+
raw = self.distribution.create_variable(f"{name}_raw")
344+
return pm.Deterministic(name, pt.sum([raw ** n for n in range(1, self.n + 1)], axis=0), dims=self.dims,)
345+
346+
cubic = PowerSumDistribution(Prior("Normal"), n=3)
347+
samples = sample_prior(cubic)
348+
349+
'''
272350

273351
dims: tuple[str, ...]
352+
"""The dimensions of the variable to create."""
274353

275354
def create_variable(self, name: str) -> pt.TensorVariable:
276355
"""Create a TensorVariable."""
@@ -381,6 +460,80 @@ class Prior:
381460
be registered with `register_tensor_transform` function or
382461
be available in either `pytensor.tensor` or `pymc.math`.
383462
463+
Examples
464+
--------
465+
Create a normal prior.
466+
467+
.. code-block:: python
468+
469+
from pymc_extras.prior import Prior
470+
471+
normal = Prior("Normal")
472+
473+
Create a hierarchical normal prior by using distributions for the parameters
474+
and specifying the dims.
475+
476+
.. code-block:: python
477+
478+
hierarchical_normal = Prior(
479+
"Normal",
480+
mu=Prior("Normal"),
481+
sigma=Prior("HalfNormal"),
482+
dims="channel",
483+
)
484+
485+
Create a non-centered hierarchical normal prior with the `centered` parameter.
486+
487+
.. code-block:: python
488+
489+
non_centered_hierarchical_normal = Prior(
490+
"Normal",
491+
mu=Prior("Normal"),
492+
sigma=Prior("HalfNormal"),
493+
dims="channel",
494+
# Only change needed to make it non-centered
495+
centered=False,
496+
)
497+
498+
Create a hierarchical beta prior by using Beta distribution, distributions for
499+
the parameters, and specifying the dims.
500+
501+
.. code-block:: python
502+
503+
hierarchical_beta = Prior(
504+
"Beta",
505+
alpha=Prior("HalfNormal"),
506+
beta=Prior("HalfNormal"),
507+
dims="channel",
508+
)
509+
510+
Create a transformed hierarchical normal prior by using the `transform`
511+
parameter. Here the "sigmoid" transformation comes from `pm.math`.
512+
513+
.. code-block:: python
514+
515+
transformed_hierarchical_normal = Prior(
516+
"Normal",
517+
mu=Prior("Normal"),
518+
sigma=Prior("HalfNormal"),
519+
transform="sigmoid",
520+
dims="channel",
521+
)
522+
523+
Create a prior with a custom transform function by registering it with
524+
:func:`register_tensor_transform`.
525+
526+
.. code-block:: python
527+
528+
from pymc_extras.prior import register_tensor_transform
529+
530+
def custom_transform(x):
531+
return x ** 2
532+
533+
register_tensor_transform("square", custom_transform)
534+
535+
custom_distribution = Prior("Normal", transform="square")
536+
384537
"""
385538

386539
# Taken from https://en.wikipedia.org/wiki/Location%E2%80%93scale_family
@@ -389,9 +542,13 @@ class Prior:
389542
"StudentT": {"mu": 0, "sigma": 1},
390543
"ZeroSumNormal": {"sigma": 1},
391544
}
545+
"""Available non-centered distributions and their default parameters."""
392546

393547
pymc_distribution: type[pm.Distribution]
548+
"""The PyMC distribution class."""
549+
394550
pytensor_transform: Callable[[pt.TensorLike], pt.TensorLike] | None
551+
"""The PyTensor transform function."""
395552

396553
@validate_call
397554
def __init__(

0 commit comments

Comments
 (0)