Skip to content

Commit 5bb857e

Browse files
committed
Cleanup C-code
1 parent 90e11bb commit 5bb857e

File tree

1 file changed

+31
-47
lines changed

1 file changed

+31
-47
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 31 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import warnings
22
from collections.abc import Collection, Iterable
3+
from textwrap import dedent
34

45
import numpy as np
56

@@ -20,15 +21,14 @@
2021
from pytensor.npy_2_compat import (
2122
normalize_axis_index,
2223
npy_2_compat_header,
23-
numpy_axis_is_none_flag,
2424
old_np_unique,
2525
)
2626
from pytensor.raise_op import Assert
2727
from pytensor.scalar import int64 as int_t
2828
from pytensor.scalar import upcast
2929
from pytensor.tensor import TensorLike, as_tensor_variable
3030
from pytensor.tensor import basic as ptb
31-
from pytensor.tensor.basic import alloc, join, second, flatten
31+
from pytensor.tensor.basic import alloc, join, second
3232
from pytensor.tensor.exceptions import NotScalarConstantError
3333
from pytensor.tensor.math import abs as pt_abs
3434
from pytensor.tensor.math import all as pt_all
@@ -48,7 +48,7 @@
4848
from pytensor.tensor.math import sum as pt_sum
4949
from pytensor.tensor.shape import Shape_i
5050
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
51-
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
51+
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes
5252
from pytensor.tensor.utils import normalize_reduce_axis
5353
from pytensor.tensor.variable import TensorVariable
5454
from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH
@@ -307,7 +307,6 @@ def __init__(self, axis: int, mode="add"):
307307
self.axis = axis
308308
self.mode = mode
309309

310-
311310
def make_node(self, x):
312311
x = ptb.as_tensor_variable(x)
313312
out_type = x.type()
@@ -325,7 +324,7 @@ def perform(self, node, inputs, output_storage):
325324
else:
326325
z[0] = np.cumprod(x, axis=self.axis)
327326

328-
def grad(self, inputs, output_gradients):
327+
def L_op(self, inputs, outputs, output_gradients):
329328
(x,) = inputs
330329
(gi,) = output_gradients
331330

@@ -357,58 +356,43 @@ def c_code(self, node, name, inames, onames, sub):
357356
fail = sub["fail"]
358357
params = sub["params"]
359358

360-
axis_code = f"int axis = {params}->axis;\n"
361-
362-
code = (
363-
axis_code
364-
+ f"""
365-
#undef NPY_UF_DBG_TRACING
366-
#define NPY_UF_DBG_TRACING 1
367-
368-
if (axis == 0 && PyArray_NDIM({x}) == 1)
369-
axis = NPY_RAVEL_AXIS;
370-
npy_intp shape[1] = {{ PyArray_SIZE({x}) }};
371-
if(axis == NPY_RAVEL_AXIS && !({z} && PyArray_DIMS({z})[0] == shape[0]))
372-
{{
373-
Py_XDECREF({z});
374-
{z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE({x}));
375-
}}
359+
return dedent(
360+
f"""
361+
int axis = {params}->axis;
376362
377-
else if(axis != NPY_RAVEL_AXIS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x}))))
378-
{{
379-
Py_XDECREF({z});
380-
{z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x}));
381-
}}
363+
if (!({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x}))))
364+
{{
365+
Py_XDECREF({z});
366+
{z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x}));
367+
if (!{z}){{ {fail} }};
368+
}}
369+
370+
{{
382371
383-
if (!{z})
372+
PyObject * t = NULL;
373+
if({params}->mode == MODE_ADD)
374+
t = PyArray_CumSum({x}, axis, PyArray_TYPE({x}), {z});
375+
else if({params}->mode == MODE_MUL)
376+
t = PyArray_CumProd({x}, axis, PyArray_TYPE({x}), {z});
377+
378+
if (!t){{
384379
{fail};
385-
{{
386-
387-
PyObject * t = NULL;
388-
if({params}->mode == MODE_ADD)
389-
t = PyArray_CumSum(
390-
{x}, axis,
391-
PyArray_TYPE({x}), {z});
392-
else if({params}->mode == MODE_MUL)
393-
t = PyArray_CumProd(
394-
{x}, axis,
395-
PyArray_TYPE({x}), {z});
396-
397-
if (!t){{
398-
{fail};
399-
}}
400-
// Because PyArray_CumSum/CumProd returns a newly created reference on t.
401-
Py_XDECREF(t);
402380
}}
381+
382+
// Because PyArray_CumSum/CumProd returns a newly created reference on t.
383+
Py_XDECREF(t);
384+
}}
403385
"""
404386
)
405387

406-
return code
407-
408388
def c_code_cache_version(self):
409-
return (9,)
389+
return (10,)
410390

411391
def __str__(self):
392+
if self.mode == "add":
393+
return f"Cumsum{{axis={self.axis}}}"
394+
elif self.mode == "mul":
395+
return f"Cumprod{{axis={self.axis}}}"
412396
return f"{self.__class__.__name__}{{{self.axis}, {self.mode}}}"
413397

414398

0 commit comments

Comments
 (0)