-
Notifications
You must be signed in to change notification settings - Fork 139
Open
Description
Description
When you perform cumsum on a matrix with axis=None, the input is first raveled and then a regular cumsum operation is done.
We could handle this ravelling symbolically and simplify the implementation of CumOp not to have to worry about this case:
pytensor/pytensor/tensor/extra_ops.py
Lines 290 to 428 in b83ca3f
class CumOp(COp): | |
# See function cumsum/cumprod for docstring | |
__props__ = ("axis", "mode") | |
check_input = False | |
params_type = ParamsType( | |
c_axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul")) | |
) | |
def __init__(self, axis: int | None = None, mode="add"): | |
if mode not in ("add", "mul"): | |
raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"') | |
if not (isinstance(axis, int) or axis is None): | |
raise TypeError("axis must be an integer or None.") | |
self.axis = axis | |
self.mode = mode | |
@property | |
def c_axis(self) -> int: | |
if self.axis is None: | |
return numpy_axis_is_none_flag | |
return self.axis | |
def make_node(self, x): | |
x = ptb.as_tensor_variable(x) | |
out_type = x.type() | |
if self.axis is None: | |
out_type = vector(dtype=x.dtype) # Flatten | |
elif self.axis >= x.ndim or self.axis < -x.ndim: | |
raise ValueError(f"axis(={self.axis}) out of bounds") | |
return Apply(self, [x], [out_type]) | |
def perform(self, node, inputs, output_storage): | |
x = inputs[0] | |
z = output_storage[0] | |
if self.mode == "add": | |
z[0] = np.cumsum(x, axis=self.axis) | |
else: | |
z[0] = np.cumprod(x, axis=self.axis) | |
def grad(self, inputs, output_gradients): | |
(x,) = inputs | |
(gi,) = output_gradients | |
if self.axis is None: | |
if self.mode == "add": | |
return [cumsum(gi[::-1])[::-1].reshape(x.shape)] | |
elif self.mode == "mul": | |
fx = cumprod(x, axis=self.axis) | |
return [cumsum((fx * gi)[::-1])[::-1].reshape(x.shape) / x] | |
else: | |
raise NotImplementedError( | |
f'{type(self).__name__}: unknown gradient for mode "{self.mode}"' | |
) | |
reverse_slicing = [slice(None, None, None)] * gi.ndim | |
reverse_slicing[self.axis] = slice(None, None, -1) | |
reverse_slicing = tuple(reverse_slicing) | |
# We need to reverse the gradients along ``self.axis``, | |
# compute cumsum, then reverse again | |
if self.mode == "add": | |
return [cumsum(gi[reverse_slicing], self.axis)[reverse_slicing]] | |
elif self.mode == "mul": | |
fx = cumprod(x, axis=self.axis) | |
return [cumsum((fx * gi)[reverse_slicing], self.axis)[reverse_slicing] / x] | |
else: | |
raise NotImplementedError( | |
f'{type(self).__name__}: unknown gradient for mode "{self.mode}"' | |
) | |
def infer_shape(self, fgraph, node, shapes): | |
if self.axis is None and len(shapes[0]) > 1: | |
return [(prod(shapes[0]),)] # Flatten | |
return shapes | |
def c_support_code_apply(self, node: Apply, name: str) -> str: | |
"""Needed to define NPY_RAVEL_AXIS""" | |
return npy_2_compat_header() | |
def c_code(self, node, name, inames, onames, sub): | |
(x,) = inames | |
(z,) = onames | |
fail = sub["fail"] | |
params = sub["params"] | |
if self.axis is None: | |
axis_code = "int axis = NPY_RAVEL_AXIS;\n" | |
else: | |
axis_code = f"int axis = {params}->c_axis;\n" | |
code = ( | |
axis_code | |
+ f""" | |
#undef NPY_UF_DBG_TRACING | |
#define NPY_UF_DBG_TRACING 1 | |
if (axis == 0 && PyArray_NDIM({x}) == 1) | |
axis = NPY_RAVEL_AXIS; | |
npy_intp shape[1] = {{ PyArray_SIZE({x}) }}; | |
if(axis == NPY_RAVEL_AXIS && !({z} && PyArray_DIMS({z})[0] == shape[0])) | |
{{ | |
Py_XDECREF({z}); | |
{z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE({x})); | |
}} | |
else if(axis != NPY_RAVEL_AXIS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x})))) | |
{{ | |
Py_XDECREF({z}); | |
{z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x})); | |
}} | |
if (!{z}) | |
{fail}; | |
{{ | |
PyObject * t = NULL; | |
if({params}->mode == MODE_ADD) | |
t = PyArray_CumSum( | |
{x}, axis, | |
PyArray_TYPE({x}), {z}); | |
else if({params}->mode == MODE_MUL) | |
t = PyArray_CumProd( | |
{x}, axis, | |
PyArray_TYPE({x}), {z}); | |
if (!t){{ | |
{fail}; | |
}} | |
// Because PyArray_CumSum/CumProd returns a newly created reference on t. | |
Py_XDECREF(t); | |
}} | |
""" | |
) | |
return code |