Skip to content

Handle special ravelling behavior of CumOp symbolically #1549

@ricardoV94

Description

@ricardoV94

Description

When you perform cumsum on a matrix with axis=None, the input is first raveled and then a regular cumsum operation is done.

We could handle this ravelling symbolically and simplify the implementation of CumOp not to have to worry about this case:

class CumOp(COp):
# See function cumsum/cumprod for docstring
__props__ = ("axis", "mode")
check_input = False
params_type = ParamsType(
c_axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul"))
)
def __init__(self, axis: int | None = None, mode="add"):
if mode not in ("add", "mul"):
raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"')
if not (isinstance(axis, int) or axis is None):
raise TypeError("axis must be an integer or None.")
self.axis = axis
self.mode = mode
@property
def c_axis(self) -> int:
if self.axis is None:
return numpy_axis_is_none_flag
return self.axis
def make_node(self, x):
x = ptb.as_tensor_variable(x)
out_type = x.type()
if self.axis is None:
out_type = vector(dtype=x.dtype) # Flatten
elif self.axis >= x.ndim or self.axis < -x.ndim:
raise ValueError(f"axis(={self.axis}) out of bounds")
return Apply(self, [x], [out_type])
def perform(self, node, inputs, output_storage):
x = inputs[0]
z = output_storage[0]
if self.mode == "add":
z[0] = np.cumsum(x, axis=self.axis)
else:
z[0] = np.cumprod(x, axis=self.axis)
def grad(self, inputs, output_gradients):
(x,) = inputs
(gi,) = output_gradients
if self.axis is None:
if self.mode == "add":
return [cumsum(gi[::-1])[::-1].reshape(x.shape)]
elif self.mode == "mul":
fx = cumprod(x, axis=self.axis)
return [cumsum((fx * gi)[::-1])[::-1].reshape(x.shape) / x]
else:
raise NotImplementedError(
f'{type(self).__name__}: unknown gradient for mode "{self.mode}"'
)
reverse_slicing = [slice(None, None, None)] * gi.ndim
reverse_slicing[self.axis] = slice(None, None, -1)
reverse_slicing = tuple(reverse_slicing)
# We need to reverse the gradients along ``self.axis``,
# compute cumsum, then reverse again
if self.mode == "add":
return [cumsum(gi[reverse_slicing], self.axis)[reverse_slicing]]
elif self.mode == "mul":
fx = cumprod(x, axis=self.axis)
return [cumsum((fx * gi)[reverse_slicing], self.axis)[reverse_slicing] / x]
else:
raise NotImplementedError(
f'{type(self).__name__}: unknown gradient for mode "{self.mode}"'
)
def infer_shape(self, fgraph, node, shapes):
if self.axis is None and len(shapes[0]) > 1:
return [(prod(shapes[0]),)] # Flatten
return shapes
def c_support_code_apply(self, node: Apply, name: str) -> str:
"""Needed to define NPY_RAVEL_AXIS"""
return npy_2_compat_header()
def c_code(self, node, name, inames, onames, sub):
(x,) = inames
(z,) = onames
fail = sub["fail"]
params = sub["params"]
if self.axis is None:
axis_code = "int axis = NPY_RAVEL_AXIS;\n"
else:
axis_code = f"int axis = {params}->c_axis;\n"
code = (
axis_code
+ f"""
#undef NPY_UF_DBG_TRACING
#define NPY_UF_DBG_TRACING 1
if (axis == 0 && PyArray_NDIM({x}) == 1)
axis = NPY_RAVEL_AXIS;
npy_intp shape[1] = {{ PyArray_SIZE({x}) }};
if(axis == NPY_RAVEL_AXIS && !({z} && PyArray_DIMS({z})[0] == shape[0]))
{{
Py_XDECREF({z});
{z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE({x}));
}}
else if(axis != NPY_RAVEL_AXIS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x}))))
{{
Py_XDECREF({z});
{z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x}));
}}
if (!{z})
{fail};
{{
PyObject * t = NULL;
if({params}->mode == MODE_ADD)
t = PyArray_CumSum(
{x}, axis,
PyArray_TYPE({x}), {z});
else if({params}->mode == MODE_MUL)
t = PyArray_CumProd(
{x}, axis,
PyArray_TYPE({x}), {z});
if (!t){{
{fail};
}}
// Because PyArray_CumSum/CumProd returns a newly created reference on t.
Py_XDECREF(t);
}}
"""
)
return code

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions