@@ -2591,8 +2591,7 @@ class ZeroSumNormal(Distribution):
2591
2591
sigma : tensor_like of float
2592
2592
Scale parameter (sigma > 0).
2593
2593
It's actually the standard deviation of the underlying, unconstrained Normal distribution.
2594
- Defaults to 1 if not specified.
2595
- For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint.
2594
+ Defaults to 1 if not specified. ``sigma`` cannot have length > 1 across the zero-sum axes.
2596
2595
n_zerosum_axes: int, defaults to 1
2597
2596
Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position.
2598
2597
Defaults to 1, i.e the rightmost axis.
@@ -2606,8 +2605,7 @@ class ZeroSumNormal(Distribution):
2606
2605
2607
2606
Warnings
2608
2607
--------
2609
- ``sigma`` has to be a scalar, to ensure the zero-sum constraint.
2610
- The ability to specify a vector of ``sigma`` may be added in future versions.
2608
+ Currently, ``sigma``cannot have length > 1 across the zero-sum axes to ensure the zero-sum constraint.
2611
2609
2612
2610
``n_zerosum_axes`` has to be > 0. If you want the behavior of ``n_zerosum_axes = 0``,
2613
2611
just use ``pm.Normal``.
@@ -2669,8 +2667,8 @@ def dist(cls, sigma=1, n_zerosum_axes=None, support_shape=None, **kwargs):
2669
2667
n_zerosum_axes = cls .check_zerosum_axes (n_zerosum_axes )
2670
2668
2671
2669
sigma = pt .as_tensor_variable (floatX (sigma ))
2672
- if sigma .ndim > 0 :
2673
- raise ValueError ("sigma has to be a scalar " )
2670
+ if not all ( sigma .type . broadcastable [ - n_zerosum_axes :]) :
2671
+ raise ValueError ("sigma must have length one across the zero-sum axes " )
2674
2672
2675
2673
support_shape = get_support_shape (
2676
2674
support_shape = support_shape ,
@@ -2681,9 +2679,7 @@ def dist(cls, sigma=1, n_zerosum_axes=None, support_shape=None, **kwargs):
2681
2679
if support_shape is None :
2682
2680
if n_zerosum_axes > 0 :
2683
2681
raise ValueError ("You must specify dims, shape or support_shape parameter" )
2684
- # TODO: edge-case doesn't work for now, because pt.stack in get_support_shape fails
2685
- # else:
2686
- # support_shape = () # because it's just a Normal in that case
2682
+
2687
2683
support_shape = pt .as_tensor_variable (intX (support_shape ))
2688
2684
2689
2685
assert n_zerosum_axes == pt .get_vector_length (
@@ -2706,7 +2702,12 @@ def check_zerosum_axes(cls, n_zerosum_axes: Optional[int]) -> int:
2706
2702
2707
2703
@classmethod
2708
2704
def rv_op (cls , sigma , n_zerosum_axes , support_shape , size = None ):
2709
- shape = to_tuple (size ) + tuple (support_shape )
2705
+ if size is not None :
2706
+ shape = tuple (size ) + tuple (support_shape )
2707
+ else :
2708
+ # Size is implied by shape of sigma
2709
+ shape = tuple (sigma .shape [:- n_zerosum_axes ]) + tuple (support_shape )
2710
+
2710
2711
normal_dist = pm .Normal .dist (sigma = sigma , shape = shape )
2711
2712
2712
2713
if n_zerosum_axes > normal_dist .ndim :
0 commit comments