@@ -302,6 +302,8 @@ def __init__(self, axis: int, mode="add"):
302
302
raise ValueError (f'{ type (self ).__name__ } : Unknown mode "{ mode } "' )
303
303
if not isinstance (axis , int ):
304
304
raise TypeError ("axis must be an integer." )
305
+ if axis < 0 :
306
+ raise ValueError ("axis must be non-negative." )
305
307
self .axis = axis
306
308
self .mode = mode
307
309
@@ -313,7 +315,7 @@ def make_node(self, x):
313
315
x = ptb .as_tensor_variable (x )
314
316
out_type = x .type ()
315
317
316
- if self .axis >= x .ndim or self . axis < - x . ndim :
318
+ if self .axis >= x .ndim :
317
319
raise ValueError (f"axis(={ self .axis } ) out of bounds" )
318
320
319
321
return Apply (self , [x ], [out_type ])
@@ -431,10 +433,10 @@ def cumsum(x, axis=None):
431
433
432
434
"""
433
435
if axis is None :
434
- # Handle raveling symbolically by flattening first, then applying cumsum with axis=0
435
- x_flattened = flatten (x , ndim = 1 ) # This creates a 1D tensor
436
- return CumOp (axis = 0 , mode = "add" )(x_flattened )
436
+ return CumOp (axis = 0 , mode = "add" )(ptb .as_tensor_variable (x ).ravel ())
437
437
else :
438
+ x = ptb .as_tensor_variable (x )
439
+ axis = normalize_axis_index (axis , x .ndim )
438
440
return CumOp (axis = axis , mode = "add" )(x )
439
441
440
442
@@ -456,10 +458,10 @@ def cumprod(x, axis=None):
456
458
457
459
"""
458
460
if axis is None :
459
- # Handle raveling symbolically by flattening first, then applying cumprod with axis=0
460
- x_flattened = flatten (x , ndim = 1 ) # This creates a 1D tensor
461
- return CumOp (axis = 0 , mode = "mul" )(x_flattened )
461
+ return CumOp (axis = 0 , mode = "mul" )(ptb .as_tensor_variable (x ).ravel ())
462
462
else :
463
+ x = ptb .as_tensor_variable (x )
464
+ axis = normalize_axis_index (axis , x .ndim )
463
465
return CumOp (axis = axis , mode = "mul" )(x )
464
466
465
467
@@ -468,8 +470,8 @@ def vectorize_cum_op(op: CumOp, node: Apply, batch_x):
468
470
"""Vectorize the CumOp to work on a batch of inputs."""
469
471
[original_x ] = node .inputs
470
472
batch_ndim = batch_x .ndim - original_x .ndim
471
- axis = normalize_axis_index ( op .axis , original_x . ndim )
472
- return type (op )(axis = axis + batch_ndim , mode = op .mode ).make_node (batch_x )
473
+ # op.axis is already normalized and non-negative
474
+ return type (op )(axis = op . axis + batch_ndim , mode = op .mode ).make_node (batch_x )
473
475
474
476
475
477
def diff (x , n = 1 , axis = - 1 ):
0 commit comments