28
28
from pytensor .scalar import upcast
29
29
from pytensor .tensor import TensorLike , as_tensor_variable
30
30
from pytensor .tensor import basic as ptb
31
- from pytensor .tensor .basic import alloc , join , second
31
+ from pytensor .tensor .basic import alloc , join , second , flatten
32
32
from pytensor .tensor .exceptions import NotScalarConstantError
33
33
from pytensor .tensor .math import abs as pt_abs
34
34
from pytensor .tensor .math import all as pt_all
@@ -297,27 +297,23 @@ class CumOp(COp):
297
297
c_axis = int_t , mode = EnumList (("MODE_ADD" , "add" ), ("MODE_MUL" , "mul" ))
298
298
)
299
299
300
- def __init__ (self , axis : int | None = None , mode = "add" ):
300
+ def __init__ (self , axis : int , mode = "add" ):
301
301
if mode not in ("add" , "mul" ):
302
302
raise ValueError (f'{ type (self ).__name__ } : Unknown mode "{ mode } "' )
303
- if not ( isinstance (axis , int ) or axis is None ):
304
- raise TypeError ("axis must be an integer or None ." )
303
+ if not isinstance (axis , int ):
304
+ raise TypeError ("axis must be an integer." )
305
305
self .axis = axis
306
306
self .mode = mode
307
307
308
308
@property
309
309
def c_axis (self ) -> int :
310
- if self .axis is None :
311
- return numpy_axis_is_none_flag
312
310
return self .axis
313
311
314
312
def make_node (self , x ):
315
313
x = ptb .as_tensor_variable (x )
316
314
out_type = x .type ()
317
315
318
- if self .axis is None :
319
- out_type = vector (dtype = x .dtype ) # Flatten
320
- elif self .axis >= x .ndim or self .axis < - x .ndim :
316
+ if self .axis >= x .ndim or self .axis < - x .ndim :
321
317
raise ValueError (f"axis(={ self .axis } ) out of bounds" )
322
318
323
319
return Apply (self , [x ], [out_type ])
@@ -334,17 +330,6 @@ def grad(self, inputs, output_gradients):
334
330
(x ,) = inputs
335
331
(gi ,) = output_gradients
336
332
337
- if self .axis is None :
338
- if self .mode == "add" :
339
- return [cumsum (gi [::- 1 ])[::- 1 ].reshape (x .shape )]
340
- elif self .mode == "mul" :
341
- fx = cumprod (x , axis = self .axis )
342
- return [cumsum ((fx * gi )[::- 1 ])[::- 1 ].reshape (x .shape ) / x ]
343
- else :
344
- raise NotImplementedError (
345
- f'{ type (self ).__name__ } : unknown gradient for mode "{ self .mode } "'
346
- )
347
-
348
333
reverse_slicing = [slice (None , None , None )] * gi .ndim
349
334
reverse_slicing [self .axis ] = slice (None , None , - 1 )
350
335
reverse_slicing = tuple (reverse_slicing )
@@ -361,9 +346,6 @@ def grad(self, inputs, output_gradients):
361
346
)
362
347
363
348
def infer_shape (self , fgraph , node , shapes ):
364
- if self .axis is None and len (shapes [0 ]) > 1 :
365
- return [(prod (shapes [0 ]),)] # Flatten
366
-
367
349
return shapes
368
350
369
351
def c_support_code_apply (self , node : Apply , name : str ) -> str :
@@ -376,10 +358,7 @@ def c_code(self, node, name, inames, onames, sub):
376
358
fail = sub ["fail" ]
377
359
params = sub ["params" ]
378
360
379
- if self .axis is None :
380
- axis_code = "int axis = NPY_RAVEL_AXIS;\n "
381
- else :
382
- axis_code = f"int axis = { params } ->c_axis;\n "
361
+ axis_code = f"int axis = { params } ->c_axis;\n "
383
362
384
363
code = (
385
364
axis_code
@@ -451,7 +430,12 @@ def cumsum(x, axis=None):
451
430
.. versionadded:: 0.7
452
431
453
432
"""
454
- return CumOp (axis = axis , mode = "add" )(x )
433
+ if axis is None :
434
+ # Handle raveling symbolically by flattening first, then applying cumsum with axis=0
435
+ x_flattened = flatten (x , ndim = 1 ) # This creates a 1D tensor
436
+ return CumOp (axis = 0 , mode = "add" )(x_flattened )
437
+ else :
438
+ return CumOp (axis = axis , mode = "add" )(x )
455
439
456
440
457
441
def cumprod (x , axis = None ):
@@ -471,26 +455,21 @@ def cumprod(x, axis=None):
471
455
.. versionadded:: 0.7
472
456
473
457
"""
474
- return CumOp (axis = axis , mode = "mul" )(x )
458
+ if axis is None :
459
+ # Handle raveling symbolically by flattening first, then applying cumprod with axis=0
460
+ x_flattened = flatten (x , ndim = 1 ) # This creates a 1D tensor
461
+ return CumOp (axis = 0 , mode = "mul" )(x_flattened )
462
+ else :
463
+ return CumOp (axis = axis , mode = "mul" )(x )
475
464
476
465
477
466
@_vectorize_node .register (CumOp )
478
467
def vectorize_cum_op (op : CumOp , node : Apply , batch_x ):
479
468
"""Vectorize the CumOp to work on a batch of inputs."""
480
469
[original_x ] = node .inputs
481
470
batch_ndim = batch_x .ndim - original_x .ndim
482
- axis = op .axis
483
- if axis is None and original_x .ndim == 1 :
484
- axis = 0
485
- elif axis is not None :
486
- axis = normalize_axis_index (op .axis , original_x .ndim )
487
-
488
- if axis is None :
489
- # Ravel all unbatched dimensions and perform CumOp on the last axis
490
- batch_x_raveled = [batch_x .flatten (ndim = batch_ndim + 1 ) for x in batch_x ]
491
- return type (op )(axis = - 1 , mode = op .mode ).make_node (batch_x_raveled )
492
- else :
493
- return type (op )(axis = axis + batch_ndim , mode = op .mode ).make_node (batch_x )
471
+ axis = normalize_axis_index (op .axis , original_x .ndim )
472
+ return type (op )(axis = axis + batch_ndim , mode = op .mode ).make_node (batch_x )
494
473
495
474
496
475
def diff (x , n = 1 , axis = - 1 ):
0 commit comments