Skip to content

Commit d3bfb15

Browse files
CopilotricardoV94
andcommitted
Address review comments: use .ravel(), normalize axis, simplify negative axis handling
Co-authored-by: ricardoV94 <[email protected]>
1 parent 98d7c87 commit d3bfb15

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,8 @@ def __init__(self, axis: int, mode="add"):
302302
raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"')
303303
if not isinstance(axis, int):
304304
raise TypeError("axis must be an integer.")
305+
if axis < 0:
306+
raise ValueError("axis must be non-negative.")
305307
self.axis = axis
306308
self.mode = mode
307309

@@ -313,7 +315,7 @@ def make_node(self, x):
313315
x = ptb.as_tensor_variable(x)
314316
out_type = x.type()
315317

316-
if self.axis >= x.ndim or self.axis < -x.ndim:
318+
if self.axis >= x.ndim:
317319
raise ValueError(f"axis(={self.axis}) out of bounds")
318320

319321
return Apply(self, [x], [out_type])
@@ -431,10 +433,10 @@ def cumsum(x, axis=None):
431433
432434
"""
433435
if axis is None:
434-
# Handle raveling symbolically by flattening first, then applying cumsum with axis=0
435-
x_flattened = flatten(x, ndim=1) # This creates a 1D tensor
436-
return CumOp(axis=0, mode="add")(x_flattened)
436+
return CumOp(axis=0, mode="add")(ptb.as_tensor_variable(x).ravel())
437437
else:
438+
x = ptb.as_tensor_variable(x)
439+
axis = normalize_axis_index(axis, x.ndim)
438440
return CumOp(axis=axis, mode="add")(x)
439441

440442

@@ -456,10 +458,10 @@ def cumprod(x, axis=None):
456458
457459
"""
458460
if axis is None:
459-
# Handle raveling symbolically by flattening first, then applying cumprod with axis=0
460-
x_flattened = flatten(x, ndim=1) # This creates a 1D tensor
461-
return CumOp(axis=0, mode="mul")(x_flattened)
461+
return CumOp(axis=0, mode="mul")(ptb.as_tensor_variable(x).ravel())
462462
else:
463+
x = ptb.as_tensor_variable(x)
464+
axis = normalize_axis_index(axis, x.ndim)
463465
return CumOp(axis=axis, mode="mul")(x)
464466

465467

@@ -468,8 +470,8 @@ def vectorize_cum_op(op: CumOp, node: Apply, batch_x):
468470
"""Vectorize the CumOp to work on a batch of inputs."""
469471
[original_x] = node.inputs
470472
batch_ndim = batch_x.ndim - original_x.ndim
471-
axis = normalize_axis_index(op.axis, original_x.ndim)
472-
return type(op)(axis=axis + batch_ndim, mode=op.mode).make_node(batch_x)
473+
# op.axis is already normalized and non-negative
474+
return type(op)(axis=op.axis + batch_ndim, mode=op.mode).make_node(batch_x)
473475

474476

475477
def diff(x, n=1, axis=-1):

tests/tensor/test_extra_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def test_infer_shape(self):
226226
a = np.random.random((3, 5, 2)).astype(config.floatX)
227227

228228
# Test axis=None using cumsum function (which now handles it symbolically)
229-
self._compile_and_check([x], [cumsum(x)], [a], type(cumsum(x).owner.op))
229+
self._compile_and_check([x], [cumsum(x)], [a], self.op_class)
230230

231231
for axis in range(-len(a.shape), len(a.shape)):
232232
self._compile_and_check([x], [cumsum(x, axis=axis)], [a], self.op_class)

0 commit comments

Comments
 (0)