Skip to content

Handles axis=None symbolically instead of within CumOp #1574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions pytensor/link/numba/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,20 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
mode = op.mode
ndim = cast(TensorVariable, node.outputs[0]).ndim

if axis is not None:
if axis < 0:
axis = ndim + axis
if axis < 0 or axis >= ndim:
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")
if axis < 0:
axis = ndim + axis
if axis < 0 or axis >= ndim:
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")

reaxis_first = (axis, *(i for i in range(ndim) if i != axis))
reaxis_first_inv = tuple(np.argsort(reaxis_first))
reaxis_first = (axis, *(i for i in range(ndim) if i != axis))
reaxis_first_inv = tuple(np.argsort(reaxis_first))

if mode == "add":
if axis is None or ndim == 1:
if ndim == 1:

@numba_basic.numba_njit
def cumop(x):
return np.cumsum(x)
return np.cumsum(x, axis=axis)

else:

Expand All @@ -71,11 +70,11 @@ def cumop(x):
return res.transpose(reaxis_first_inv)

else:
if axis is None or ndim == 1:
if ndim == 1:

@numba_basic.numba_njit
def cumop(x):
return np.cumprod(x)
return np.cumprod(x, axis=axis)

else:

Expand All @@ -92,7 +91,7 @@ def cumop(x):
for m in range(1, x.shape[axis]):
res[m] = res[m - 1] * x_axis_first[m]

return res.transpose(reaxis_first)
return res.transpose(reaxis_first_inv)

return cumop

Expand Down
9 changes: 2 additions & 7 deletions pytensor/link/pytorch/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,10 @@ def pytorch_funcify_Cumop(op, **kwargs):
mode = op.mode

def cumop(x):
if axis is None:
x = x.reshape(-1)
dim = 0
else:
dim = axis
if mode == "add":
return torch.cumsum(x, dim=dim)
return torch.cumsum(x, dim=axis)
else:
return torch.cumprod(x, dim=dim)
return torch.cumprod(x, dim=axis)

return cumop

Expand Down
63 changes: 22 additions & 41 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pytensor.scalar import upcast
from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor import basic as ptb
from pytensor.tensor.basic import alloc, join, second
from pytensor.tensor.basic import alloc, join, second, flatten
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import abs as pt_abs
from pytensor.tensor.math import all as pt_all
Expand Down Expand Up @@ -297,27 +297,25 @@ class CumOp(COp):
c_axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul"))
)

def __init__(self, axis: int | None = None, mode="add"):
def __init__(self, axis: int, mode="add"):
if mode not in ("add", "mul"):
raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"')
if not (isinstance(axis, int) or axis is None):
raise TypeError("axis must be an integer or None.")
if not isinstance(axis, int):
raise TypeError("axis must be an integer.")
if axis < 0:
raise ValueError("axis must be non-negative.")
self.axis = axis
self.mode = mode

@property
def c_axis(self) -> int:
if self.axis is None:
return numpy_axis_is_none_flag
return self.axis

def make_node(self, x):
x = ptb.as_tensor_variable(x)
out_type = x.type()

if self.axis is None:
out_type = vector(dtype=x.dtype) # Flatten
elif self.axis >= x.ndim or self.axis < -x.ndim:
if self.axis >= x.ndim:
raise ValueError(f"axis(={self.axis}) out of bounds")

return Apply(self, [x], [out_type])
Expand All @@ -334,17 +332,6 @@ def grad(self, inputs, output_gradients):
(x,) = inputs
(gi,) = output_gradients

if self.axis is None:
if self.mode == "add":
return [cumsum(gi[::-1])[::-1].reshape(x.shape)]
elif self.mode == "mul":
fx = cumprod(x, axis=self.axis)
return [cumsum((fx * gi)[::-1])[::-1].reshape(x.shape) / x]
else:
raise NotImplementedError(
f'{type(self).__name__}: unknown gradient for mode "{self.mode}"'
)

reverse_slicing = [slice(None, None, None)] * gi.ndim
reverse_slicing[self.axis] = slice(None, None, -1)
reverse_slicing = tuple(reverse_slicing)
Expand All @@ -361,9 +348,6 @@ def grad(self, inputs, output_gradients):
)

def infer_shape(self, fgraph, node, shapes):
if self.axis is None and len(shapes[0]) > 1:
return [(prod(shapes[0]),)] # Flatten

return shapes

def c_support_code_apply(self, node: Apply, name: str) -> str:
Expand All @@ -376,10 +360,7 @@ def c_code(self, node, name, inames, onames, sub):
fail = sub["fail"]
params = sub["params"]

if self.axis is None:
axis_code = "int axis = NPY_RAVEL_AXIS;\n"
else:
axis_code = f"int axis = {params}->c_axis;\n"
axis_code = f"int axis = {params}->c_axis;\n"

code = (
axis_code
Expand Down Expand Up @@ -451,7 +432,12 @@ def cumsum(x, axis=None):
.. versionadded:: 0.7

"""
return CumOp(axis=axis, mode="add")(x)
if axis is None:
return CumOp(axis=0, mode="add")(ptb.as_tensor_variable(x).ravel())
else:
x = ptb.as_tensor_variable(x)
axis = normalize_axis_index(axis, x.ndim)
return CumOp(axis=axis, mode="add")(x)


def cumprod(x, axis=None):
Expand All @@ -471,26 +457,21 @@ def cumprod(x, axis=None):
.. versionadded:: 0.7

"""
return CumOp(axis=axis, mode="mul")(x)
if axis is None:
return CumOp(axis=0, mode="mul")(ptb.as_tensor_variable(x).ravel())
else:
x = ptb.as_tensor_variable(x)
axis = normalize_axis_index(axis, x.ndim)
return CumOp(axis=axis, mode="mul")(x)


@_vectorize_node.register(CumOp)
def vectorize_cum_op(op: CumOp, node: Apply, batch_x):
"""Vectorize the CumOp to work on a batch of inputs."""
[original_x] = node.inputs
batch_ndim = batch_x.ndim - original_x.ndim
axis = op.axis
if axis is None and original_x.ndim == 1:
axis = 0
elif axis is not None:
axis = normalize_axis_index(op.axis, original_x.ndim)

if axis is None:
# Ravel all unbatched dimensions and perform CumOp on the last axis
batch_x_raveled = [batch_x.flatten(ndim=batch_ndim + 1) for x in batch_x]
return type(op)(axis=-1, mode=op.mode).make_node(batch_x_raveled)
else:
return type(op)(axis=axis + batch_ndim, mode=op.mode).make_node(batch_x)
# op.axis is already normalized and non-negative
return type(op)(axis=op.axis + batch_ndim, mode=op.mode).make_node(batch_x)


def diff(x, n=1, axis=-1):
Expand Down
11 changes: 6 additions & 5 deletions tests/tensor/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class TestCumOp(utt.InferShapeTester):
def setup_method(self):
super().setup_method()
self.op_class = CumOp
self.op = CumOp()
self.op = CumOp(axis=0) # Use a specific axis since None is no longer supported

def test_cum_op(self):
x = tensor3("x")
Expand Down Expand Up @@ -225,17 +225,18 @@ def test_infer_shape(self):
x = tensor3("x")
a = np.random.random((3, 5, 2)).astype(config.floatX)

# Test axis=None
self._compile_and_check([x], [self.op(x)], [a], self.op_class)
# Test axis=None using cumsum function (which now handles it symbolically)
self._compile_and_check([x], [cumsum(x)], [a], self.op_class)

for axis in range(-len(a.shape), len(a.shape)):
self._compile_and_check([x], [cumsum(x, axis=axis)], [a], self.op_class)

def test_grad(self):
a = np.random.random((3, 5, 2)).astype(config.floatX)

utt.verify_grad(self.op_class(mode="add"), [a]) # Test axis=None
utt.verify_grad(self.op_class(mode="mul"), [a]) # Test axis=None
# Test axis=None using cumsum/cumprod functions (which now handle it symbolically)
utt.verify_grad(lambda x: cumsum(x), [a]) # Test axis=None for cumsum
utt.verify_grad(lambda x: cumprod(x), [a]) # Test axis=None for cumprod

for axis in range(-len(a.shape), len(a.shape)):
utt.verify_grad(self.op_class(axis=axis, mode="add"), [a], eps=4e-4)
Expand Down