diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index f7700acf47..318683ddab 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -37,21 +37,15 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs): mode = op.mode ndim = cast(TensorVariable, node.outputs[0]).ndim - if axis is not None: - if axis < 0: - axis = ndim + axis - if axis < 0 or axis >= ndim: - raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}") - - reaxis_first = (axis, *(i for i in range(ndim) if i != axis)) - reaxis_first_inv = tuple(np.argsort(reaxis_first)) + reaxis_first = (axis, *(i for i in range(ndim) if i != axis)) + reaxis_first_inv = tuple(np.argsort(reaxis_first)) if mode == "add": - if axis is None or ndim == 1: + if ndim == 1: @numba_basic.numba_njit def cumop(x): - return np.cumsum(x) + return np.cumsum(x, axis=axis) else: @@ -71,11 +65,11 @@ def cumop(x): return res.transpose(reaxis_first_inv) else: - if axis is None or ndim == 1: + if ndim == 1: @numba_basic.numba_njit def cumop(x): - return np.cumprod(x) + return np.cumprod(x, axis=axis) else: @@ -92,7 +86,7 @@ def cumop(x): for m in range(1, x.shape[axis]): res[m] = res[m - 1] * x_axis_first[m] - return res.transpose(reaxis_first) + return res.transpose(reaxis_first_inv) return cumop diff --git a/pytensor/link/pytorch/dispatch/extra_ops.py b/pytensor/link/pytorch/dispatch/extra_ops.py index 74284d651d..912083f9e3 100644 --- a/pytensor/link/pytorch/dispatch/extra_ops.py +++ b/pytensor/link/pytorch/dispatch/extra_ops.py @@ -10,15 +10,10 @@ def pytorch_funcify_Cumop(op, **kwargs): mode = op.mode def cumop(x): - if axis is None: - x = x.reshape(-1) - dim = 0 - else: - dim = axis if mode == "add": - return torch.cumsum(x, dim=dim) + return torch.cumsum(x, dim=axis) else: - return torch.cumprod(x, dim=dim) + return torch.cumprod(x, dim=axis) return cumop diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index a6eafcf485..4bed73e25a 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -1,5 +1,6 @@ import warnings from collections.abc import Collection, Iterable +from textwrap import dedent import numpy as np @@ -20,7 +21,6 @@ from pytensor.npy_2_compat import ( normalize_axis_index, npy_2_compat_header, - numpy_axis_is_none_flag, old_np_unique, ) from pytensor.raise_op import Assert @@ -48,7 +48,7 @@ from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.shape import Shape_i from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor -from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector +from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes from pytensor.tensor.utils import normalize_reduce_axis from pytensor.tensor.variable import TensorVariable from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH @@ -294,30 +294,24 @@ class CumOp(COp): __props__ = ("axis", "mode") check_input = False params_type = ParamsType( - c_axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul")) + axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul")) ) - def __init__(self, axis: int | None = None, mode="add"): + def __init__(self, axis: int, mode="add"): if mode not in ("add", "mul"): raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"') - if not (isinstance(axis, int) or axis is None): - raise TypeError("axis must be an integer or None.") + if not isinstance(axis, int): + raise TypeError("axis must be an integer.") + if axis < 0: + raise ValueError("axis must be non-negative.") self.axis = axis self.mode = mode - @property - def c_axis(self) -> int: - if self.axis is None: - return numpy_axis_is_none_flag - return self.axis - def make_node(self, x): x = ptb.as_tensor_variable(x) out_type = x.type() - if self.axis is None: - out_type = vector(dtype=x.dtype) # Flatten - elif self.axis >= x.ndim or self.axis < -x.ndim: + if self.axis >= x.ndim: raise ValueError(f"axis(={self.axis}) out of bounds") return Apply(self, [x], [out_type]) @@ -330,21 +324,10 @@ def perform(self, node, inputs, output_storage): else: z[0] = np.cumprod(x, axis=self.axis) - def grad(self, inputs, output_gradients): + def L_op(self, inputs, outputs, output_gradients): (x,) = inputs (gi,) = output_gradients - if self.axis is None: - if self.mode == "add": - return [cumsum(gi[::-1])[::-1].reshape(x.shape)] - elif self.mode == "mul": - fx = cumprod(x, axis=self.axis) - return [cumsum((fx * gi)[::-1])[::-1].reshape(x.shape) / x] - else: - raise NotImplementedError( - f'{type(self).__name__}: unknown gradient for mode "{self.mode}"' - ) - reverse_slicing = [slice(None, None, None)] * gi.ndim reverse_slicing[self.axis] = slice(None, None, -1) reverse_slicing = tuple(reverse_slicing) @@ -361,9 +344,6 @@ def grad(self, inputs, output_gradients): ) def infer_shape(self, fgraph, node, shapes): - if self.axis is None and len(shapes[0]) > 1: - return [(prod(shapes[0]),)] # Flatten - return shapes def c_support_code_apply(self, node: Apply, name: str) -> str: @@ -376,61 +356,43 @@ def c_code(self, node, name, inames, onames, sub): fail = sub["fail"] params = sub["params"] - if self.axis is None: - axis_code = "int axis = NPY_RAVEL_AXIS;\n" - else: - axis_code = f"int axis = {params}->c_axis;\n" - - code = ( - axis_code - + f""" - #undef NPY_UF_DBG_TRACING - #define NPY_UF_DBG_TRACING 1 - - if (axis == 0 && PyArray_NDIM({x}) == 1) - axis = NPY_RAVEL_AXIS; - npy_intp shape[1] = {{ PyArray_SIZE({x}) }}; - if(axis == NPY_RAVEL_AXIS && !({z} && PyArray_DIMS({z})[0] == shape[0])) - {{ - Py_XDECREF({z}); - {z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE({x})); - }} + return dedent( + f""" + int axis = {params}->axis; - else if(axis != NPY_RAVEL_AXIS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x})))) - {{ - Py_XDECREF({z}); - {z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x})); - }} + if (!({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x})))) + {{ + Py_XDECREF({z}); + {z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x})); + if (!{z}){{ {fail} }}; + }} + + {{ + + PyObject * t = NULL; + if({params}->mode == MODE_ADD) + t = PyArray_CumSum({x}, axis, PyArray_TYPE({x}), {z}); + else if({params}->mode == MODE_MUL) + t = PyArray_CumProd({x}, axis, PyArray_TYPE({x}), {z}); - if (!{z}) + if (!t){{ {fail}; - {{ - - PyObject * t = NULL; - if({params}->mode == MODE_ADD) - t = PyArray_CumSum( - {x}, axis, - PyArray_TYPE({x}), {z}); - else if({params}->mode == MODE_MUL) - t = PyArray_CumProd( - {x}, axis, - PyArray_TYPE({x}), {z}); - - if (!t){{ - {fail}; - }} - // Because PyArray_CumSum/CumProd returns a newly created reference on t. - Py_XDECREF(t); }} + + // Because PyArray_CumSum/CumProd returns a newly created reference on t. + Py_XDECREF(t); + }} """ ) - return code - def c_code_cache_version(self): - return (9,) + return (10,) def __str__(self): + if self.mode == "add": + return f"Cumsum{{axis={self.axis}}}" + elif self.mode == "mul": + return f"Cumprod{{axis={self.axis}}}" return f"{self.__class__.__name__}{{{self.axis}, {self.mode}}}" @@ -451,6 +413,12 @@ def cumsum(x, axis=None): .. versionadded:: 0.7 """ + x = ptb.as_tensor_variable(x) + if axis is None: + x = x.ravel() + axis = 0 + else: + axis = normalize_axis_index(axis, x.ndim) return CumOp(axis=axis, mode="add")(x) @@ -471,6 +439,12 @@ def cumprod(x, axis=None): .. versionadded:: 0.7 """ + x = ptb.as_tensor_variable(x) + if axis is None: + x = x.ravel() + axis = 0 + else: + axis = normalize_axis_index(axis, x.ndim) return CumOp(axis=axis, mode="mul")(x) @@ -479,18 +453,8 @@ def vectorize_cum_op(op: CumOp, node: Apply, batch_x): """Vectorize the CumOp to work on a batch of inputs.""" [original_x] = node.inputs batch_ndim = batch_x.ndim - original_x.ndim - axis = op.axis - if axis is None and original_x.ndim == 1: - axis = 0 - elif axis is not None: - axis = normalize_axis_index(op.axis, original_x.ndim) - - if axis is None: - # Ravel all unbatched dimensions and perform CumOp on the last axis - batch_x_raveled = [batch_x.flatten(ndim=batch_ndim + 1) for x in batch_x] - return type(op)(axis=-1, mode=op.mode).make_node(batch_x_raveled) - else: - return type(op)(axis=axis + batch_ndim, mode=op.mode).make_node(batch_x) + # op.axis is already normalized and non-negative + return type(op)(axis=op.axis + batch_ndim, mode=op.mode).make_node(batch_x) def diff(x, n=1, axis=-1): diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 8274ddbcea..bd8eb1ce73 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -194,7 +194,7 @@ class TestCumOp(utt.InferShapeTester): def setup_method(self): super().setup_method() self.op_class = CumOp - self.op = CumOp() + self.op = CumOp(axis=0) def test_cum_op(self): x = tensor3("x") @@ -225,8 +225,8 @@ def test_infer_shape(self): x = tensor3("x") a = np.random.random((3, 5, 2)).astype(config.floatX) - # Test axis=None - self._compile_and_check([x], [self.op(x)], [a], self.op_class) + # Test default axis=None + self._compile_and_check([x], [cumsum(x)], [a], self.op_class) for axis in range(-len(a.shape), len(a.shape)): self._compile_and_check([x], [cumsum(x, axis=axis)], [a], self.op_class) @@ -234,8 +234,9 @@ def test_infer_shape(self): def test_grad(self): a = np.random.random((3, 5, 2)).astype(config.floatX) - utt.verify_grad(self.op_class(mode="add"), [a]) # Test axis=None - utt.verify_grad(self.op_class(mode="mul"), [a]) # Test axis=None + # Test default axis=None using cumsum/cumprod functions + utt.verify_grad(lambda x: cumsum(x), [a]) # Test axis=None for cumsum + utt.verify_grad(lambda x: cumprod(x), [a]) # Test axis=None for cumprod for axis in range(-len(a.shape), len(a.shape)): utt.verify_grad(self.op_class(axis=axis, mode="add"), [a], eps=4e-4)