@@ -138,9 +138,23 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa
138
138
139
139
Doesn't check for validity of the dims
140
140
141
+ Parameters
142
+ ----------
143
+ x : pt.TensorLike
144
+ The tensor to align.
145
+ dims : Dims
146
+ The current dimensions of the tensor.
147
+ desired_dims : Dims
148
+ The desired dimensions of the tensor.
149
+
150
+ Returns
151
+ -------
152
+ pt.TensorVariable
153
+ The aligned tensor.
154
+
141
155
Examples
142
156
--------
143
- 1D to 2D with new dim
157
+ Handle transpose 1D to 2D with new dimension.
144
158
145
159
.. code-block:: python
146
160
@@ -177,10 +191,40 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa
177
191
178
192
179
193
DimHandler = Callable [[pt .TensorLike , Dims ], pt .TensorLike ]
194
+ """A function that takes a tensor and its current dims and makes it broadcastable to the desired dims."""
180
195
181
196
182
197
def create_dim_handler (desired_dims : Dims ) -> DimHandler :
183
- """Wrap the `handle_dims` function to act like the previous `create_dim_handler` function."""
198
+ """Wrap the :func:`handle_dims` function to always use the same desired_dims.
199
+
200
+ Parameters
201
+ ----------
202
+ desired_dims : Dims
203
+ The desired dimensions to align to.
204
+
205
+ Returns
206
+ -------
207
+ DimHandler
208
+ A function that takes a tensor and its current dims and aligns it to
209
+ the desired dims.
210
+
211
+
212
+ Examples
213
+ --------
214
+ Create a dim handler to align to ("channel", "group").
215
+
216
+ .. code-block:: python
217
+
218
+ import numpy as np
219
+
220
+ from pymc_extras.prior import create_dim_handler
221
+
222
+ dim_handler = create_dim_handler(("channel", "group"))
223
+
224
+ result = dim_handler(np.array([1, 2, 3]), dims="channel")
225
+
226
+
227
+ """
184
228
185
229
def func (x : pt .TensorLike , dims : Dims ) -> pt .TensorVariable :
186
230
return handle_dims (x , dims , desired_dims )
@@ -268,9 +312,44 @@ def _get_pymc_parameters(distribution: pm.Distribution) -> set[str]:
268
312
269
313
@runtime_checkable
270
314
class VariableFactory (Protocol ):
271
- """Protocol for something that works like a Prior class."""
315
+ '''Protocol for something that works like a Prior class.
316
+
317
+ Sample with :func:`sample_prior`.
318
+
319
+ Examples
320
+ --------
321
+ Create a custom variable factory.
322
+
323
+ .. code-block:: python
324
+
325
+ import pymc as pm
326
+
327
+ import pytensor.tensor as pt
328
+
329
+ from pymc_extras.prior import sample_prior, VariableFactory
330
+
331
+
332
+ class PowerSumDistribution:
333
+ """Create a distribution that is the sum of powers of a base distribution."""
334
+ def __init__(self, distribution: VariableFactory, n: int):
335
+ self.distribution = distribution
336
+ self.n = n
337
+
338
+ @property
339
+ def dims(self):
340
+ return self.distribution.dims
341
+
342
+ def create_variable(self, name: str) -> "TensorVariable":
343
+ raw = self.distribution.create_variable(f"{name}_raw")
344
+ return pm.Deterministic(name, pt.sum([raw ** n for n in range(1, self.n + 1)], axis=0), dims=self.dims,)
345
+
346
+ cubic = PowerSumDistribution(Prior("Normal"), n=3)
347
+ samples = sample_prior(cubic)
348
+
349
+ '''
272
350
273
351
dims : tuple [str , ...]
352
+ """The dimensions of the variable to create."""
274
353
275
354
def create_variable (self , name : str ) -> pt .TensorVariable :
276
355
"""Create a TensorVariable."""
@@ -381,6 +460,80 @@ class Prior:
381
460
be registered with `register_tensor_transform` function or
382
461
be available in either `pytensor.tensor` or `pymc.math`.
383
462
463
+ Examples
464
+ --------
465
+ Create a normal prior.
466
+
467
+ .. code-block:: python
468
+
469
+ from pymc_extras.prior import Prior
470
+
471
+ normal = Prior("Normal")
472
+
473
+ Create a hierarchical normal prior by using distributions for the parameters
474
+ and specifying the dims.
475
+
476
+ .. code-block:: python
477
+
478
+ hierarchical_normal = Prior(
479
+ "Normal",
480
+ mu=Prior("Normal"),
481
+ sigma=Prior("HalfNormal"),
482
+ dims="channel",
483
+ )
484
+
485
+ Create a non-centered hierarchical normal prior with the `centered` parameter.
486
+
487
+ .. code-block:: python
488
+
489
+ non_centered_hierarchical_normal = Prior(
490
+ "Normal",
491
+ mu=Prior("Normal"),
492
+ sigma=Prior("HalfNormal"),
493
+ dims="channel",
494
+ # Only change needed to make it non-centered
495
+ centered=False,
496
+ )
497
+
498
+ Create a hierarchical beta prior by using Beta distribution, distributions for
499
+ the parameters, and specifying the dims.
500
+
501
+ .. code-block:: python
502
+
503
+ hierarchical_beta = Prior(
504
+ "Beta",
505
+ alpha=Prior("HalfNormal"),
506
+ beta=Prior("HalfNormal"),
507
+ dims="channel",
508
+ )
509
+
510
+ Create a transformed hierarchical normal prior by using the `transform`
511
+ parameter. Here the "sigmoid" transformation comes from `pm.math`.
512
+
513
+ .. code-block:: python
514
+
515
+ transformed_hierarchical_normal = Prior(
516
+ "Normal",
517
+ mu=Prior("Normal"),
518
+ sigma=Prior("HalfNormal"),
519
+ transform="sigmoid",
520
+ dims="channel",
521
+ )
522
+
523
+ Create a prior with a custom transform function by registering it with
524
+ :func:`register_tensor_transform`.
525
+
526
+ .. code-block:: python
527
+
528
+ from pymc_extras.prior import register_tensor_transform
529
+
530
+ def custom_transform(x):
531
+ return x ** 2
532
+
533
+ register_tensor_transform("square", custom_transform)
534
+
535
+ custom_distribution = Prior("Normal", transform="square")
536
+
384
537
"""
385
538
386
539
# Taken from https://en.wikipedia.org/wiki/Location%E2%80%93scale_family
@@ -389,9 +542,13 @@ class Prior:
389
542
"StudentT" : {"mu" : 0 , "sigma" : 1 },
390
543
"ZeroSumNormal" : {"sigma" : 1 },
391
544
}
545
+ """Available non-centered distributions and their default parameters."""
392
546
393
547
pymc_distribution : type [pm .Distribution ]
548
+ """The PyMC distribution class."""
549
+
394
550
pytensor_transform : Callable [[pt .TensorLike ], pt .TensorLike ] | None
551
+ """The PyTensor transform function."""
395
552
396
553
@validate_call
397
554
def __init__ (
0 commit comments