Skip to content

Commit 98d7c87

Browse files
CopilotricardoV94
andcommitted
Implement symbolic raveling for CumOp: handle axis=None symbolically using flatten
Co-authored-by: ricardoV94 <[email protected]>
1 parent 6a297ff commit 98d7c87

File tree

4 files changed

+39
-65
lines changed

4 files changed

+39
-65
lines changed

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,20 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
3737
mode = op.mode
3838
ndim = cast(TensorVariable, node.outputs[0]).ndim
3939

40-
if axis is not None:
41-
if axis < 0:
42-
axis = ndim + axis
43-
if axis < 0 or axis >= ndim:
44-
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")
40+
if axis < 0:
41+
axis = ndim + axis
42+
if axis < 0 or axis >= ndim:
43+
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")
4544

46-
reaxis_first = (axis, *(i for i in range(ndim) if i != axis))
47-
reaxis_first_inv = tuple(np.argsort(reaxis_first))
45+
reaxis_first = (axis, *(i for i in range(ndim) if i != axis))
46+
reaxis_first_inv = tuple(np.argsort(reaxis_first))
4847

4948
if mode == "add":
50-
if axis is None or ndim == 1:
49+
if ndim == 1:
5150

5251
@numba_basic.numba_njit
5352
def cumop(x):
54-
return np.cumsum(x)
53+
return np.cumsum(x, axis=axis)
5554

5655
else:
5756

@@ -71,11 +70,11 @@ def cumop(x):
7170
return res.transpose(reaxis_first_inv)
7271

7372
else:
74-
if axis is None or ndim == 1:
73+
if ndim == 1:
7574

7675
@numba_basic.numba_njit
7776
def cumop(x):
78-
return np.cumprod(x)
77+
return np.cumprod(x, axis=axis)
7978

8079
else:
8180

@@ -92,7 +91,7 @@ def cumop(x):
9291
for m in range(1, x.shape[axis]):
9392
res[m] = res[m - 1] * x_axis_first[m]
9493

95-
return res.transpose(reaxis_first)
94+
return res.transpose(reaxis_first_inv)
9695

9796
return cumop
9897

pytensor/link/pytorch/dispatch/extra_ops.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,10 @@ def pytorch_funcify_Cumop(op, **kwargs):
1010
mode = op.mode
1111

1212
def cumop(x):
13-
if axis is None:
14-
x = x.reshape(-1)
15-
dim = 0
16-
else:
17-
dim = axis
1813
if mode == "add":
19-
return torch.cumsum(x, dim=dim)
14+
return torch.cumsum(x, dim=axis)
2015
else:
21-
return torch.cumprod(x, dim=dim)
16+
return torch.cumprod(x, dim=axis)
2217

2318
return cumop
2419

pytensor/tensor/extra_ops.py

Lines changed: 20 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pytensor.scalar import upcast
2929
from pytensor.tensor import TensorLike, as_tensor_variable
3030
from pytensor.tensor import basic as ptb
31-
from pytensor.tensor.basic import alloc, join, second
31+
from pytensor.tensor.basic import alloc, join, second, flatten
3232
from pytensor.tensor.exceptions import NotScalarConstantError
3333
from pytensor.tensor.math import abs as pt_abs
3434
from pytensor.tensor.math import all as pt_all
@@ -297,27 +297,23 @@ class CumOp(COp):
297297
c_axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul"))
298298
)
299299

300-
def __init__(self, axis: int | None = None, mode="add"):
300+
def __init__(self, axis: int, mode="add"):
301301
if mode not in ("add", "mul"):
302302
raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"')
303-
if not (isinstance(axis, int) or axis is None):
304-
raise TypeError("axis must be an integer or None.")
303+
if not isinstance(axis, int):
304+
raise TypeError("axis must be an integer.")
305305
self.axis = axis
306306
self.mode = mode
307307

308308
@property
309309
def c_axis(self) -> int:
310-
if self.axis is None:
311-
return numpy_axis_is_none_flag
312310
return self.axis
313311

314312
def make_node(self, x):
315313
x = ptb.as_tensor_variable(x)
316314
out_type = x.type()
317315

318-
if self.axis is None:
319-
out_type = vector(dtype=x.dtype) # Flatten
320-
elif self.axis >= x.ndim or self.axis < -x.ndim:
316+
if self.axis >= x.ndim or self.axis < -x.ndim:
321317
raise ValueError(f"axis(={self.axis}) out of bounds")
322318

323319
return Apply(self, [x], [out_type])
@@ -334,17 +330,6 @@ def grad(self, inputs, output_gradients):
334330
(x,) = inputs
335331
(gi,) = output_gradients
336332

337-
if self.axis is None:
338-
if self.mode == "add":
339-
return [cumsum(gi[::-1])[::-1].reshape(x.shape)]
340-
elif self.mode == "mul":
341-
fx = cumprod(x, axis=self.axis)
342-
return [cumsum((fx * gi)[::-1])[::-1].reshape(x.shape) / x]
343-
else:
344-
raise NotImplementedError(
345-
f'{type(self).__name__}: unknown gradient for mode "{self.mode}"'
346-
)
347-
348333
reverse_slicing = [slice(None, None, None)] * gi.ndim
349334
reverse_slicing[self.axis] = slice(None, None, -1)
350335
reverse_slicing = tuple(reverse_slicing)
@@ -361,9 +346,6 @@ def grad(self, inputs, output_gradients):
361346
)
362347

363348
def infer_shape(self, fgraph, node, shapes):
364-
if self.axis is None and len(shapes[0]) > 1:
365-
return [(prod(shapes[0]),)] # Flatten
366-
367349
return shapes
368350

369351
def c_support_code_apply(self, node: Apply, name: str) -> str:
@@ -376,10 +358,7 @@ def c_code(self, node, name, inames, onames, sub):
376358
fail = sub["fail"]
377359
params = sub["params"]
378360

379-
if self.axis is None:
380-
axis_code = "int axis = NPY_RAVEL_AXIS;\n"
381-
else:
382-
axis_code = f"int axis = {params}->c_axis;\n"
361+
axis_code = f"int axis = {params}->c_axis;\n"
383362

384363
code = (
385364
axis_code
@@ -451,7 +430,12 @@ def cumsum(x, axis=None):
451430
.. versionadded:: 0.7
452431
453432
"""
454-
return CumOp(axis=axis, mode="add")(x)
433+
if axis is None:
434+
# Handle raveling symbolically by flattening first, then applying cumsum with axis=0
435+
x_flattened = flatten(x, ndim=1) # This creates a 1D tensor
436+
return CumOp(axis=0, mode="add")(x_flattened)
437+
else:
438+
return CumOp(axis=axis, mode="add")(x)
455439

456440

457441
def cumprod(x, axis=None):
@@ -471,26 +455,21 @@ def cumprod(x, axis=None):
471455
.. versionadded:: 0.7
472456
473457
"""
474-
return CumOp(axis=axis, mode="mul")(x)
458+
if axis is None:
459+
# Handle raveling symbolically by flattening first, then applying cumprod with axis=0
460+
x_flattened = flatten(x, ndim=1) # This creates a 1D tensor
461+
return CumOp(axis=0, mode="mul")(x_flattened)
462+
else:
463+
return CumOp(axis=axis, mode="mul")(x)
475464

476465

477466
@_vectorize_node.register(CumOp)
478467
def vectorize_cum_op(op: CumOp, node: Apply, batch_x):
479468
"""Vectorize the CumOp to work on a batch of inputs."""
480469
[original_x] = node.inputs
481470
batch_ndim = batch_x.ndim - original_x.ndim
482-
axis = op.axis
483-
if axis is None and original_x.ndim == 1:
484-
axis = 0
485-
elif axis is not None:
486-
axis = normalize_axis_index(op.axis, original_x.ndim)
487-
488-
if axis is None:
489-
# Ravel all unbatched dimensions and perform CumOp on the last axis
490-
batch_x_raveled = [batch_x.flatten(ndim=batch_ndim + 1) for x in batch_x]
491-
return type(op)(axis=-1, mode=op.mode).make_node(batch_x_raveled)
492-
else:
493-
return type(op)(axis=axis + batch_ndim, mode=op.mode).make_node(batch_x)
471+
axis = normalize_axis_index(op.axis, original_x.ndim)
472+
return type(op)(axis=axis + batch_ndim, mode=op.mode).make_node(batch_x)
494473

495474

496475
def diff(x, n=1, axis=-1):

tests/tensor/test_extra_ops.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ class TestCumOp(utt.InferShapeTester):
194194
def setup_method(self):
195195
super().setup_method()
196196
self.op_class = CumOp
197-
self.op = CumOp()
197+
self.op = CumOp(axis=0) # Use a specific axis since None is no longer supported
198198

199199
def test_cum_op(self):
200200
x = tensor3("x")
@@ -225,17 +225,18 @@ def test_infer_shape(self):
225225
x = tensor3("x")
226226
a = np.random.random((3, 5, 2)).astype(config.floatX)
227227

228-
# Test axis=None
229-
self._compile_and_check([x], [self.op(x)], [a], self.op_class)
228+
# Test axis=None using cumsum function (which now handles it symbolically)
229+
self._compile_and_check([x], [cumsum(x)], [a], type(cumsum(x).owner.op))
230230

231231
for axis in range(-len(a.shape), len(a.shape)):
232232
self._compile_and_check([x], [cumsum(x, axis=axis)], [a], self.op_class)
233233

234234
def test_grad(self):
235235
a = np.random.random((3, 5, 2)).astype(config.floatX)
236236

237-
utt.verify_grad(self.op_class(mode="add"), [a]) # Test axis=None
238-
utt.verify_grad(self.op_class(mode="mul"), [a]) # Test axis=None
237+
# Test axis=None using cumsum/cumprod functions (which now handle it symbolically)
238+
utt.verify_grad(lambda x: cumsum(x), [a]) # Test axis=None for cumsum
239+
utt.verify_grad(lambda x: cumprod(x), [a]) # Test axis=None for cumprod
239240

240241
for axis in range(-len(a.shape), len(a.shape)):
241242
utt.verify_grad(self.op_class(axis=axis, mode="add"), [a], eps=4e-4)

0 commit comments

Comments
 (0)