1
1
import warnings
2
2
from collections .abc import Collection , Iterable
3
+ from textwrap import dedent
3
4
4
5
import numpy as np
5
6
20
21
from pytensor .npy_2_compat import (
21
22
normalize_axis_index ,
22
23
npy_2_compat_header ,
23
- numpy_axis_is_none_flag ,
24
24
old_np_unique ,
25
25
)
26
26
from pytensor .raise_op import Assert
27
27
from pytensor .scalar import int64 as int_t
28
28
from pytensor .scalar import upcast
29
29
from pytensor .tensor import TensorLike , as_tensor_variable
30
30
from pytensor .tensor import basic as ptb
31
- from pytensor .tensor .basic import alloc , join , second , flatten
31
+ from pytensor .tensor .basic import alloc , join , second
32
32
from pytensor .tensor .exceptions import NotScalarConstantError
33
33
from pytensor .tensor .math import abs as pt_abs
34
34
from pytensor .tensor .math import all as pt_all
48
48
from pytensor .tensor .math import sum as pt_sum
49
49
from pytensor .tensor .shape import Shape_i
50
50
from pytensor .tensor .subtensor import advanced_inc_subtensor1 , set_subtensor
51
- from pytensor .tensor .type import TensorType , dvector , int_dtypes , integer_dtypes , vector
51
+ from pytensor .tensor .type import TensorType , dvector , int_dtypes , integer_dtypes
52
52
from pytensor .tensor .utils import normalize_reduce_axis
53
53
from pytensor .tensor .variable import TensorVariable
54
54
from pytensor .utils import LOCAL_BITWIDTH , PYTHON_INT_BITWIDTH
@@ -307,7 +307,6 @@ def __init__(self, axis: int, mode="add"):
307
307
self .axis = axis
308
308
self .mode = mode
309
309
310
-
311
310
def make_node (self , x ):
312
311
x = ptb .as_tensor_variable (x )
313
312
out_type = x .type ()
@@ -325,7 +324,7 @@ def perform(self, node, inputs, output_storage):
325
324
else :
326
325
z [0 ] = np .cumprod (x , axis = self .axis )
327
326
328
- def grad (self , inputs , output_gradients ):
327
+ def L_op (self , inputs , outputs , output_gradients ):
329
328
(x ,) = inputs
330
329
(gi ,) = output_gradients
331
330
@@ -357,58 +356,43 @@ def c_code(self, node, name, inames, onames, sub):
357
356
fail = sub ["fail" ]
358
357
params = sub ["params" ]
359
358
360
- axis_code = f"int axis = { params } ->axis;\n "
361
-
362
- code = (
363
- axis_code
364
- + f"""
365
- #undef NPY_UF_DBG_TRACING
366
- #define NPY_UF_DBG_TRACING 1
367
-
368
- if (axis == 0 && PyArray_NDIM({ x } ) == 1)
369
- axis = NPY_RAVEL_AXIS;
370
- npy_intp shape[1] = {{ PyArray_SIZE({ x } ) }};
371
- if(axis == NPY_RAVEL_AXIS && !({ z } && PyArray_DIMS({ z } )[0] == shape[0]))
372
- {{
373
- Py_XDECREF({ z } );
374
- { z } = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE({ x } ));
375
- }}
359
+ return dedent (
360
+ f"""
361
+ int axis = { params } ->axis;
376
362
377
- else if(axis != NPY_RAVEL_AXIS && !({ z } && PyArray_CompareLists(PyArray_DIMS({ z } ), PyArray_DIMS({ x } ), PyArray_NDIM({ x } ))))
378
- {{
379
- Py_XDECREF({ z } );
380
- { z } = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({ x } ), PyArray_DIMS({ x } ), PyArray_TYPE({ x } ));
381
- }}
363
+ if (!({ z } && PyArray_CompareLists(PyArray_DIMS({ z } ), PyArray_DIMS({ x } ), PyArray_NDIM({ x } ))))
364
+ {{
365
+ Py_XDECREF({ z } );
366
+ { z } = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({ x } ), PyArray_DIMS({ x } ), PyArray_TYPE({ x } ));
367
+ if (!{ z } ){{ { fail } }};
368
+ }}
369
+
370
+ {{
382
371
383
- if (!{ z } )
372
+ PyObject * t = NULL;
373
+ if({ params } ->mode == MODE_ADD)
374
+ t = PyArray_CumSum({ x } , axis, PyArray_TYPE({ x } ), { z } );
375
+ else if({ params } ->mode == MODE_MUL)
376
+ t = PyArray_CumProd({ x } , axis, PyArray_TYPE({ x } ), { z } );
377
+
378
+ if (!t){{
384
379
{ fail } ;
385
- {{
386
-
387
- PyObject * t = NULL;
388
- if({ params } ->mode == MODE_ADD)
389
- t = PyArray_CumSum(
390
- { x } , axis,
391
- PyArray_TYPE({ x } ), { z } );
392
- else if({ params } ->mode == MODE_MUL)
393
- t = PyArray_CumProd(
394
- { x } , axis,
395
- PyArray_TYPE({ x } ), { z } );
396
-
397
- if (!t){{
398
- { fail } ;
399
- }}
400
- // Because PyArray_CumSum/CumProd returns a newly created reference on t.
401
- Py_XDECREF(t);
402
380
}}
381
+
382
+ // Because PyArray_CumSum/CumProd returns a newly created reference on t.
383
+ Py_XDECREF(t);
384
+ }}
403
385
"""
404
386
)
405
387
406
- return code
407
-
408
388
def c_code_cache_version (self ):
409
- return (9 ,)
389
+ return (10 ,)
410
390
411
391
def __str__ (self ):
392
+ if self .mode == "add" :
393
+ return f"Cumsum{{axis={ self .axis } }}"
394
+ elif self .mode == "mul" :
395
+ return f"Cumprod{{axis={ self .axis } }}"
412
396
return f"{ self .__class__ .__name__ } {{{ self .axis } , { self .mode } }}"
413
397
414
398
0 commit comments