18
18
import warnings
19
19
20
20
from functools import reduce
21
+ from typing import Optional
21
22
22
23
import aesara
23
24
import aesara .tensor as at
36
37
from aesara .tensor .random .utils import broadcast_params
37
38
from aesara .tensor .slinalg import Cholesky , SolveTriangular
38
39
from aesara .tensor .type import TensorType
39
-
40
- # from numpy.core.numeric import normalize_axis_tuple
41
40
from scipy import linalg , stats
42
41
43
42
import pymc as pm
@@ -2412,20 +2411,24 @@ class ZeroSumNormal(Distribution):
2412
2411
It's actually the standard deviation of the underlying, unconstrained Normal distribution.
2413
2412
Defaults to 1 if not specified.
2414
2413
For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint.
2415
- zerosum_axes: list or tuple of strings or integers
2416
- Axis (or axes) along which the zero-sum constraint is enforced.
2417
- Defaults to [-1], i.e the last axis.
2418
- If strings are passed, then ``dims`` is needed.
2419
- Otherwise, ``shape`` and ``size`` work as they do for other PyMC distributions.
2420
- dims: list or tuple of strings, optional
2421
- The dimension names of the axes.
2422
- Necessary when ``zerosum_axes`` is specified with strings.
2414
+ zerosum_axes: int, defaults to 1
2415
+ Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position.
2416
+ Defaults to 1, i.e the rightmost axis.
2417
+ dims: sequence of strings, optional
2418
+ Dimension names of the distribution. Works the same as for other PyMC distributions.
2419
+ Necessary if ``shape`` is not passed.
2420
+ shape: tuple of integers, optional
2421
+ Shape of the distribution. Works the same as for other PyMC distributions.
2422
+ Necessary if ``dims`` is not passed.
2423
2423
2424
2424
Warnings
2425
2425
--------
2426
2426
``sigma`` has to be a scalar, to ensure the zero-sum constraint.
2427
2427
The ability to specifiy a vector of ``sigma`` may be added in future versions.
2428
2428
2429
+ ``zerosum_axes`` has to be > 0. If you want the behavior of ``zerosum_axes = 0``,
2430
+ just use ``pm.Normal``.
2431
+
2429
2432
Examples
2430
2433
--------
2431
2434
.. code-block:: python
@@ -2444,23 +2447,21 @@ class ZeroSumNormal(Distribution):
2444
2447
"""
2445
2448
rv_type = ZeroSumNormalRV
2446
2449
2447
- # def __new__(cls, *args, zerosum_axes=None, dims=None, **kwargs):
2448
- # dims = convert_dims(dims)
2449
- # if zerosum_axes is None:
2450
- # zerosum_axes = [-1]
2451
- # if not isinstance(zerosum_axes, (list, tuple)):
2452
- # zerosum_axes = [zerosum_axes]
2450
+ def __new__ (cls , * args , zerosum_axes = None , support_shape = None , dims = None , ** kwargs ):
2451
+ if dims is not None or kwargs .get ("observed" ) is not None :
2452
+ zerosum_axes = cls .check_zerosum_axes (zerosum_axes )
2453
2453
2454
- # if isinstance(zerosum_axes[0], str):
2455
- # if not dims:
2456
- # raise ValueError("You need to specify dims if zerosum_axes are strings.")
2457
- # else:
2458
- # zerosum_axes_ = []
2459
- # for axis in zerosum_axes:
2460
- # zerosum_axes_.append(dims.index(axis))
2461
- # zerosum_axes = zerosum_axes_
2454
+ support_shape = get_support_shape (
2455
+ support_shape = support_shape ,
2456
+ shape = None , # Shape will be checked in `cls.dist`
2457
+ dims = dims ,
2458
+ observed = kwargs .get ("observed" , None ),
2459
+ ndim_supp = zerosum_axes ,
2460
+ )
2462
2461
2463
- # return super().__new__(cls, *args, zerosum_axes=zerosum_axes, dims=dims, **kwargs)
2462
+ return super ().__new__ (
2463
+ cls , * args , zerosum_axes = zerosum_axes , support_shape = support_shape , dims = dims , ** kwargs
2464
+ )
2464
2465
2465
2466
@classmethod
2466
2467
def dist (cls , sigma = 1 , zerosum_axes = None , support_shape = None , ** kwargs ):
@@ -2480,10 +2481,13 @@ def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs):
2480
2481
shape = kwargs .get ("shape" ),
2481
2482
ndim_supp = zerosum_axes ,
2482
2483
)
2484
+
2485
+ # print(f"{support_shape.eval() = }")
2486
+
2483
2487
if support_shape is None :
2484
2488
if zerosum_axes > 0 :
2485
2489
raise ValueError ("You must specify shape or support_shape parameter" )
2486
- # edge case doesn't work for now, because at.stack in get_support_shape fails
2490
+ # edge- case doesn't work for now, because at.stack in get_support_shape fails
2487
2491
# else:
2488
2492
# support_shape = () # because it's just a Normal in that case
2489
2493
support_shape = at .as_tensor_variable (intX (support_shape ))
@@ -2511,6 +2515,16 @@ def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs):
2511
2515
# check_dist_not_registered(dist)
2512
2516
# return super().dist([dist, lower, upper], **kwargs)
2513
2517
2518
+ @classmethod
2519
+ def check_zerosum_axes (cls , zerosum_axes : Optional [int ]) -> int :
2520
+ if zerosum_axes is None :
2521
+ zerosum_axes = 1
2522
+ if not isinstance (zerosum_axes , int ):
2523
+ raise TypeError ("zerosum_axes has to be an integer" )
2524
+ if not zerosum_axes > 0 :
2525
+ raise ValueError ("zerosum_axes has to be > 0" )
2526
+ return zerosum_axes
2527
+
2514
2528
@classmethod
2515
2529
def rv_op (cls , sigma , zerosum_axes , support_shape , size = None ):
2516
2530
# if size is None:
@@ -2553,11 +2567,14 @@ def rv_op(cls, sigma, zerosum_axes, support_shape, size=None):
2553
2567
2554
2568
@_change_dist_size .register (ZeroSumNormalRV )
2555
2569
def change_zerosum_size (op , normal_dist , new_size , expand = False ):
2570
+
2556
2571
normal_dist , sigma , support_shape = normal_dist .owner .inputs
2572
+
2557
2573
if expand :
2558
2574
original_shape = tuple (normal_dist .shape )
2559
2575
old_size = original_shape [len (original_shape ) - op .ndim_supp :]
2560
2576
new_size = tuple (new_size ) + old_size
2577
+
2561
2578
return ZeroSumNormal .rv_op (
2562
2579
sigma = sigma , zerosum_axes = op .ndim_supp , support_shape = support_shape , size = new_size
2563
2580
)
@@ -2570,26 +2587,28 @@ def zerosumnormal_moment(op, rv, *rv_inputs):
2570
2587
2571
2588
@_default_transform .register (ZeroSumNormalRV )
2572
2589
def zerosum_default_transform (op , rv ):
2573
- return ZeroSumTransform (op .zerosum_axes )
2590
+ zerosum_axes = tuple (np .arange (- op .ndim_supp , 0 ))
2591
+ return ZeroSumTransform (zerosum_axes )
2574
2592
2575
2593
2576
2594
@_logprob .register (ZeroSumNormalRV )
2577
2595
def zerosumnormal_logp (op , values , normal_dist , sigma , support_shape , ** kwargs ):
2578
2596
(value ,) = values
2579
2597
shape = value .shape
2580
2598
zerosum_axes = op .ndim_supp
2599
+
2581
2600
_deg_free_support_shape = at .inc_subtensor (shape [- zerosum_axes :], - 1 )
2582
2601
_full_size = at .prod (shape )
2583
2602
_degrees_of_freedom = at .prod (_deg_free_support_shape )
2603
+
2584
2604
zerosums = [
2585
2605
at .all (at .isclose (at .mean (value , axis = - axis - 1 ), 0 , atol = 1e-9 ))
2586
2606
for axis in range (zerosum_axes )
2587
2607
]
2608
+
2588
2609
out = at .sum (
2589
2610
pm .logp (normal_dist , value ) * _degrees_of_freedom / _full_size ,
2590
2611
axis = tuple (np .arange (- zerosum_axes , 0 )),
2591
2612
)
2592
- # figure out how dimensionality should be handled for logp
2593
- # for now, we assume ZSN is a scalar distribut, which is not correct
2594
- # out = pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size
2613
+
2595
2614
return check_parameters (out , * zerosums , msg = "at.mean(value, axis=zerosum_axes) == 0" )
0 commit comments