Skip to content

Commit 126e76b

Browse files
committed
Start writing __new__ method
1 parent a5ed1f0 commit 126e76b

File tree

2 files changed

+54
-33
lines changed

2 files changed

+54
-33
lines changed

pymc/distributions/multivariate.py

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import warnings
1919

2020
from functools import reduce
21+
from typing import Optional
2122

2223
import aesara
2324
import aesara.tensor as at
@@ -36,8 +37,6 @@
3637
from aesara.tensor.random.utils import broadcast_params
3738
from aesara.tensor.slinalg import Cholesky, SolveTriangular
3839
from aesara.tensor.type import TensorType
39-
40-
# from numpy.core.numeric import normalize_axis_tuple
4140
from scipy import linalg, stats
4241

4342
import pymc as pm
@@ -2412,20 +2411,24 @@ class ZeroSumNormal(Distribution):
24122411
It's actually the standard deviation of the underlying, unconstrained Normal distribution.
24132412
Defaults to 1 if not specified.
24142413
For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint.
2415-
zerosum_axes: list or tuple of strings or integers
2416-
Axis (or axes) along which the zero-sum constraint is enforced.
2417-
Defaults to [-1], i.e the last axis.
2418-
If strings are passed, then ``dims`` is needed.
2419-
Otherwise, ``shape`` and ``size`` work as they do for other PyMC distributions.
2420-
dims: list or tuple of strings, optional
2421-
The dimension names of the axes.
2422-
Necessary when ``zerosum_axes`` is specified with strings.
2414+
zerosum_axes: int, defaults to 1
2415+
Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position.
2416+
Defaults to 1, i.e the rightmost axis.
2417+
dims: sequence of strings, optional
2418+
Dimension names of the distribution. Works the same as for other PyMC distributions.
2419+
Necessary if ``shape`` is not passed.
2420+
shape: tuple of integers, optional
2421+
Shape of the distribution. Works the same as for other PyMC distributions.
2422+
Necessary if ``dims`` is not passed.
24232423
24242424
Warnings
24252425
--------
24262426
``sigma`` has to be a scalar, to ensure the zero-sum constraint.
24272427
The ability to specifiy a vector of ``sigma`` may be added in future versions.
24282428
2429+
``zerosum_axes`` has to be > 0. If you want the behavior of ``zerosum_axes = 0``,
2430+
just use ``pm.Normal``.
2431+
24292432
Examples
24302433
--------
24312434
.. code-block:: python
@@ -2444,23 +2447,21 @@ class ZeroSumNormal(Distribution):
24442447
"""
24452448
rv_type = ZeroSumNormalRV
24462449

2447-
# def __new__(cls, *args, zerosum_axes=None, dims=None, **kwargs):
2448-
# dims = convert_dims(dims)
2449-
# if zerosum_axes is None:
2450-
# zerosum_axes = [-1]
2451-
# if not isinstance(zerosum_axes, (list, tuple)):
2452-
# zerosum_axes = [zerosum_axes]
2450+
def __new__(cls, *args, zerosum_axes=None, support_shape=None, dims=None, **kwargs):
2451+
if dims is not None or kwargs.get("observed") is not None:
2452+
zerosum_axes = cls.check_zerosum_axes(zerosum_axes)
24532453

2454-
# if isinstance(zerosum_axes[0], str):
2455-
# if not dims:
2456-
# raise ValueError("You need to specify dims if zerosum_axes are strings.")
2457-
# else:
2458-
# zerosum_axes_ = []
2459-
# for axis in zerosum_axes:
2460-
# zerosum_axes_.append(dims.index(axis))
2461-
# zerosum_axes = zerosum_axes_
2454+
support_shape = get_support_shape(
2455+
support_shape=support_shape,
2456+
shape=None, # Shape will be checked in `cls.dist`
2457+
dims=dims,
2458+
observed=kwargs.get("observed", None),
2459+
ndim_supp=zerosum_axes,
2460+
)
24622461

2463-
# return super().__new__(cls, *args, zerosum_axes=zerosum_axes, dims=dims, **kwargs)
2462+
return super().__new__(
2463+
cls, *args, zerosum_axes=zerosum_axes, support_shape=support_shape, dims=dims, **kwargs
2464+
)
24642465

24652466
@classmethod
24662467
def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs):
@@ -2480,10 +2481,13 @@ def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs):
24802481
shape=kwargs.get("shape"),
24812482
ndim_supp=zerosum_axes,
24822483
)
2484+
2485+
# print(f"{support_shape.eval() = }")
2486+
24832487
if support_shape is None:
24842488
if zerosum_axes > 0:
24852489
raise ValueError("You must specify shape or support_shape parameter")
2486-
# edge case doesn't work for now, because at.stack in get_support_shape fails
2490+
# edge-case doesn't work for now, because at.stack in get_support_shape fails
24872491
# else:
24882492
# support_shape = () # because it's just a Normal in that case
24892493
support_shape = at.as_tensor_variable(intX(support_shape))
@@ -2511,6 +2515,16 @@ def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs):
25112515
# check_dist_not_registered(dist)
25122516
# return super().dist([dist, lower, upper], **kwargs)
25132517

2518+
@classmethod
2519+
def check_zerosum_axes(cls, zerosum_axes: Optional[int]) -> int:
2520+
if zerosum_axes is None:
2521+
zerosum_axes = 1
2522+
if not isinstance(zerosum_axes, int):
2523+
raise TypeError("zerosum_axes has to be an integer")
2524+
if not zerosum_axes > 0:
2525+
raise ValueError("zerosum_axes has to be > 0")
2526+
return zerosum_axes
2527+
25142528
@classmethod
25152529
def rv_op(cls, sigma, zerosum_axes, support_shape, size=None):
25162530
# if size is None:
@@ -2553,11 +2567,14 @@ def rv_op(cls, sigma, zerosum_axes, support_shape, size=None):
25532567

25542568
@_change_dist_size.register(ZeroSumNormalRV)
25552569
def change_zerosum_size(op, normal_dist, new_size, expand=False):
2570+
25562571
normal_dist, sigma, support_shape = normal_dist.owner.inputs
2572+
25572573
if expand:
25582574
original_shape = tuple(normal_dist.shape)
25592575
old_size = original_shape[len(original_shape) - op.ndim_supp :]
25602576
new_size = tuple(new_size) + old_size
2577+
25612578
return ZeroSumNormal.rv_op(
25622579
sigma=sigma, zerosum_axes=op.ndim_supp, support_shape=support_shape, size=new_size
25632580
)
@@ -2570,26 +2587,28 @@ def zerosumnormal_moment(op, rv, *rv_inputs):
25702587

25712588
@_default_transform.register(ZeroSumNormalRV)
25722589
def zerosum_default_transform(op, rv):
2573-
return ZeroSumTransform(op.zerosum_axes)
2590+
zerosum_axes = tuple(np.arange(-op.ndim_supp, 0))
2591+
return ZeroSumTransform(zerosum_axes)
25742592

25752593

25762594
@_logprob.register(ZeroSumNormalRV)
25772595
def zerosumnormal_logp(op, values, normal_dist, sigma, support_shape, **kwargs):
25782596
(value,) = values
25792597
shape = value.shape
25802598
zerosum_axes = op.ndim_supp
2599+
25812600
_deg_free_support_shape = at.inc_subtensor(shape[-zerosum_axes:], -1)
25822601
_full_size = at.prod(shape)
25832602
_degrees_of_freedom = at.prod(_deg_free_support_shape)
2603+
25842604
zerosums = [
25852605
at.all(at.isclose(at.mean(value, axis=-axis - 1), 0, atol=1e-9))
25862606
for axis in range(zerosum_axes)
25872607
]
2608+
25882609
out = at.sum(
25892610
pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size,
25902611
axis=tuple(np.arange(-zerosum_axes, 0)),
25912612
)
2592-
# figure out how dimensionality should be handled for logp
2593-
# for now, we assume ZSN is a scalar distribut, which is not correct
2594-
# out = pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size
2613+
25952614
return check_parameters(out, *zerosums, msg="at.mean(value, axis=zerosum_axes) == 0")

pymc/distributions/shape_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ def get_support_shape(
681681
support_shape_offset: Sequence[int] = None,
682682
ndim_supp: int = 1,
683683
):
684-
"""Extract length of support shapes from shape / dims / observed information
684+
"""Extract the support shapes from shape / dims / observed information
685685
686686
Parameters
687687
----------
@@ -694,7 +694,8 @@ def get_support_shape(
694694
observed:
695695
User-specified observed data from multivariate distribution
696696
support_shape_offset:
697-
Difference between last shape dimensions and the length of explicit support shapes in multivariate distribution, defaults to 0.
697+
Difference between last shape dimensions and the length of
698+
explicit support shapes in multivariate distribution, defaults to 0.
698699
For timeseries, this is shape[-1] = support_shape[-1] + 1
699700
ndim_supp:
700701
Number of support dimensions of the given multivariate distribution, defaults to 1
@@ -740,9 +741,10 @@ def get_support_shape(
740741
inferred_support_shape = support_shape
741742
# If there are two sources of information for the support shapes, assert they are consistent:
742743
elif support_shape is not None:
743-
inferred_support_shape = Assert(msg="Steps do not match last shape dimension")(
744+
inferred_support_shape = Assert(msg="support_shape does not match last shape dimension")(
744745
inferred_support_shape, at.all(at.eq(inferred_support_shape, support_shape))
745746
)
747+
746748
return inferred_support_shape
747749

748750

0 commit comments

Comments
 (0)