Skip to content

Commit b0057c4

Browse files
committed
Numba sparse: Remove codebase xfails
1 parent 89f0fa7 commit b0057c4

File tree

3 files changed

+14
-16
lines changed

3 files changed

+14
-16
lines changed

pytensor/link/numba/dispatch/typed_list.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
register_funcify_default_op_cache_key,
1010
)
1111
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
12+
from pytensor.link.numba.dispatch.sparse.basic import CSCMatrixType, CSRMatrixType
1213
from pytensor.tensor.type_other import SliceType
1314
from pytensor.typed_list import (
1415
Append,
@@ -64,6 +65,18 @@ def all_equal(x, y):
6465
def all_equal(x, y):
6566
return x == y
6667

68+
if (isinstance(x, CSRMatrixType) and isinstance(y, CSRMatrixType)) or (
69+
isinstance(x, CSCMatrixType) and isinstance(y, CSCMatrixType)
70+
):
71+
72+
def all_equal(x, y):
73+
return (
74+
x.shape == y.shape
75+
and (x.data == y.data).all()
76+
and (x.indptr == y.indptr).all()
77+
and (x.indices == y.indices).all()
78+
)
79+
6780
return all_equal
6881

6982

tests/compile/function/test_pfunc.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33
import scipy as sp
44

55
import pytensor.tensor as pt
6-
from pytensor.compile import UnusedInputError, get_default_mode, get_mode
6+
from pytensor.compile import UnusedInputError, get_mode
77
from pytensor.compile.function import function, pfunc
88
from pytensor.compile.function.pfunc import rebuild_collect_shared
99
from pytensor.compile.io import In
1010
from pytensor.compile.sharedvalue import shared
1111
from pytensor.configdefaults import config
1212
from pytensor.graph.utils import MissingInputError
13-
from pytensor.link.numba import NumbaLinker
1413
from pytensor.sparse import SparseTensorType
1514
from pytensor.tensor.math import sum as pt_sum
1615
from pytensor.tensor.type import (
@@ -766,10 +765,6 @@ def test_shared_constructor_copies(self):
766765
# rule #2 reading back from pytensor-managed memory
767766
assert not np.may_share_memory(A.get_value(borrow=False), data_of(A))
768767

769-
@pytest.mark.xfail(
770-
condition=isinstance(get_default_mode().linker, NumbaLinker),
771-
reason="Numba does not support Sparse Ops yet",
772-
)
773768
def test_sparse_input_aliasing_affecting_inplace_operations(self):
774769
# Note: to trigger this bug with pytensor rev 4586:2bc6fc7f218b,
775770
# you need to make in inputs mutable (so that inplace

tests/typed_list/test_basic.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import pytensor
88
import pytensor.typed_list
99
from pytensor import sparse
10-
from pytensor.compile import get_default_mode
11-
from pytensor.link.numba import NumbaLinker
1210
from pytensor.tensor.type import (
1311
TensorType,
1412
integer_dtypes,
@@ -454,10 +452,6 @@ def test_non_tensor_type(self):
454452

455453
assert f([[x, y], [x, y, y]], [x, y]) == 0
456454

457-
@pytest.mark.xfail(
458-
condition=isinstance(get_default_mode().linker, NumbaLinker),
459-
reason="Numba does not support Sparse Ops yet",
460-
)
461455
def test_sparse(self):
462456
mySymbolicSparseList = TypedListType(
463457
sparse.SparseTensorType("csr", pytensor.config.floatX)
@@ -525,10 +519,6 @@ def test_non_tensor_type(self):
525519

526520
assert f([[x, y], [x, y, y]], [x, y]) == 1
527521

528-
@pytest.mark.xfail(
529-
condition=isinstance(get_default_mode().linker, NumbaLinker),
530-
reason="Numba does not support Sparse Ops yet",
531-
)
532522
def test_sparse(self):
533523
mySymbolicSparseList = TypedListType(
534524
sparse.SparseTensorType("csr", pytensor.config.floatX)

0 commit comments

Comments
 (0)