@@ -140,9 +140,23 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa
140
140
141
141
Doesn't check for validity of the dims
142
142
143
+ Parameters
144
+ ----------
145
+ x : pt.TensorLike
146
+ The tensor to align.
147
+ dims : Dims
148
+ The current dimensions of the tensor.
149
+ desired_dims : Dims
150
+ The desired dimensions of the tensor.
151
+
152
+ Returns
153
+ -------
154
+ pt.TensorVariable
155
+ The aligned tensor.
156
+
143
157
Examples
144
158
--------
145
- 1D to 2D with new dim
159
+ Handle transpose 1D to 2D with new dimension.
146
160
147
161
.. code-block:: python
148
162
@@ -179,10 +193,40 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa
179
193
180
194
181
195
DimHandler = Callable [[pt .TensorLike , Dims ], pt .TensorLike ]
196
+ """A function that takes a tensor and its current dims and makes it broadcastable to the desired dims."""
182
197
183
198
184
199
def create_dim_handler (desired_dims : Dims ) -> DimHandler :
185
- """Wrap the `handle_dims` function to act like the previous `create_dim_handler` function."""
200
+ """Wrap the :func:`handle_dims` function to always use the same desired_dims.
201
+
202
+ Parameters
203
+ ----------
204
+ desired_dims : Dims
205
+ The desired dimensions to align to.
206
+
207
+ Returns
208
+ -------
209
+ DimHandler
210
+ A function that takes a tensor and its current dims and aligns it to
211
+ the desired dims.
212
+
213
+
214
+ Examples
215
+ --------
216
+ Create a dim handler to align to ("channel", "group").
217
+
218
+ .. code-block:: python
219
+
220
+ import numpy as np
221
+
222
+ from pymc_extras.prior import create_dim_handler
223
+
224
+ dim_handler = create_dim_handler(("channel", "group"))
225
+
226
+ result = dim_handler(np.array([1, 2, 3]), dims="channel")
227
+
228
+
229
+ """
186
230
187
231
def func (x : pt .TensorLike , dims : Dims ) -> pt .TensorVariable :
188
232
return handle_dims (x , dims , desired_dims )
@@ -272,9 +316,44 @@ def _get_pymc_parameters(distribution: pm.Distribution) -> set[str]:
272
316
273
317
@runtime_checkable
274
318
class VariableFactory (Protocol ):
275
- """Protocol for something that works like a Prior class."""
319
+ '''Protocol for something that works like a Prior class.
320
+
321
+ Sample with :func:`sample_prior`.
322
+
323
+ Examples
324
+ --------
325
+ Create a custom variable factory.
326
+
327
+ .. code-block:: python
328
+
329
+ import pymc as pm
330
+
331
+ import pytensor.tensor as pt
332
+
333
+ from pymc_extras.prior import sample_prior, VariableFactory
334
+
335
+
336
+ class PowerSumDistribution:
337
+ """Create a distribution that is the sum of powers of a base distribution."""
338
+ def __init__(self, distribution: VariableFactory, n: int):
339
+ self.distribution = distribution
340
+ self.n = n
341
+
342
+ @property
343
+ def dims(self):
344
+ return self.distribution.dims
345
+
346
+ def create_variable(self, name: str) -> "TensorVariable":
347
+ raw = self.distribution.create_variable(f"{name}_raw")
348
+ return pm.Deterministic(name, pt.sum([raw ** n for n in range(1, self.n + 1)], axis=0), dims=self.dims,)
349
+
350
+ cubic = PowerSumDistribution(Prior("Normal"), n=3)
351
+ samples = sample_prior(cubic)
352
+
353
+ '''
276
354
277
355
dims : tuple [str , ...]
356
+ """The dimensions of the variable to create."""
278
357
279
358
def create_variable (self , name : str ) -> pt .TensorVariable :
280
359
"""Create a TensorVariable."""
@@ -387,6 +466,80 @@ class Prior:
387
466
be registered with `register_tensor_transform` function or
388
467
be available in either `pytensor.tensor` or `pymc.math`.
389
468
469
+ Examples
470
+ --------
471
+ Create a normal prior.
472
+
473
+ .. code-block:: python
474
+
475
+ from pymc_extras.prior import Prior
476
+
477
+ normal = Prior("Normal")
478
+
479
+ Create a hierarchical normal prior by using distributions for the parameters
480
+ and specifying the dims.
481
+
482
+ .. code-block:: python
483
+
484
+ hierarchical_normal = Prior(
485
+ "Normal",
486
+ mu=Prior("Normal"),
487
+ sigma=Prior("HalfNormal"),
488
+ dims="channel",
489
+ )
490
+
491
+ Create a non-centered hierarchical normal prior with the `centered` parameter.
492
+
493
+ .. code-block:: python
494
+
495
+ non_centered_hierarchical_normal = Prior(
496
+ "Normal",
497
+ mu=Prior("Normal"),
498
+ sigma=Prior("HalfNormal"),
499
+ dims="channel",
500
+ # Only change needed to make it non-centered
501
+ centered=False,
502
+ )
503
+
504
+ Create a hierarchical beta prior by using Beta distribution, distributions for
505
+ the parameters, and specifying the dims.
506
+
507
+ .. code-block:: python
508
+
509
+ hierarchical_beta = Prior(
510
+ "Beta",
511
+ alpha=Prior("HalfNormal"),
512
+ beta=Prior("HalfNormal"),
513
+ dims="channel",
514
+ )
515
+
516
+ Create a transformed hierarchical normal prior by using the `transform`
517
+ parameter. Here the "sigmoid" transformation comes from `pm.math`.
518
+
519
+ .. code-block:: python
520
+
521
+ transformed_hierarchical_normal = Prior(
522
+ "Normal",
523
+ mu=Prior("Normal"),
524
+ sigma=Prior("HalfNormal"),
525
+ transform="sigmoid",
526
+ dims="channel",
527
+ )
528
+
529
+ Create a prior with a custom transform function by registering it with
530
+ :func:`register_tensor_transform`.
531
+
532
+ .. code-block:: python
533
+
534
+ from pymc_extras.prior import register_tensor_transform
535
+
536
+ def custom_transform(x):
537
+ return x ** 2
538
+
539
+ register_tensor_transform("square", custom_transform)
540
+
541
+ custom_distribution = Prior("Normal", transform="square")
542
+
390
543
"""
391
544
392
545
# Taken from https://en.wikipedia.org/wiki/Location%E2%80%93scale_family
@@ -395,9 +548,13 @@ class Prior:
395
548
"StudentT" : {"mu" : 0 , "sigma" : 1 },
396
549
"ZeroSumNormal" : {"sigma" : 1 },
397
550
}
551
+ """Available non-centered distributions and their default parameters."""
398
552
399
553
pymc_distribution : type [pm .Distribution ]
554
+ """The PyMC distribution class."""
555
+
400
556
pytensor_transform : Callable [[pt .TensorLike ], pt .TensorLike ] | None
557
+ """The PyTensor transform function."""
401
558
402
559
@validate_call
403
560
def __init__ (
@@ -1323,9 +1480,33 @@ def create_likelihood_variable(
1323
1480
1324
1481
1325
1482
class Scaled :
1326
- """Scaled distribution for numerical stability."""
1483
+ """Scaled distribution for numerical stability.
1484
+
1485
+ This is the same as multiplying the variable by a constant factor.
1486
+
1487
+ Parameters
1488
+ ----------
1489
+ dist : Prior
1490
+ The prior distribution to scale.
1491
+ factor : pt.TensorLike
1492
+ The scaling factor. This will have to be broadcastable to the
1493
+ dimensions of the distribution.
1494
+
1495
+ Examples
1496
+ --------
1497
+ Create a scaled normal distribution.
1498
+
1499
+ .. code-block:: python
1500
+
1501
+ from pymc_extras.prior import Prior, Scaled
1502
+
1503
+ normal = Prior("Normal", mu=0, sigma=1)
1504
+ # Same as Normal(mu=0, sigma=10)
1505
+ scaled_normal = Scaled(normal, factor=10)
1506
+
1507
+ """
1327
1508
1328
- def __init__ (self , dist : Prior , factor : float | pt .TensorVariable ) -> None :
1509
+ def __init__ (self , dist : Prior , factor : pt .TensorLike ) -> None :
1329
1510
self .dist = dist
1330
1511
self .factor = factor
1331
1512
0 commit comments