Skip to content

Commit 007ec2d

Browse files
CopilotricardoV94
andcommitted
Address review comments: remove redundant checks and c_axis property, eliminate code duplication
Co-authored-by: ricardoV94 <[email protected]>
1 parent d3bfb15 commit 007ec2d

File tree

3 files changed

+13
-17
lines changed

3 files changed

+13
-17
lines changed

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,7 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
3737
mode = op.mode
3838
ndim = cast(TensorVariable, node.outputs[0]).ndim
3939

40-
if axis < 0:
41-
axis = ndim + axis
42-
if axis < 0 or axis >= ndim:
43-
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")
40+
4441

4542
reaxis_first = (axis, *(i for i in range(ndim) if i != axis))
4643
reaxis_first_inv = tuple(np.argsort(reaxis_first))

pytensor/tensor/extra_ops.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ class CumOp(COp):
294294
__props__ = ("axis", "mode")
295295
check_input = False
296296
params_type = ParamsType(
297-
c_axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul"))
297+
axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul"))
298298
)
299299

300300
def __init__(self, axis: int, mode="add"):
@@ -307,9 +307,6 @@ def __init__(self, axis: int, mode="add"):
307307
self.axis = axis
308308
self.mode = mode
309309

310-
@property
311-
def c_axis(self) -> int:
312-
return self.axis
313310

314311
def make_node(self, x):
315312
x = ptb.as_tensor_variable(x)
@@ -360,7 +357,7 @@ def c_code(self, node, name, inames, onames, sub):
360357
fail = sub["fail"]
361358
params = sub["params"]
362359

363-
axis_code = f"int axis = {params}->c_axis;\n"
360+
axis_code = f"int axis = {params}->axis;\n"
364361

365362
code = (
366363
axis_code
@@ -432,12 +429,13 @@ def cumsum(x, axis=None):
432429
.. versionadded:: 0.7
433430
434431
"""
432+
x = ptb.as_tensor_variable(x)
435433
if axis is None:
436-
return CumOp(axis=0, mode="add")(ptb.as_tensor_variable(x).ravel())
434+
x = x.ravel()
435+
axis = 0
437436
else:
438-
x = ptb.as_tensor_variable(x)
439437
axis = normalize_axis_index(axis, x.ndim)
440-
return CumOp(axis=axis, mode="add")(x)
438+
return CumOp(axis=axis, mode="add")(x)
441439

442440

443441
def cumprod(x, axis=None):
@@ -457,12 +455,13 @@ def cumprod(x, axis=None):
457455
.. versionadded:: 0.7
458456
459457
"""
458+
x = ptb.as_tensor_variable(x)
460459
if axis is None:
461-
return CumOp(axis=0, mode="mul")(ptb.as_tensor_variable(x).ravel())
460+
x = x.ravel()
461+
axis = 0
462462
else:
463-
x = ptb.as_tensor_variable(x)
464463
axis = normalize_axis_index(axis, x.ndim)
465-
return CumOp(axis=axis, mode="mul")(x)
464+
return CumOp(axis=axis, mode="mul")(x)
466465

467466

468467
@_vectorize_node.register(CumOp)

tests/tensor/test_extra_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def test_infer_shape(self):
225225
x = tensor3("x")
226226
a = np.random.random((3, 5, 2)).astype(config.floatX)
227227

228-
# Test axis=None using cumsum function (which now handles it symbolically)
228+
# Test default axis=None
229229
self._compile_and_check([x], [cumsum(x)], [a], self.op_class)
230230

231231
for axis in range(-len(a.shape), len(a.shape)):
@@ -234,7 +234,7 @@ def test_infer_shape(self):
234234
def test_grad(self):
235235
a = np.random.random((3, 5, 2)).astype(config.floatX)
236236

237-
# Test axis=None using cumsum/cumprod functions (which now handle it symbolically)
237+
# Test default axis=None using cumsum/cumprod functions
238238
utt.verify_grad(lambda x: cumsum(x), [a]) # Test axis=None for cumsum
239239
utt.verify_grad(lambda x: cumprod(x), [a]) # Test axis=None for cumprod
240240

0 commit comments

Comments
 (0)