@@ -138,9 +138,23 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa
138138
139139 Doesn't check for validity of the dims
140140
141+ Parameters
142+ ----------
143+ x : pt.TensorLike
144+ The tensor to align.
145+ dims : Dims
146+ The current dimensions of the tensor.
147+ desired_dims : Dims
148+ The desired dimensions of the tensor.
149+
150+ Returns
151+ -------
152+ pt.TensorVariable
153+ The aligned tensor.
154+
141155 Examples
142156 --------
143- 1D to 2D with new dim
157+ Handle transpose 1D to 2D with new dimension.
144158
145159 .. code-block:: python
146160
@@ -177,10 +191,40 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa
177191
178192
179193DimHandler = Callable [[pt .TensorLike , Dims ], pt .TensorLike ]
194+ """A function that takes a tensor and its current dims and makes it broadcastable to the desired dims."""
180195
181196
182197def create_dim_handler (desired_dims : Dims ) -> DimHandler :
183- """Wrap the `handle_dims` function to act like the previous `create_dim_handler` function."""
198+ """Wrap the :func:`handle_dims` function to always use the same desired_dims.
199+
200+ Parameters
201+ ----------
202+ desired_dims : Dims
203+ The desired dimensions to align to.
204+
205+ Returns
206+ -------
207+ DimHandler
208+ A function that takes a tensor and its current dims and aligns it to
209+ the desired dims.
210+
211+
212+ Examples
213+ --------
214+ Create a dim handler to align to ("channel", "group").
215+
216+ .. code-block:: python
217+
218+ import numpy as np
219+
220+ from pymc_extras.prior import create_dim_handler
221+
222+ dim_handler = create_dim_handler(("channel", "group"))
223+
224+ result = dim_handler(np.array([1, 2, 3]), dims="channel")
225+
226+
227+ """
184228
185229 def func (x : pt .TensorLike , dims : Dims ) -> pt .TensorVariable :
186230 return handle_dims (x , dims , desired_dims )
@@ -268,9 +312,44 @@ def _get_pymc_parameters(distribution: pm.Distribution) -> set[str]:
268312
269313@runtime_checkable
270314class VariableFactory (Protocol ):
271- """Protocol for something that works like a Prior class."""
315+ '''Protocol for something that works like a Prior class.
316+
317+ Sample with :func:`sample_prior`.
318+
319+ Examples
320+ --------
321+ Create a custom variable factory.
322+
323+ .. code-block:: python
324+
325+ import pymc as pm
326+
327+ import pytensor.tensor as pt
328+
329+ from pymc_extras.prior import sample_prior, VariableFactory
330+
331+
332+ class PowerSumDistribution:
333+ """Create a distribution that is the sum of powers of a base distribution."""
334+ def __init__(self, distribution: VariableFactory, n: int):
335+ self.distribution = distribution
336+ self.n = n
337+
338+ @property
339+ def dims(self):
340+ return self.distribution.dims
341+
342+ def create_variable(self, name: str) -> "TensorVariable":
343+ raw = self.distribution.create_variable(f"{name}_raw")
344+ return pm.Deterministic(name, pt.sum([raw ** n for n in range(1, self.n + 1)], axis=0), dims=self.dims,)
345+
346+ cubic = PowerSumDistribution(Prior("Normal"), n=3)
347+ samples = sample_prior(cubic)
348+
349+ '''
272350
273351 dims : tuple [str , ...]
352+ """The dimensions of the variable to create."""
274353
275354 def create_variable (self , name : str ) -> pt .TensorVariable :
276355 """Create a TensorVariable."""
@@ -381,6 +460,80 @@ class Prior:
381460 be registered with `register_tensor_transform` function or
382461 be available in either `pytensor.tensor` or `pymc.math`.
383462
463+ Examples
464+ --------
465+ Create a normal prior.
466+
467+ .. code-block:: python
468+
469+ from pymc_extras.prior import Prior
470+
471+ normal = Prior("Normal")
472+
473+ Create a hierarchical normal prior by using distributions for the parameters
474+ and specifying the dims.
475+
476+ .. code-block:: python
477+
478+ hierarchical_normal = Prior(
479+ "Normal",
480+ mu=Prior("Normal"),
481+ sigma=Prior("HalfNormal"),
482+ dims="channel",
483+ )
484+
485+ Create a non-centered hierarchical normal prior with the `centered` parameter.
486+
487+ .. code-block:: python
488+
489+ non_centered_hierarchical_normal = Prior(
490+ "Normal",
491+ mu=Prior("Normal"),
492+ sigma=Prior("HalfNormal"),
493+ dims="channel",
494+ # Only change needed to make it non-centered
495+ centered=False,
496+ )
497+
498+ Create a hierarchical beta prior by using Beta distribution, distributions for
499+ the parameters, and specifying the dims.
500+
501+ .. code-block:: python
502+
503+ hierarchical_beta = Prior(
504+ "Beta",
505+ alpha=Prior("HalfNormal"),
506+ beta=Prior("HalfNormal"),
507+ dims="channel",
508+ )
509+
510+ Create a transformed hierarchical normal prior by using the `transform`
511+ parameter. Here the "sigmoid" transformation comes from `pm.math`.
512+
513+ .. code-block:: python
514+
515+ transformed_hierarchical_normal = Prior(
516+ "Normal",
517+ mu=Prior("Normal"),
518+ sigma=Prior("HalfNormal"),
519+ transform="sigmoid",
520+ dims="channel",
521+ )
522+
523+ Create a prior with a custom transform function by registering it with
524+ :func:`register_tensor_transform`.
525+
526+ .. code-block:: python
527+
528+ from pymc_extras.prior import register_tensor_transform
529+
530+ def custom_transform(x):
531+ return x ** 2
532+
533+ register_tensor_transform("square", custom_transform)
534+
535+ custom_distribution = Prior("Normal", transform="square")
536+
384537 """
385538
386539 # Taken from https://en.wikipedia.org/wiki/Location%E2%80%93scale_family
@@ -389,9 +542,13 @@ class Prior:
389542 "StudentT" : {"mu" : 0 , "sigma" : 1 },
390543 "ZeroSumNormal" : {"sigma" : 1 },
391544 }
545+ """Available non-centered distributions and their default parameters."""
392546
393547 pymc_distribution : type [pm .Distribution ]
548+ """The PyMC distribution class."""
549+
394550 pytensor_transform : Callable [[pt .TensorLike ], pt .TensorLike ] | None
551+ """The PyTensor transform function."""
395552
396553 @validate_call
397554 def __init__ (
0 commit comments