Skip to content

Commit 49f76da

Browse files
CopilotricardoV94
authored andcommitted
Implement axis=None raveling behavior symbolically in CumOp
1 parent 9b522a8 commit 49f76da

File tree

6 files changed

+76
-161
lines changed

6 files changed

+76
-161
lines changed

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,15 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
4141
mode = op.mode
4242
ndim = cast(TensorVariable, node.outputs[0]).ndim
4343

44-
if axis is not None:
45-
if axis < 0:
46-
axis = ndim + axis
47-
if axis < 0 or axis >= ndim:
48-
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")
49-
50-
reaxis_first = (axis, *(i for i in range(ndim) if i != axis))
51-
reaxis_first_inv = tuple(np.argsort(reaxis_first))
44+
reaxis_first = (axis, *(i for i in range(ndim) if i != axis))
45+
reaxis_first_inv = tuple(np.argsort(reaxis_first))
5246

5347
if mode == "add":
54-
if axis is None or ndim == 1:
48+
if ndim == 1:
5549

5650
@numba_basic.numba_njit
5751
def cumop(x):
58-
return np.cumsum(x)
52+
return np.cumsum(x, axis=axis)
5953

6054
else:
6155

@@ -75,11 +69,11 @@ def cumop(x):
7569
return res.transpose(reaxis_first_inv)
7670

7771
else:
78-
if axis is None or ndim == 1:
72+
if ndim == 1:
7973

8074
@numba_basic.numba_njit
8175
def cumop(x):
82-
return np.cumprod(x)
76+
return np.cumprod(x, axis=axis)
8377

8478
else:
8579

@@ -96,7 +90,7 @@ def cumop(x):
9690
for m in range(1, x.shape[axis]):
9791
res[m] = res[m - 1] * x_axis_first[m]
9892

99-
return res.transpose(reaxis_first)
93+
return res.transpose(reaxis_first_inv)
10094

10195
return cumop
10296

pytensor/link/pytorch/dispatch/extra_ops.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,10 @@ def pytorch_funcify_Cumop(op, **kwargs):
1010
mode = op.mode
1111

1212
def cumop(x):
13-
if axis is None:
14-
x = x.reshape(-1)
15-
dim = 0
16-
else:
17-
dim = axis
1813
if mode == "add":
19-
return torch.cumsum(x, dim=dim)
14+
return torch.cumsum(x, dim=axis)
2015
else:
21-
return torch.cumprod(x, dim=dim)
16+
return torch.cumprod(x, dim=axis)
2217

2318
return cumop
2419

pytensor/tensor/extra_ops.py

Lines changed: 56 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import warnings
22
from collections.abc import Collection, Iterable
3+
from textwrap import dedent
34

45
import numpy as np
56
from numpy.lib.array_utils import normalize_axis_index
@@ -44,10 +45,10 @@
4445
from pytensor.tensor.math import sum as pt_sum
4546
from pytensor.tensor.shape import Shape_i
4647
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
47-
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
48+
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes
4849
from pytensor.tensor.utils import normalize_reduce_axis
4950
from pytensor.tensor.variable import TensorVariable
50-
from pytensor.utils import LOCAL_BITWIDTH, NPY_RAVEL_AXIS, PYTHON_INT_BITWIDTH
51+
from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH
5152

5253

5354
class CpuContiguous(COp):
@@ -290,33 +291,28 @@ class CumOp(COp):
290291
__props__ = ("axis", "mode")
291292
check_input = False
292293
params_type = ParamsType(
293-
c_axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul"))
294+
axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul"))
294295
)
295296

296-
def __init__(self, axis: int | None = None, mode="add"):
297+
def __init__(self, axis: int, mode="add"):
297298
if mode not in ("add", "mul"):
298299
raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"')
299-
if not (isinstance(axis, int) or axis is None):
300-
raise TypeError("axis must be an integer or None.")
300+
if not isinstance(axis, int):
301+
raise TypeError(f"axis must be an integer, got {axis} of type {type(axis)}")
302+
if axis < 0:
303+
raise ValueError(f"axis must be non-negative, got {axis}")
301304
self.axis = axis
302305
self.mode = mode
303306

304-
@property
305-
def c_axis(self) -> int:
306-
if self.axis is None:
307-
return NPY_RAVEL_AXIS
308-
return self.axis
309-
310307
def make_node(self, x):
311308
x = ptb.as_tensor_variable(x)
312-
out_type = x.type()
313309

314-
if self.axis is None:
315-
out_type = vector(dtype=x.dtype) # Flatten
316-
elif self.axis >= x.ndim or self.axis < -x.ndim:
317-
raise ValueError(f"axis(={self.axis}) out of bounds")
310+
if self.axis >= x.type.ndim:
311+
raise ValueError(
312+
f"axis(={self.axis}) out of bounds for variable {x} with {x.type.ndim} ndims"
313+
)
318314

319-
return Apply(self, [x], [out_type])
315+
return Apply(self, [x], [x.type()])
320316

321317
def perform(self, node, inputs, output_storage):
322318
x = inputs[0]
@@ -326,21 +322,10 @@ def perform(self, node, inputs, output_storage):
326322
else:
327323
z[0] = np.cumprod(x, axis=self.axis)
328324

329-
def grad(self, inputs, output_gradients):
325+
def L_op(self, inputs, outputs, output_gradients):
330326
(x,) = inputs
331327
(gi,) = output_gradients
332328

333-
if self.axis is None:
334-
if self.mode == "add":
335-
return [cumsum(gi[::-1])[::-1].reshape(x.shape)]
336-
elif self.mode == "mul":
337-
fx = cumprod(x, axis=self.axis)
338-
return [cumsum((fx * gi)[::-1])[::-1].reshape(x.shape) / x]
339-
else:
340-
raise NotImplementedError(
341-
f'{type(self).__name__}: unknown gradient for mode "{self.mode}"'
342-
)
343-
344329
reverse_slicing = [slice(None, None, None)] * gi.ndim
345330
reverse_slicing[self.axis] = slice(None, None, -1)
346331
reverse_slicing = tuple(reverse_slicing)
@@ -357,9 +342,6 @@ def grad(self, inputs, output_gradients):
357342
)
358343

359344
def infer_shape(self, fgraph, node, shapes):
360-
if self.axis is None and len(shapes[0]) > 1:
361-
return [(prod(shapes[0]),)] # Flatten
362-
363345
return shapes
364346

365347
def c_code(self, node, name, inames, onames, sub):
@@ -368,61 +350,43 @@ def c_code(self, node, name, inames, onames, sub):
368350
fail = sub["fail"]
369351
params = sub["params"]
370352

371-
if self.axis is None:
372-
axis_code = "int axis = NPY_RAVEL_AXIS;\n"
373-
else:
374-
axis_code = f"int axis = {params}->c_axis;\n"
375-
376-
code = (
377-
axis_code
378-
+ f"""
379-
#undef NPY_UF_DBG_TRACING
380-
#define NPY_UF_DBG_TRACING 1
381-
382-
if (axis == 0 && PyArray_NDIM({x}) == 1)
383-
axis = NPY_RAVEL_AXIS;
384-
npy_intp shape[1] = {{ PyArray_SIZE({x}) }};
385-
if(axis == NPY_RAVEL_AXIS && !({z} && PyArray_DIMS({z})[0] == shape[0]))
386-
{{
387-
Py_XDECREF({z});
388-
{z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE({x}));
389-
}}
353+
return dedent(
354+
f"""
355+
int axis = {params}->axis;
390356
391-
else if(axis != NPY_RAVEL_AXIS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x}))))
392-
{{
393-
Py_XDECREF({z});
394-
{z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x}));
395-
}}
357+
if (!({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x}))))
358+
{{
359+
Py_XDECREF({z});
360+
{z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x}));
361+
if (!{z}){{ {fail} }};
362+
}}
363+
364+
{{
396365
397-
if (!{z})
366+
PyObject * t = NULL;
367+
if({params}->mode == MODE_ADD)
368+
t = PyArray_CumSum({x}, axis, PyArray_TYPE({x}), {z});
369+
else if({params}->mode == MODE_MUL)
370+
t = PyArray_CumProd({x}, axis, PyArray_TYPE({x}), {z});
371+
372+
if (!t){{
398373
{fail};
399-
{{
400-
401-
PyObject * t = NULL;
402-
if({params}->mode == MODE_ADD)
403-
t = PyArray_CumSum(
404-
{x}, axis,
405-
PyArray_TYPE({x}), {z});
406-
else if({params}->mode == MODE_MUL)
407-
t = PyArray_CumProd(
408-
{x}, axis,
409-
PyArray_TYPE({x}), {z});
410-
411-
if (!t){{
412-
{fail};
413-
}}
414-
// Because PyArray_CumSum/CumProd returns a newly created reference on t.
415-
Py_XDECREF(t);
416374
}}
375+
376+
// Because PyArray_CumSum/CumProd returns a newly created reference on t.
377+
Py_XDECREF(t);
378+
}}
417379
"""
418380
)
419381

420-
return code
421-
422382
def c_code_cache_version(self):
423-
return (10,)
383+
return (11,)
424384

425385
def __str__(self):
386+
if self.mode == "add":
387+
return f"Cumsum{{axis={self.axis}}}"
388+
elif self.mode == "mul":
389+
return f"Cumprod{{axis={self.axis}}}"
426390
return f"{self.__class__.__name__}{{{self.axis}, {self.mode}}}"
427391

428392

@@ -443,6 +407,12 @@ def cumsum(x, axis=None):
443407
.. versionadded:: 0.7
444408
445409
"""
410+
x = ptb.as_tensor_variable(x)
411+
if axis is None:
412+
x = x.ravel()
413+
axis = 0
414+
else:
415+
axis = normalize_axis_index(axis, x.ndim)
446416
return CumOp(axis=axis, mode="add")(x)
447417

448418

@@ -463,6 +433,12 @@ def cumprod(x, axis=None):
463433
.. versionadded:: 0.7
464434
465435
"""
436+
x = ptb.as_tensor_variable(x)
437+
if axis is None:
438+
x = x.ravel()
439+
axis = 0
440+
else:
441+
axis = normalize_axis_index(axis, x.ndim)
466442
return CumOp(axis=axis, mode="mul")(x)
467443

468444

@@ -471,18 +447,8 @@ def vectorize_cum_op(op: CumOp, node: Apply, batch_x):
471447
"""Vectorize the CumOp to work on a batch of inputs."""
472448
[original_x] = node.inputs
473449
batch_ndim = batch_x.ndim - original_x.ndim
474-
axis = op.axis
475-
if axis is None and original_x.ndim == 1:
476-
axis = 0
477-
elif axis is not None:
478-
axis = normalize_axis_index(op.axis, original_x.ndim)
479-
480-
if axis is None:
481-
# Ravel all unbatched dimensions and perform CumOp on the last axis
482-
batch_x_raveled = [batch_x.flatten(ndim=batch_ndim + 1) for x in batch_x]
483-
return type(op)(axis=-1, mode=op.mode).make_node(batch_x_raveled)
484-
else:
485-
return type(op)(axis=axis + batch_ndim, mode=op.mode).make_node(batch_x)
450+
# op.axis is already normalized and non-negative
451+
return type(op)(axis=op.axis + batch_ndim, mode=op.mode).make_node(batch_x)
486452

487453

488454
def diff(x, n=1, axis=-1):

tests/link/numba/test_extra_ops.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,6 @@ def test_Bartlett(val):
3838
1,
3939
"add",
4040
),
41-
(
42-
(pt.dtensor3(), np.arange(30, dtype=config.floatX).reshape((2, 3, 5))),
43-
-1,
44-
"add",
45-
),
4641
(
4742
(pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
4843
0,
@@ -53,11 +48,6 @@ def test_Bartlett(val):
5348
1,
5449
"add",
5550
),
56-
(
57-
(pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
58-
None,
59-
"add",
60-
),
6151
(
6252
(pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
6353
0,
@@ -68,11 +58,6 @@ def test_Bartlett(val):
6858
1,
6959
"mul",
7060
),
71-
(
72-
(pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
73-
None,
74-
"mul",
75-
),
7661
],
7762
)
7863
def test_CumOp(val, axis, mode):

tests/link/pytorch/test_extra_ops.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,39 +5,13 @@
55
from tests.link.pytorch.test_basic import compare_pytorch_and_py
66

77

8-
@pytest.mark.parametrize(
9-
"dtype",
10-
["float64", "int64"],
11-
)
12-
@pytest.mark.parametrize(
13-
"axis",
14-
[None, 1, (0,)],
15-
)
8+
@pytest.mark.parametrize("dtype", ["float64", "int64"])
9+
@pytest.mark.parametrize("axis", [None, -1])
1610
def test_pytorch_CumOp(axis, dtype):
17-
"""Test PyTorch conversion of the `CumOp` `Op`."""
18-
19-
# Create a symbolic input for the first input of `CumOp`
2011
a = pt.matrix("a", dtype=dtype)
21-
22-
# Create test value
2312
test_value = np.arange(9, dtype=dtype).reshape((3, 3))
24-
25-
# Create the output variable
26-
if isinstance(axis, tuple):
27-
with pytest.raises(TypeError, match="axis must be an integer or None\\."):
28-
out = pt.cumsum(a, axis=axis)
29-
with pytest.raises(TypeError, match="axis must be an integer or None\\."):
30-
out = pt.cumprod(a, axis=axis)
31-
else:
32-
out = pt.cumsum(a, axis=axis)
33-
34-
# Pass the inputs and outputs to the testing function
35-
compare_pytorch_and_py([a], [out], [test_value])
36-
37-
# For the second mode of CumOp
38-
out = pt.cumprod(a, axis=axis)
39-
40-
compare_pytorch_and_py([a], [out], [test_value])
13+
outs = [pt.cumsum(a, axis=axis), pt.cumprod(a, axis=axis)]
14+
compare_pytorch_and_py([a], outs, [test_value])
4115

4216

4317
@pytest.mark.parametrize("axis, repeats", [(0, (1, 2, 3)), (1, (3, 3)), (None, 3)])

0 commit comments

Comments
 (0)