Skip to content

Commit 6c3e757

Browse files
committed
Fix numba implementation of CumOp when axis is None
1 parent 9a5deee commit 6c3e757

File tree

4 files changed

+31
-13
lines changed

4 files changed

+31
-13
lines changed

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import warnings
2+
from typing import cast
23

34
import numba
45
import numpy as np
56

67
from pytensor import config
8+
from pytensor.graph import Apply
79
from pytensor.link.numba.dispatch import basic as numba_basic
810
from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
911
from pytensor.raise_op import CheckAndRaise
12+
from pytensor.tensor import TensorVariable
1013
from pytensor.tensor.extra_ops import (
1114
Bartlett,
1215
CumOp,
@@ -30,21 +33,22 @@ def bartlett(x):
3033

3134

3235
@numba_funcify.register(CumOp)
33-
def numba_funcify_CumOp(op, node, **kwargs):
36+
def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
3437
axis = op.axis
3538
mode = op.mode
36-
ndim = node.outputs[0].ndim
39+
ndim = cast(TensorVariable, node.outputs[0]).ndim
3740

38-
if axis < 0:
39-
axis = ndim + axis
40-
if axis < 0 or axis >= ndim:
41-
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")
41+
if axis is not None:
42+
if axis < 0:
43+
axis = ndim + axis
44+
if axis < 0 or axis >= ndim:
45+
raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}")
4246

43-
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
44-
reaxis_first_inv = tuple(np.argsort(reaxis_first))
47+
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
48+
reaxis_first_inv = tuple(np.argsort(reaxis_first))
4549

4650
if mode == "add":
47-
if ndim == 1:
51+
if axis is None or ndim == 1:
4852

4953
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
5054
def cumop(x):
@@ -68,7 +72,7 @@ def cumop(x):
6872
return res.transpose(reaxis_first_inv)
6973

7074
else:
71-
if ndim == 1:
75+
if axis is None or ndim == 1:
7276

7377
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
7478
def cumop(x):

pytensor/tensor/extra_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Collection
2-
from typing import Iterable, Set, Tuple, Union
2+
from typing import Iterable, Optional, Set, Tuple, Union
33

44
import numpy as np
55
from numpy.core.multiarray import normalize_axis_index
@@ -291,7 +291,7 @@ class CumOp(COp):
291291
c_axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul"))
292292
)
293293

294-
def __init__(self, axis=None, mode="add"):
294+
def __init__(self, axis: Optional[int] = None, mode="add"):
295295
if mode not in ("add", "mul"):
296296
raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"')
297297
self.axis = axis

pytensor/tensor/var.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ def __iter__(self):
619619
)
620620

621621
@property
622-
def ndim(self):
622+
def ndim(self) -> int:
623623
"""The rank of this tensor."""
624624
return self.type.ndim
625625

tests/link/numba/test_extra_ops.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@ def test_Bartlett(val):
6767
1,
6868
"add",
6969
),
70+
(
71+
set_test_value(
72+
at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
73+
),
74+
None,
75+
"add",
76+
),
7077
(
7178
set_test_value(
7279
at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
@@ -81,6 +88,13 @@ def test_Bartlett(val):
8188
1,
8289
"mul",
8390
),
91+
(
92+
set_test_value(
93+
at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
94+
),
95+
None,
96+
"mul",
97+
),
8498
],
8599
)
86100
def test_CumOp(val, axis, mode):

0 commit comments

Comments
 (0)