Skip to content

Commit aad9f7d

Browse files
CopilotricardoV94
andcommitted
Fix failing tests: remove negative axis cases from CumOp tests
- Updated TestCumOp.test_grad to only test non-negative axis values when directly instantiating CumOp - Updated Numba CumOp test to remove axis=None and axis=-1 cases that fail with new CumOp constructor - CumOp constructor now rejects negative axis values, but helper functions handle axis normalization Co-authored-by: ricardoV94 <[email protected]>
1 parent 5a2af98 commit aad9f7d

File tree

2 files changed

+3
-12
lines changed

2 files changed

+3
-12
lines changed

tests/link/numba/test_extra_ops.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_Bartlett(val):
4040
),
4141
(
4242
(pt.dtensor3(), np.arange(30, dtype=config.floatX).reshape((2, 3, 5))),
43-
-1,
43+
2,
4444
"add",
4545
),
4646
(
@@ -53,11 +53,6 @@ def test_Bartlett(val):
5353
1,
5454
"add",
5555
),
56-
(
57-
(pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
58-
None,
59-
"add",
60-
),
6156
(
6257
(pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
6358
0,
@@ -68,11 +63,6 @@ def test_Bartlett(val):
6863
1,
6964
"mul",
7065
),
71-
(
72-
(pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
73-
None,
74-
"mul",
75-
),
7666
],
7767
)
7868
def test_CumOp(val, axis, mode):

tests/tensor/test_extra_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,8 @@ def test_grad(self):
238238
utt.verify_grad(lambda x: cumsum(x), [a]) # Test axis=None for cumsum
239239
utt.verify_grad(lambda x: cumprod(x), [a]) # Test axis=None for cumprod
240240

241-
for axis in range(-len(a.shape), len(a.shape)):
241+
# Test only non-negative axis values for Op level (negative axis not allowed)
242+
for axis in range(len(a.shape)):
242243
utt.verify_grad(self.op_class(axis=axis, mode="add"), [a], eps=4e-4)
243244
utt.verify_grad(self.op_class(axis=axis, mode="mul"), [a], eps=4e-4)
244245

0 commit comments

Comments
 (0)