36
36
from aesara .tensor .random .utils import broadcast_params
37
37
from aesara .tensor .slinalg import Cholesky , SolveTriangular
38
38
from aesara .tensor .type import TensorType
39
- from numpy .core .numeric import normalize_axis_tuple
39
+
40
+ # from numpy.core.numeric import normalize_axis_tuple
40
41
from scipy import linalg , stats
41
42
42
43
import pymc as pm
64
65
_change_dist_size ,
65
66
broadcast_dist_samples_to ,
66
67
change_dist_size ,
67
- convert_dims ,
68
+ get_support_shape ,
68
69
rv_size_is_none ,
69
70
to_tuple ,
70
71
)
@@ -2389,11 +2390,7 @@ class ZeroSumNormalRV(SymbolicRandomVariable):
2389
2390
"""ZeroSumNormal random variable"""
2390
2391
2391
2392
_print_name = ("ZeroSumNormal" , "\\ operatorname{ZeroSumNormal}" )
2392
- zerosum_axes = None
2393
-
2394
- def __init__ (self , * args , zerosum_axes , ** kwargs ):
2395
- self .zerosum_axes = zerosum_axes
2396
- super ().__init__ (* args , ** kwargs )
2393
+ default_output = 0
2397
2394
2398
2395
2399
2396
class ZeroSumNormal (Distribution ):
@@ -2447,36 +2444,57 @@ class ZeroSumNormal(Distribution):
2447
2444
"""
2448
2445
rv_type = ZeroSumNormalRV
2449
2446
2450
- def __new__ (cls , * args , zerosum_axes = None , dims = None , ** kwargs ):
2451
- dims = convert_dims (dims )
2452
- if zerosum_axes is None :
2453
- zerosum_axes = [- 1 ]
2454
- if not isinstance (zerosum_axes , (list , tuple )):
2455
- zerosum_axes = [zerosum_axes ]
2447
+ # def __new__(cls, *args, zerosum_axes=None, dims=None, **kwargs):
2448
+ # dims = convert_dims(dims)
2449
+ # if zerosum_axes is None:
2450
+ # zerosum_axes = [-1]
2451
+ # if not isinstance(zerosum_axes, (list, tuple)):
2452
+ # zerosum_axes = [zerosum_axes]
2456
2453
2457
- if isinstance (zerosum_axes [0 ], str ):
2458
- if not dims :
2459
- raise ValueError ("You need to specify dims if zerosum_axes are strings." )
2460
- else :
2461
- zerosum_axes_ = []
2462
- for axis in zerosum_axes :
2463
- zerosum_axes_ .append (dims .index (axis ))
2464
- zerosum_axes = zerosum_axes_
2454
+ # if isinstance(zerosum_axes[0], str):
2455
+ # if not dims:
2456
+ # raise ValueError("You need to specify dims if zerosum_axes are strings.")
2457
+ # else:
2458
+ # zerosum_axes_ = []
2459
+ # for axis in zerosum_axes:
2460
+ # zerosum_axes_.append(dims.index(axis))
2461
+ # zerosum_axes = zerosum_axes_
2465
2462
2466
- return super ().__new__ (cls , * args , zerosum_axes = zerosum_axes , dims = dims , ** kwargs )
2463
+ # return super().__new__(cls, *args, zerosum_axes=zerosum_axes, dims=dims, **kwargs)
2467
2464
2468
2465
@classmethod
2469
- def dist (cls , sigma = 1 , zerosum_axes = None , ** kwargs ):
2466
+ def dist (cls , sigma = 1 , zerosum_axes = None , support_shape = None , ** kwargs ):
2470
2467
if zerosum_axes is None :
2471
- zerosum_axes = [- 1 ]
2472
- if not isinstance (zerosum_axes , (list , tuple )):
2473
- zerosum_axes = [zerosum_axes ]
2468
+ zerosum_axes = 1
2469
+ if not isinstance (zerosum_axes , int ):
2470
+ raise TypeError ("zerosum_axes has to be an integer" )
2471
+ if not zerosum_axes > 0 :
2472
+ raise ValueError ("zerosum_axes has to be > 0" )
2474
2473
2475
2474
sigma = at .as_tensor_variable (floatX (sigma ))
2476
2475
if sigma .ndim > 0 :
2477
2476
raise ValueError ("sigma has to be a scalar" )
2478
2477
2479
- return super ().dist ([sigma ], zerosum_axes = zerosum_axes , ** kwargs )
2478
+ support_shape = get_support_shape (
2479
+ support_shape = support_shape ,
2480
+ shape = kwargs .get ("shape" ),
2481
+ ndim_supp = zerosum_axes ,
2482
+ )
2483
+ if support_shape is None :
2484
+ if zerosum_axes > 0 :
2485
+ raise ValueError ("You must specify shape or support_shape parameter" )
2486
+ # edge case doesn't work for now, because at.stack in get_support_shape fails
2487
+ # else:
2488
+ # support_shape = () # because it's just a Normal in that case
2489
+ support_shape = at .as_tensor_variable (intX (support_shape ))
2490
+
2491
+ assert zerosum_axes == at .get_vector_length (
2492
+ support_shape
2493
+ ), "support_shape has to be as long as zerosum_axes"
2494
+
2495
+ return super ().dist (
2496
+ [sigma ], zerosum_axes = zerosum_axes , support_shape = support_shape , ** kwargs
2497
+ )
2480
2498
2481
2499
# TODO: This is if we want ZeroSum constraint on other dists than Normal
2482
2500
# def dist(cls, dist, lower, upper, **kwargs):
@@ -2494,39 +2512,55 @@ def dist(cls, sigma=1, zerosum_axes=None, **kwargs):
2494
2512
# return super().dist([dist, lower, upper], **kwargs)
2495
2513
2496
2514
@classmethod
2497
- def rv_op (cls , sigma , zerosum_axes , size = None ):
2498
- if size is None :
2499
- zerosum_axes_ = np .asarray (zerosum_axes )
2500
- # just a placeholder size to infer minimum shape
2501
- size = np .ones (
2502
- max ((max (np .abs (zerosum_axes_ ) - 1 ), max (zerosum_axes_ ))) + 1 , dtype = int
2503
- ).tolist ()
2515
+ def rv_op (cls , sigma , zerosum_axes , support_shape , size = None ):
2516
+ # if size is None:
2517
+ # zerosum_axes_ = np.asarray(zerosum_axes)
2518
+ # # just a placeholder size to infer minimum shape
2519
+ # size = np.ones(
2520
+ # max((max(np.abs(zerosum_axes_) - 1), max(zerosum_axes_))) + 1, dtype=int
2521
+ # ).tolist()
2504
2522
2505
2523
# check if zerosum_axes is valid
2506
- normalize_axis_tuple (zerosum_axes , len (size ))
2507
-
2508
- normal_dist = ignore_logprob (pm .Normal .dist (sigma = sigma , size = size ))
2509
- normal_dist_ , sigma_ = normal_dist .type (), sigma .type ()
2524
+ # normalize_axis_tuple(zerosum_axes, len(size))
2525
+
2526
+ shape = to_tuple (size ) + tuple (support_shape )
2527
+ normal_dist = ignore_logprob (pm .Normal .dist (sigma = sigma , shape = shape ))
2528
+ normal_dist_ , sigma_ , support_shape_ = (
2529
+ normal_dist .type (),
2530
+ sigma .type (),
2531
+ support_shape .type (),
2532
+ )
2510
2533
2511
2534
# Zerosum-normaling is achieved by substracting the mean along the given zerosum_axes
2512
2535
zerosum_rv_ = normal_dist_
2513
- for axis in zerosum_axes :
2514
- zerosum_rv_ -= zerosum_rv_ .mean (axis = axis , keepdims = True )
2536
+ for axis in range ( zerosum_axes ) :
2537
+ zerosum_rv_ -= zerosum_rv_ .mean (axis = - axis - 1 , keepdims = True )
2515
2538
2516
2539
return ZeroSumNormalRV (
2517
- inputs = [normal_dist_ , sigma_ ],
2518
- outputs = [zerosum_rv_ ],
2519
- zerosum_axes = zerosum_axes ,
2520
- ndim_supp = 0 ,
2521
- )(normal_dist , sigma )
2540
+ inputs = [normal_dist_ , sigma_ , support_shape_ ],
2541
+ outputs = [zerosum_rv_ , support_shape_ ],
2542
+ ndim_supp = zerosum_axes ,
2543
+ )(normal_dist , sigma , support_shape )
2544
+
2545
+ # TODO:
2546
+ # write __new__
2547
+ # refactor ZSN tests
2548
+ # test get_support_shape with 2D
2549
+ # test ZSN logp
2550
+ # test ZSN variance
2551
+ # fix failing Ubuntu test
2522
2552
2523
2553
2524
2554
@_change_dist_size .register (ZeroSumNormalRV )
2525
2555
def change_zerosum_size (op , normal_dist , new_size , expand = False ):
2526
- normal_dist , sigma = normal_dist .owner .inputs
2556
+ normal_dist , sigma , support_shape = normal_dist .owner .inputs
2527
2557
if expand :
2528
- new_size = tuple (new_size ) + tuple (normal_dist .shape )
2529
- return ZeroSumNormal .rv_op (sigma = sigma , zerosum_axes = op .zerosum_axes , size = new_size )
2558
+ original_shape = tuple (normal_dist .shape )
2559
+ old_size = original_shape [len (original_shape ) - op .ndim_supp :]
2560
+ new_size = tuple (new_size ) + old_size
2561
+ return ZeroSumNormal .rv_op (
2562
+ sigma = sigma , zerosum_axes = op .ndim_supp , support_shape = support_shape , size = new_size
2563
+ )
2530
2564
2531
2565
2532
2566
@_moment .register (ZeroSumNormalRV )
@@ -2540,20 +2574,22 @@ def zerosum_default_transform(op, rv):
2540
2574
2541
2575
2542
2576
@_logprob .register (ZeroSumNormalRV )
2543
- def zerosumnormal_logp (op , values , normal_dist , sigma , ** kwargs ):
2577
+ def zerosumnormal_logp (op , values , normal_dist , sigma , support_shape , ** kwargs ):
2544
2578
(value ,) = values
2545
2579
shape = value .shape
2546
- _deg_free_shape = at .inc_subtensor (shape [at .as_tensor_variable (op .zerosum_axes )], - 1 )
2580
+ zerosum_axes = op .ndim_supp
2581
+ _deg_free_support_shape = at .inc_subtensor (shape [- zerosum_axes :], - 1 )
2547
2582
_full_size = at .prod (shape )
2548
- _degrees_of_freedom = at .prod (_deg_free_shape )
2583
+ _degrees_of_freedom = at .prod (_deg_free_support_shape )
2549
2584
zerosums = [
2550
- at .all (at .isclose (at .mean (value , axis = axis ), 0 , atol = 1e-9 )) for axis in op .zerosum_axes
2585
+ at .all (at .isclose (at .mean (value , axis = - axis - 1 ), 0 , atol = 1e-9 ))
2586
+ for axis in range (zerosum_axes )
2551
2587
]
2552
- # out = at.sum(
2553
- # pm.logp(dist , value) * _degrees_of_freedom / _full_size,
2554
- # axis=op. zerosum_axes,
2555
- # )
2588
+ out = at .sum (
2589
+ pm .logp (normal_dist , value ) * _degrees_of_freedom / _full_size ,
2590
+ axis = tuple ( np . arange ( - zerosum_axes , 0 )) ,
2591
+ )
2556
2592
# figure out how dimensionality should be handled for logp
2557
2593
# for now, we assume ZSN is a scalar distribut, which is not correct
2558
- out = pm .logp (normal_dist , value ) * _degrees_of_freedom / _full_size
2594
+ # out = pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size
2559
2595
return check_parameters (out , * zerosums , msg = "at.mean(value, axis=zerosum_axes) == 0" )
0 commit comments