11import warnings
22from collections .abc import Collection , Iterable
3+ from textwrap import dedent
34
45import numpy as np
56from numpy .lib .array_utils import normalize_axis_index
4445from pytensor .tensor .math import sum as pt_sum
4546from pytensor .tensor .shape import Shape_i
4647from pytensor .tensor .subtensor import advanced_inc_subtensor1 , set_subtensor
47- from pytensor .tensor .type import TensorType , dvector , int_dtypes , integer_dtypes , vector
48+ from pytensor .tensor .type import TensorType , dvector , int_dtypes , integer_dtypes
4849from pytensor .tensor .utils import normalize_reduce_axis
4950from pytensor .tensor .variable import TensorVariable
50- from pytensor .utils import LOCAL_BITWIDTH , NPY_RAVEL_AXIS , PYTHON_INT_BITWIDTH
51+ from pytensor .utils import LOCAL_BITWIDTH , PYTHON_INT_BITWIDTH
5152
5253
5354class CpuContiguous (COp ):
@@ -290,33 +291,28 @@ class CumOp(COp):
290291 __props__ = ("axis" , "mode" )
291292 check_input = False
292293 params_type = ParamsType (
293- c_axis = int_t , mode = EnumList (("MODE_ADD" , "add" ), ("MODE_MUL" , "mul" ))
294+ axis = int_t , mode = EnumList (("MODE_ADD" , "add" ), ("MODE_MUL" , "mul" ))
294295 )
295296
296- def __init__ (self , axis : int | None = None , mode = "add" ):
297+ def __init__ (self , axis : int , mode = "add" ):
297298 if mode not in ("add" , "mul" ):
298299 raise ValueError (f'{ type (self ).__name__ } : Unknown mode "{ mode } "' )
299- if not (isinstance (axis , int ) or axis is None ):
300- raise TypeError ("axis must be an integer or None." )
300+ if not isinstance (axis , int ):
301+ raise TypeError (f"axis must be an integer, got { axis } of type { type (axis )} " )
302+ if axis < 0 :
303+ raise ValueError (f"axis must be non-negative, got { axis } " )
301304 self .axis = axis
302305 self .mode = mode
303306
304- @property
305- def c_axis (self ) -> int :
306- if self .axis is None :
307- return NPY_RAVEL_AXIS
308- return self .axis
309-
310307 def make_node (self , x ):
311308 x = ptb .as_tensor_variable (x )
312- out_type = x .type ()
313309
314- if self .axis is None :
315- out_type = vector ( dtype = x . dtype ) # Flatten
316- elif self . axis >= x . ndim or self .axis < - x . ndim :
317- raise ValueError ( f"axis(= { self . axis } ) out of bounds" )
310+ if self .axis >= x . type . ndim :
311+ raise ValueError (
312+ f"axis(= { self .axis } ) out of bounds for variable { x } with { x . type . ndim } ndims"
313+ )
318314
319- return Apply (self , [x ], [out_type ])
315+ return Apply (self , [x ], [x . type () ])
320316
321317 def perform (self , node , inputs , output_storage ):
322318 x = inputs [0 ]
@@ -326,21 +322,10 @@ def perform(self, node, inputs, output_storage):
326322 else :
327323 z [0 ] = np .cumprod (x , axis = self .axis )
328324
329- def grad (self , inputs , output_gradients ):
325+ def L_op (self , inputs , outputs , output_gradients ):
330326 (x ,) = inputs
331327 (gi ,) = output_gradients
332328
333- if self .axis is None :
334- if self .mode == "add" :
335- return [cumsum (gi [::- 1 ])[::- 1 ].reshape (x .shape )]
336- elif self .mode == "mul" :
337- fx = cumprod (x , axis = self .axis )
338- return [cumsum ((fx * gi )[::- 1 ])[::- 1 ].reshape (x .shape ) / x ]
339- else :
340- raise NotImplementedError (
341- f'{ type (self ).__name__ } : unknown gradient for mode "{ self .mode } "'
342- )
343-
344329 reverse_slicing = [slice (None , None , None )] * gi .ndim
345330 reverse_slicing [self .axis ] = slice (None , None , - 1 )
346331 reverse_slicing = tuple (reverse_slicing )
@@ -357,9 +342,6 @@ def grad(self, inputs, output_gradients):
357342 )
358343
359344 def infer_shape (self , fgraph , node , shapes ):
360- if self .axis is None and len (shapes [0 ]) > 1 :
361- return [(prod (shapes [0 ]),)] # Flatten
362-
363345 return shapes
364346
365347 def c_code (self , node , name , inames , onames , sub ):
@@ -368,61 +350,43 @@ def c_code(self, node, name, inames, onames, sub):
368350 fail = sub ["fail" ]
369351 params = sub ["params" ]
370352
371- if self .axis is None :
372- axis_code = "int axis = NPY_RAVEL_AXIS;\n "
373- else :
374- axis_code = f"int axis = { params } ->c_axis;\n "
375-
376- code = (
377- axis_code
378- + f"""
379- #undef NPY_UF_DBG_TRACING
380- #define NPY_UF_DBG_TRACING 1
381-
382- if (axis == 0 && PyArray_NDIM({ x } ) == 1)
383- axis = NPY_RAVEL_AXIS;
384- npy_intp shape[1] = {{ PyArray_SIZE({ x } ) }};
385- if(axis == NPY_RAVEL_AXIS && !({ z } && PyArray_DIMS({ z } )[0] == shape[0]))
386- {{
387- Py_XDECREF({ z } );
388- { z } = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE({ x } ));
389- }}
353+ return dedent (
354+ f"""
355+ int axis = { params } ->axis;
390356
391- else if(axis != NPY_RAVEL_AXIS && !({ z } && PyArray_CompareLists(PyArray_DIMS({ z } ), PyArray_DIMS({ x } ), PyArray_NDIM({ x } ))))
392- {{
393- Py_XDECREF({ z } );
394- { z } = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({ x } ), PyArray_DIMS({ x } ), PyArray_TYPE({ x } ));
395- }}
357+ if (!({ z } && PyArray_CompareLists(PyArray_DIMS({ z } ), PyArray_DIMS({ x } ), PyArray_NDIM({ x } ))))
358+ {{
359+ Py_XDECREF({ z } );
360+ { z } = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({ x } ), PyArray_DIMS({ x } ), PyArray_TYPE({ x } ));
361+ if (!{ z } ){{ { fail } }};
362+ }}
363+
364+ {{
396365
397- if (!{ z } )
366+ PyObject * t = NULL;
367+ if({ params } ->mode == MODE_ADD)
368+ t = PyArray_CumSum({ x } , axis, PyArray_TYPE({ x } ), { z } );
369+ else if({ params } ->mode == MODE_MUL)
370+ t = PyArray_CumProd({ x } , axis, PyArray_TYPE({ x } ), { z } );
371+
372+ if (!t){{
398373 { fail } ;
399- {{
400-
401- PyObject * t = NULL;
402- if({ params } ->mode == MODE_ADD)
403- t = PyArray_CumSum(
404- { x } , axis,
405- PyArray_TYPE({ x } ), { z } );
406- else if({ params } ->mode == MODE_MUL)
407- t = PyArray_CumProd(
408- { x } , axis,
409- PyArray_TYPE({ x } ), { z } );
410-
411- if (!t){{
412- { fail } ;
413- }}
414- // Because PyArray_CumSum/CumProd returns a newly created reference on t.
415- Py_XDECREF(t);
416374 }}
375+
376+ // Because PyArray_CumSum/CumProd returns a newly created reference on t.
377+ Py_XDECREF(t);
378+ }}
417379 """
418380 )
419381
420- return code
421-
422382 def c_code_cache_version (self ):
423- return (10 ,)
383+ return (11 ,)
424384
425385 def __str__ (self ):
386+ if self .mode == "add" :
387+ return f"Cumsum{{axis={ self .axis } }}"
388+ elif self .mode == "mul" :
389+ return f"Cumprod{{axis={ self .axis } }}"
426390 return f"{ self .__class__ .__name__ } {{{ self .axis } , { self .mode } }}"
427391
428392
@@ -443,6 +407,12 @@ def cumsum(x, axis=None):
443407 .. versionadded:: 0.7
444408
445409 """
410+ x = ptb .as_tensor_variable (x )
411+ if axis is None :
412+ x = x .ravel ()
413+ axis = 0
414+ else :
415+ axis = normalize_axis_index (axis , x .ndim )
446416 return CumOp (axis = axis , mode = "add" )(x )
447417
448418
@@ -463,6 +433,12 @@ def cumprod(x, axis=None):
463433 .. versionadded:: 0.7
464434
465435 """
436+ x = ptb .as_tensor_variable (x )
437+ if axis is None :
438+ x = x .ravel ()
439+ axis = 0
440+ else :
441+ axis = normalize_axis_index (axis , x .ndim )
466442 return CumOp (axis = axis , mode = "mul" )(x )
467443
468444
@@ -471,18 +447,8 @@ def vectorize_cum_op(op: CumOp, node: Apply, batch_x):
471447 """Vectorize the CumOp to work on a batch of inputs."""
472448 [original_x ] = node .inputs
473449 batch_ndim = batch_x .ndim - original_x .ndim
474- axis = op .axis
475- if axis is None and original_x .ndim == 1 :
476- axis = 0
477- elif axis is not None :
478- axis = normalize_axis_index (op .axis , original_x .ndim )
479-
480- if axis is None :
481- # Ravel all unbatched dimensions and perform CumOp on the last axis
482- batch_x_raveled = [batch_x .flatten (ndim = batch_ndim + 1 ) for x in batch_x ]
483- return type (op )(axis = - 1 , mode = op .mode ).make_node (batch_x_raveled )
484- else :
485- return type (op )(axis = axis + batch_ndim , mode = op .mode ).make_node (batch_x )
450+ # op.axis is already normalized and non-negative
451+ return type (op )(axis = op .axis + batch_ndim , mode = op .mode ).make_node (batch_x )
486452
487453
488454def diff (x , n = 1 , axis = - 1 ):
0 commit comments