@@ -2468,12 +2468,7 @@ def __new__(cls, *args, zerosum_axes=None, support_shape=None, dims=None, **kwar
2468
2468
2469
2469
@classmethod
2470
2470
def dist (cls , sigma = 1 , zerosum_axes = None , support_shape = None , ** kwargs ):
2471
- if zerosum_axes is None :
2472
- zerosum_axes = 1
2473
- if not isinstance (zerosum_axes , int ):
2474
- raise TypeError ("zerosum_axes has to be an integer" )
2475
- if not zerosum_axes > 0 :
2476
- raise ValueError ("zerosum_axes has to be > 0" )
2471
+ zerosum_axes = cls .check_zerosum_axes (zerosum_axes )
2477
2472
2478
2473
sigma = at .as_tensor_variable (floatX (sigma ))
2479
2474
if sigma .ndim > 0 :
@@ -2501,21 +2496,6 @@ def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs):
2501
2496
[sigma ], zerosum_axes = zerosum_axes , support_shape = support_shape , ** kwargs
2502
2497
)
2503
2498
2504
- # TODO: This is if we want ZeroSum constraint on other dists than Normal
2505
- # def dist(cls, dist, lower, upper, **kwargs):
2506
- # if not isinstance(dist, TensorVariable) or not isinstance(
2507
- # dist.owner.op, (RandomVariable, SymbolicRandomVariable)
2508
- # ):
2509
- # raise ValueError(
2510
- # f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}"
2511
- # )
2512
- # if dist.owner.op.ndim_supp > 0:
2513
- # raise NotImplementedError(
2514
- # "Censoring of multivariate distributions has not been implemented yet"
2515
- # )
2516
- # check_dist_not_registered(dist)
2517
- # return super().dist([dist, lower, upper], **kwargs)
2518
-
2519
2499
@classmethod
2520
2500
def check_zerosum_axes (cls , zerosum_axes : Optional [int ]) -> int :
2521
2501
if zerosum_axes is None :
0 commit comments