Skip to content

Handles axis=None symbolically instead of within CumOp #1574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions pytensor/link/numba/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,15 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
mode = op.mode
ndim = cast(TensorVariable, node.outputs[0]).ndim

if axis is not None:
if axis < 0:
axis = ndim + axis
if axis < 0 or axis >= ndim:
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")

reaxis_first = (axis, *(i for i in range(ndim) if i != axis))
reaxis_first_inv = tuple(np.argsort(reaxis_first))
reaxis_first = (axis, *(i for i in range(ndim) if i != axis))
reaxis_first_inv = tuple(np.argsort(reaxis_first))

if mode == "add":
if axis is None or ndim == 1:
if ndim == 1:

@numba_basic.numba_njit
def cumop(x):
return np.cumsum(x)
return np.cumsum(x, axis=axis)

else:

Expand All @@ -71,11 +65,11 @@ def cumop(x):
return res.transpose(reaxis_first_inv)

else:
if axis is None or ndim == 1:
if ndim == 1:

@numba_basic.numba_njit
def cumop(x):
return np.cumprod(x)
return np.cumprod(x, axis=axis)

else:

Expand All @@ -92,7 +86,7 @@ def cumop(x):
for m in range(1, x.shape[axis]):
res[m] = res[m - 1] * x_axis_first[m]

return res.transpose(reaxis_first)
return res.transpose(reaxis_first_inv)

return cumop

Expand Down
9 changes: 2 additions & 7 deletions pytensor/link/pytorch/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,10 @@ def pytorch_funcify_Cumop(op, **kwargs):
mode = op.mode

def cumop(x):
if axis is None:
x = x.reshape(-1)
dim = 0
else:
dim = axis
if mode == "add":
return torch.cumsum(x, dim=dim)
return torch.cumsum(x, dim=axis)
else:
return torch.cumprod(x, dim=dim)
return torch.cumprod(x, dim=axis)

return cumop

Expand Down
138 changes: 51 additions & 87 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
from collections.abc import Collection, Iterable
from textwrap import dedent

import numpy as np

Expand All @@ -20,7 +21,6 @@
from pytensor.npy_2_compat import (
normalize_axis_index,
npy_2_compat_header,
numpy_axis_is_none_flag,
old_np_unique,
)
from pytensor.raise_op import Assert
Expand Down Expand Up @@ -48,7 +48,7 @@
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import Shape_i
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes
from pytensor.tensor.utils import normalize_reduce_axis
from pytensor.tensor.variable import TensorVariable
from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH
Expand Down Expand Up @@ -294,30 +294,24 @@ class CumOp(COp):
__props__ = ("axis", "mode")
check_input = False
params_type = ParamsType(
c_axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul"))
axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul"))
)

def __init__(self, axis: int | None = None, mode="add"):
def __init__(self, axis: int, mode="add"):
if mode not in ("add", "mul"):
raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"')
if not (isinstance(axis, int) or axis is None):
raise TypeError("axis must be an integer or None.")
if not isinstance(axis, int):
raise TypeError("axis must be an integer.")
if axis < 0:
raise ValueError("axis must be non-negative.")
self.axis = axis
self.mode = mode

@property
def c_axis(self) -> int:
if self.axis is None:
return numpy_axis_is_none_flag
return self.axis

def make_node(self, x):
x = ptb.as_tensor_variable(x)
out_type = x.type()

if self.axis is None:
out_type = vector(dtype=x.dtype) # Flatten
elif self.axis >= x.ndim or self.axis < -x.ndim:
if self.axis >= x.ndim:
raise ValueError(f"axis(={self.axis}) out of bounds")

return Apply(self, [x], [out_type])
Expand All @@ -330,21 +324,10 @@ def perform(self, node, inputs, output_storage):
else:
z[0] = np.cumprod(x, axis=self.axis)

def grad(self, inputs, output_gradients):
def L_op(self, inputs, outputs, output_gradients):
(x,) = inputs
(gi,) = output_gradients

if self.axis is None:
if self.mode == "add":
return [cumsum(gi[::-1])[::-1].reshape(x.shape)]
elif self.mode == "mul":
fx = cumprod(x, axis=self.axis)
return [cumsum((fx * gi)[::-1])[::-1].reshape(x.shape) / x]
else:
raise NotImplementedError(
f'{type(self).__name__}: unknown gradient for mode "{self.mode}"'
)

reverse_slicing = [slice(None, None, None)] * gi.ndim
reverse_slicing[self.axis] = slice(None, None, -1)
reverse_slicing = tuple(reverse_slicing)
Expand All @@ -361,9 +344,6 @@ def grad(self, inputs, output_gradients):
)

def infer_shape(self, fgraph, node, shapes):
if self.axis is None and len(shapes[0]) > 1:
return [(prod(shapes[0]),)] # Flatten

return shapes

def c_support_code_apply(self, node: Apply, name: str) -> str:
Expand All @@ -376,61 +356,43 @@ def c_code(self, node, name, inames, onames, sub):
fail = sub["fail"]
params = sub["params"]

if self.axis is None:
axis_code = "int axis = NPY_RAVEL_AXIS;\n"
else:
axis_code = f"int axis = {params}->c_axis;\n"

code = (
axis_code
+ f"""
#undef NPY_UF_DBG_TRACING
#define NPY_UF_DBG_TRACING 1

if (axis == 0 && PyArray_NDIM({x}) == 1)
axis = NPY_RAVEL_AXIS;
npy_intp shape[1] = {{ PyArray_SIZE({x}) }};
if(axis == NPY_RAVEL_AXIS && !({z} && PyArray_DIMS({z})[0] == shape[0]))
{{
Py_XDECREF({z});
{z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE({x}));
}}
return dedent(
f"""
int axis = {params}->axis;

else if(axis != NPY_RAVEL_AXIS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x}))))
{{
Py_XDECREF({z});
{z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x}));
}}
if (!({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x}))))
{{
Py_XDECREF({z});
{z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x}));
if (!{z}){{ {fail} }};
}}

{{

PyObject * t = NULL;
if({params}->mode == MODE_ADD)
t = PyArray_CumSum({x}, axis, PyArray_TYPE({x}), {z});
else if({params}->mode == MODE_MUL)
t = PyArray_CumProd({x}, axis, PyArray_TYPE({x}), {z});

if (!{z})
if (!t){{
{fail};
{{

PyObject * t = NULL;
if({params}->mode == MODE_ADD)
t = PyArray_CumSum(
{x}, axis,
PyArray_TYPE({x}), {z});
else if({params}->mode == MODE_MUL)
t = PyArray_CumProd(
{x}, axis,
PyArray_TYPE({x}), {z});

if (!t){{
{fail};
}}
// Because PyArray_CumSum/CumProd returns a newly created reference on t.
Py_XDECREF(t);
}}

// Because PyArray_CumSum/CumProd returns a newly created reference on t.
Py_XDECREF(t);
}}
"""
)

return code

def c_code_cache_version(self):
return (9,)
return (10,)

def __str__(self):
if self.mode == "add":
return f"Cumsum{{axis={self.axis}}}"
elif self.mode == "mul":
return f"Cumprod{{axis={self.axis}}}"
return f"{self.__class__.__name__}{{{self.axis}, {self.mode}}}"


Expand All @@ -451,6 +413,12 @@ def cumsum(x, axis=None):
.. versionadded:: 0.7

"""
x = ptb.as_tensor_variable(x)
if axis is None:
x = x.ravel()
axis = 0
else:
axis = normalize_axis_index(axis, x.ndim)
return CumOp(axis=axis, mode="add")(x)


Expand All @@ -471,6 +439,12 @@ def cumprod(x, axis=None):
.. versionadded:: 0.7

"""
x = ptb.as_tensor_variable(x)
if axis is None:
x = x.ravel()
axis = 0
else:
axis = normalize_axis_index(axis, x.ndim)
return CumOp(axis=axis, mode="mul")(x)


Expand All @@ -479,18 +453,8 @@ def vectorize_cum_op(op: CumOp, node: Apply, batch_x):
"""Vectorize the CumOp to work on a batch of inputs."""
[original_x] = node.inputs
batch_ndim = batch_x.ndim - original_x.ndim
axis = op.axis
if axis is None and original_x.ndim == 1:
axis = 0
elif axis is not None:
axis = normalize_axis_index(op.axis, original_x.ndim)

if axis is None:
# Ravel all unbatched dimensions and perform CumOp on the last axis
batch_x_raveled = [batch_x.flatten(ndim=batch_ndim + 1) for x in batch_x]
return type(op)(axis=-1, mode=op.mode).make_node(batch_x_raveled)
else:
return type(op)(axis=axis + batch_ndim, mode=op.mode).make_node(batch_x)
# op.axis is already normalized and non-negative
return type(op)(axis=op.axis + batch_ndim, mode=op.mode).make_node(batch_x)


def diff(x, n=1, axis=-1):
Expand Down
11 changes: 6 additions & 5 deletions tests/tensor/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class TestCumOp(utt.InferShapeTester):
def setup_method(self):
super().setup_method()
self.op_class = CumOp
self.op = CumOp()
self.op = CumOp(axis=0)

def test_cum_op(self):
x = tensor3("x")
Expand Down Expand Up @@ -225,17 +225,18 @@ def test_infer_shape(self):
x = tensor3("x")
a = np.random.random((3, 5, 2)).astype(config.floatX)

# Test axis=None
self._compile_and_check([x], [self.op(x)], [a], self.op_class)
# Test default axis=None
self._compile_and_check([x], [cumsum(x)], [a], self.op_class)

for axis in range(-len(a.shape), len(a.shape)):
self._compile_and_check([x], [cumsum(x, axis=axis)], [a], self.op_class)

def test_grad(self):
a = np.random.random((3, 5, 2)).astype(config.floatX)

utt.verify_grad(self.op_class(mode="add"), [a]) # Test axis=None
utt.verify_grad(self.op_class(mode="mul"), [a]) # Test axis=None
# Test default axis=None using cumsum/cumprod functions
utt.verify_grad(lambda x: cumsum(x), [a]) # Test axis=None for cumsum
utt.verify_grad(lambda x: cumprod(x), [a]) # Test axis=None for cumprod

for axis in range(-len(a.shape), len(a.shape)):
utt.verify_grad(self.op_class(axis=axis, mode="add"), [a], eps=4e-4)
Expand Down
Loading