Skip to content

Commit a5ed1f0

Browse files
committed
Refactor ZSN dist and logp for rightmost zerosum_axes
1 parent da6eaab commit a5ed1f0

File tree

2 files changed

+95
-57
lines changed

2 files changed

+95
-57
lines changed

pymc/distributions/multivariate.py

Lines changed: 93 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
from aesara.tensor.random.utils import broadcast_params
3737
from aesara.tensor.slinalg import Cholesky, SolveTriangular
3838
from aesara.tensor.type import TensorType
39-
from numpy.core.numeric import normalize_axis_tuple
39+
40+
# from numpy.core.numeric import normalize_axis_tuple
4041
from scipy import linalg, stats
4142

4243
import pymc as pm
@@ -64,7 +65,7 @@
6465
_change_dist_size,
6566
broadcast_dist_samples_to,
6667
change_dist_size,
67-
convert_dims,
68+
get_support_shape,
6869
rv_size_is_none,
6970
to_tuple,
7071
)
@@ -2389,11 +2390,7 @@ class ZeroSumNormalRV(SymbolicRandomVariable):
23892390
"""ZeroSumNormal random variable"""
23902391

23912392
_print_name = ("ZeroSumNormal", "\\operatorname{ZeroSumNormal}")
2392-
zerosum_axes = None
2393-
2394-
def __init__(self, *args, zerosum_axes, **kwargs):
2395-
self.zerosum_axes = zerosum_axes
2396-
super().__init__(*args, **kwargs)
2393+
default_output = 0
23972394

23982395

23992396
class ZeroSumNormal(Distribution):
@@ -2447,36 +2444,57 @@ class ZeroSumNormal(Distribution):
24472444
"""
24482445
rv_type = ZeroSumNormalRV
24492446

2450-
def __new__(cls, *args, zerosum_axes=None, dims=None, **kwargs):
2451-
dims = convert_dims(dims)
2452-
if zerosum_axes is None:
2453-
zerosum_axes = [-1]
2454-
if not isinstance(zerosum_axes, (list, tuple)):
2455-
zerosum_axes = [zerosum_axes]
2447+
# def __new__(cls, *args, zerosum_axes=None, dims=None, **kwargs):
2448+
# dims = convert_dims(dims)
2449+
# if zerosum_axes is None:
2450+
# zerosum_axes = [-1]
2451+
# if not isinstance(zerosum_axes, (list, tuple)):
2452+
# zerosum_axes = [zerosum_axes]
24562453

2457-
if isinstance(zerosum_axes[0], str):
2458-
if not dims:
2459-
raise ValueError("You need to specify dims if zerosum_axes are strings.")
2460-
else:
2461-
zerosum_axes_ = []
2462-
for axis in zerosum_axes:
2463-
zerosum_axes_.append(dims.index(axis))
2464-
zerosum_axes = zerosum_axes_
2454+
# if isinstance(zerosum_axes[0], str):
2455+
# if not dims:
2456+
# raise ValueError("You need to specify dims if zerosum_axes are strings.")
2457+
# else:
2458+
# zerosum_axes_ = []
2459+
# for axis in zerosum_axes:
2460+
# zerosum_axes_.append(dims.index(axis))
2461+
# zerosum_axes = zerosum_axes_
24652462

2466-
return super().__new__(cls, *args, zerosum_axes=zerosum_axes, dims=dims, **kwargs)
2463+
# return super().__new__(cls, *args, zerosum_axes=zerosum_axes, dims=dims, **kwargs)
24672464

24682465
@classmethod
2469-
def dist(cls, sigma=1, zerosum_axes=None, **kwargs):
2466+
def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs):
24702467
if zerosum_axes is None:
2471-
zerosum_axes = [-1]
2472-
if not isinstance(zerosum_axes, (list, tuple)):
2473-
zerosum_axes = [zerosum_axes]
2468+
zerosum_axes = 1
2469+
if not isinstance(zerosum_axes, int):
2470+
raise TypeError("zerosum_axes has to be an integer")
2471+
if not zerosum_axes > 0:
2472+
raise ValueError("zerosum_axes has to be > 0")
24742473

24752474
sigma = at.as_tensor_variable(floatX(sigma))
24762475
if sigma.ndim > 0:
24772476
raise ValueError("sigma has to be a scalar")
24782477

2479-
return super().dist([sigma], zerosum_axes=zerosum_axes, **kwargs)
2478+
support_shape = get_support_shape(
2479+
support_shape=support_shape,
2480+
shape=kwargs.get("shape"),
2481+
ndim_supp=zerosum_axes,
2482+
)
2483+
if support_shape is None:
2484+
if zerosum_axes > 0:
2485+
raise ValueError("You must specify shape or support_shape parameter")
2486+
# edge case doesn't work for now, because at.stack in get_support_shape fails
2487+
# else:
2488+
# support_shape = () # because it's just a Normal in that case
2489+
support_shape = at.as_tensor_variable(intX(support_shape))
2490+
2491+
assert zerosum_axes == at.get_vector_length(
2492+
support_shape
2493+
), "support_shape has to be as long as zerosum_axes"
2494+
2495+
return super().dist(
2496+
[sigma], zerosum_axes=zerosum_axes, support_shape=support_shape, **kwargs
2497+
)
24802498

24812499
# TODO: This is if we want ZeroSum constraint on other dists than Normal
24822500
# def dist(cls, dist, lower, upper, **kwargs):
@@ -2494,39 +2512,55 @@ def dist(cls, sigma=1, zerosum_axes=None, **kwargs):
24942512
# return super().dist([dist, lower, upper], **kwargs)
24952513

24962514
@classmethod
2497-
def rv_op(cls, sigma, zerosum_axes, size=None):
2498-
if size is None:
2499-
zerosum_axes_ = np.asarray(zerosum_axes)
2500-
# just a placeholder size to infer minimum shape
2501-
size = np.ones(
2502-
max((max(np.abs(zerosum_axes_) - 1), max(zerosum_axes_))) + 1, dtype=int
2503-
).tolist()
2515+
def rv_op(cls, sigma, zerosum_axes, support_shape, size=None):
2516+
# if size is None:
2517+
# zerosum_axes_ = np.asarray(zerosum_axes)
2518+
# # just a placeholder size to infer minimum shape
2519+
# size = np.ones(
2520+
# max((max(np.abs(zerosum_axes_) - 1), max(zerosum_axes_))) + 1, dtype=int
2521+
# ).tolist()
25042522

25052523
# check if zerosum_axes is valid
2506-
normalize_axis_tuple(zerosum_axes, len(size))
2507-
2508-
normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, size=size))
2509-
normal_dist_, sigma_ = normal_dist.type(), sigma.type()
2524+
# normalize_axis_tuple(zerosum_axes, len(size))
2525+
2526+
shape = to_tuple(size) + tuple(support_shape)
2527+
normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, shape=shape))
2528+
normal_dist_, sigma_, support_shape_ = (
2529+
normal_dist.type(),
2530+
sigma.type(),
2531+
support_shape.type(),
2532+
)
25102533

25112534
# Zerosum-normaling is achieved by substracting the mean along the given zerosum_axes
25122535
zerosum_rv_ = normal_dist_
2513-
for axis in zerosum_axes:
2514-
zerosum_rv_ -= zerosum_rv_.mean(axis=axis, keepdims=True)
2536+
for axis in range(zerosum_axes):
2537+
zerosum_rv_ -= zerosum_rv_.mean(axis=-axis - 1, keepdims=True)
25152538

25162539
return ZeroSumNormalRV(
2517-
inputs=[normal_dist_, sigma_],
2518-
outputs=[zerosum_rv_],
2519-
zerosum_axes=zerosum_axes,
2520-
ndim_supp=0,
2521-
)(normal_dist, sigma)
2540+
inputs=[normal_dist_, sigma_, support_shape_],
2541+
outputs=[zerosum_rv_, support_shape_],
2542+
ndim_supp=zerosum_axes,
2543+
)(normal_dist, sigma, support_shape)
2544+
2545+
# TODO:
2546+
# write __new__
2547+
# refactor ZSN tests
2548+
# test get_support_shape with 2D
2549+
# test ZSN logp
2550+
# test ZSN variance
2551+
# fix failing Ubuntu test
25222552

25232553

25242554
@_change_dist_size.register(ZeroSumNormalRV)
25252555
def change_zerosum_size(op, normal_dist, new_size, expand=False):
2526-
normal_dist, sigma = normal_dist.owner.inputs
2556+
normal_dist, sigma, support_shape = normal_dist.owner.inputs
25272557
if expand:
2528-
new_size = tuple(new_size) + tuple(normal_dist.shape)
2529-
return ZeroSumNormal.rv_op(sigma=sigma, zerosum_axes=op.zerosum_axes, size=new_size)
2558+
original_shape = tuple(normal_dist.shape)
2559+
old_size = original_shape[len(original_shape) - op.ndim_supp :]
2560+
new_size = tuple(new_size) + old_size
2561+
return ZeroSumNormal.rv_op(
2562+
sigma=sigma, zerosum_axes=op.ndim_supp, support_shape=support_shape, size=new_size
2563+
)
25302564

25312565

25322566
@_moment.register(ZeroSumNormalRV)
@@ -2540,20 +2574,22 @@ def zerosum_default_transform(op, rv):
25402574

25412575

25422576
@_logprob.register(ZeroSumNormalRV)
2543-
def zerosumnormal_logp(op, values, normal_dist, sigma, **kwargs):
2577+
def zerosumnormal_logp(op, values, normal_dist, sigma, support_shape, **kwargs):
25442578
(value,) = values
25452579
shape = value.shape
2546-
_deg_free_shape = at.inc_subtensor(shape[at.as_tensor_variable(op.zerosum_axes)], -1)
2580+
zerosum_axes = op.ndim_supp
2581+
_deg_free_support_shape = at.inc_subtensor(shape[-zerosum_axes:], -1)
25472582
_full_size = at.prod(shape)
2548-
_degrees_of_freedom = at.prod(_deg_free_shape)
2583+
_degrees_of_freedom = at.prod(_deg_free_support_shape)
25492584
zerosums = [
2550-
at.all(at.isclose(at.mean(value, axis=axis), 0, atol=1e-9)) for axis in op.zerosum_axes
2585+
at.all(at.isclose(at.mean(value, axis=-axis - 1), 0, atol=1e-9))
2586+
for axis in range(zerosum_axes)
25512587
]
2552-
# out = at.sum(
2553-
# pm.logp(dist, value) * _degrees_of_freedom / _full_size,
2554-
# axis=op.zerosum_axes,
2555-
# )
2588+
out = at.sum(
2589+
pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size,
2590+
axis=tuple(np.arange(-zerosum_axes, 0)),
2591+
)
25562592
# figure out how dimensionality should be handled for logp
25572593
# for now, we assume ZSN is a scalar distribut, which is not correct
2558-
out = pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size
2594+
# out = pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size
25592595
return check_parameters(out, *zerosums, msg="at.mean(value, axis=zerosum_axes) == 0")

pymc/distributions/shape_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,8 @@ def get_support_shape(
706706
shape / dims / observed. When two sources of support shape information are provided,
707707
a symbolic Assert is added to ensure they are consistent.
708708
"""
709+
if ndim_supp < 1:
710+
raise NotImplementedError("ndim_supp must be bigger than 0")
709711
if support_shape_offset is None:
710712
support_shape_offset = [0] * ndim_supp
711713
inferred_support_shape = None

0 commit comments

Comments
 (0)