Skip to content

Commit e9c1320

Browse files
committed
Numba sparse: Remove codebase xfails
1 parent da224b3 commit e9c1320

File tree

6 files changed

+63
-54
lines changed

6 files changed

+63
-54
lines changed

tests/compile/function/test_pfunc.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33
import scipy as sp
44

55
import pytensor.tensor as pt
6-
from pytensor.compile import UnusedInputError, get_default_mode, get_mode
6+
from pytensor.compile import UnusedInputError, get_mode
77
from pytensor.compile.function import function, pfunc
88
from pytensor.compile.function.pfunc import rebuild_collect_shared
99
from pytensor.compile.io import In
1010
from pytensor.compile.sharedvalue import shared
1111
from pytensor.configdefaults import config
1212
from pytensor.graph.utils import MissingInputError
13-
from pytensor.link.numba import NumbaLinker
1413
from pytensor.sparse import SparseTensorType
1514
from pytensor.tensor.math import sum as pt_sum
1615
from pytensor.tensor.type import (
@@ -766,10 +765,6 @@ def test_shared_constructor_copies(self):
766765
# rule #2 reading back from pytensor-managed memory
767766
assert not np.may_share_memory(A.get_value(borrow=False), data_of(A))
768767

769-
@pytest.mark.xfail(
770-
condition=isinstance(get_default_mode().linker, NumbaLinker),
771-
reason="Numba does not support Sparse Ops yet",
772-
)
773768
def test_sparse_input_aliasing_affecting_inplace_operations(self):
774769
# Note: to trigger this bug with pytensor rev 4586:2bc6fc7f218b,
775770
# you need to make in inputs mutable (so that inplace

tests/sparse/__init__.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +0,0 @@
1-
import pytest
2-
3-
from pytensor.compile import get_default_mode
4-
from pytensor.link.numba import NumbaLinker
5-
6-
7-
if isinstance(get_default_mode().linker, NumbaLinker):
8-
pytest.skip(
9-
reason="Numba does not support Sparse Ops yet",
10-
allow_module_level=True,
11-
)

tests/sparse/test_math.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
import pytensor
99
import pytensor.sparse.math as psm
1010
import pytensor.tensor as pt
11+
from pytensor.compile import get_default_mode
1112
from pytensor.configdefaults import config
13+
from pytensor.link.numba import NumbaLinker
1214
from pytensor.scalar import upcast
1315
from pytensor.sparse.basic import (
1416
CSR,
@@ -427,33 +429,54 @@ def test_opt_unpack(self):
427429
)
428430
f(kernvals, imvals)
429431

430-
def test_dot_sparse_sparse(self):
432+
@pytest.mark.parametrize(
433+
"sparse_format_a",
434+
(
435+
"csc",
436+
"csr",
437+
pytest.param(
438+
"bsr",
439+
marks=pytest.mark.xfail(
440+
isinstance(get_default_mode().linker, NumbaLinker),
441+
reason="Numba does not support bsr",
442+
),
443+
),
444+
),
445+
)
446+
@pytest.mark.parametrize(
447+
"sparse_format_b",
448+
(
449+
"csc",
450+
"csr",
451+
pytest.param(
452+
"bsr",
453+
marks=pytest.mark.xfail(
454+
isinstance(get_default_mode().linker, NumbaLinker),
455+
reason="Numba does not support bsr",
456+
),
457+
),
458+
),
459+
)
460+
def test_dot_sparse_sparse(self, sparse_format_a, sparse_format_b):
431461
sparse_dtype = "float64"
432462
sp_mat = {
433463
"csc": scipy_sparse.csc_matrix,
434464
"csr": scipy_sparse.csr_matrix,
435465
"bsr": scipy_sparse.csr_matrix,
436466
}
437-
438-
for sparse_format_a in ["csc", "csr", "bsr"]:
439-
for sparse_format_b in ["csc", "csr", "bsr"]:
440-
a = SparseTensorType(sparse_format_a, dtype=sparse_dtype)()
441-
b = SparseTensorType(sparse_format_b, dtype=sparse_dtype)()
442-
d = pt.dot(a, b)
443-
f = pytensor.function([a, b], d)
444-
for M, N, K, nnz in [
445-
(4, 3, 2, 3),
446-
(40, 30, 20, 3),
447-
(40, 30, 20, 30),
448-
(400, 3000, 200, 6000),
449-
]:
450-
a_val = sp_mat[sparse_format_a](
451-
random_lil((M, N), sparse_dtype, nnz)
452-
)
453-
b_val = sp_mat[sparse_format_b](
454-
random_lil((N, K), sparse_dtype, nnz)
455-
)
456-
f(a_val, b_val)
467+
a = SparseTensorType(sparse_format_a, dtype=sparse_dtype)()
468+
b = SparseTensorType(sparse_format_b, dtype=sparse_dtype)()
469+
d = pt.dot(a, b)
470+
f = pytensor.function([a, b], d)
471+
for M, N, K, nnz in [
472+
(4, 3, 2, 3),
473+
(40, 30, 20, 3),
474+
(40, 30, 20, 30),
475+
(400, 3000, 200, 6000),
476+
]:
477+
a_val = sp_mat[sparse_format_a](random_lil((M, N), sparse_dtype, nnz))
478+
b_val = sp_mat[sparse_format_b](random_lil((N, K), sparse_dtype, nnz))
479+
f(a_val, b_val) # TODO: Test something
457480

458481
def test_tensor_dot_types(self):
459482
x = csc_matrix("x")

tests/sparse/test_rewriting.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pytensor import sparse
88
from pytensor.compile.mode import Mode, get_default_mode
99
from pytensor.configdefaults import config
10+
from pytensor.link.numba import NumbaLinker
1011
from pytensor.sparse.rewriting import SamplingDotCSR, sd_csc
1112
from pytensor.tensor.basic import as_tensor_variable
1213
from pytensor.tensor.math import sum as pt_sum
@@ -68,6 +69,10 @@ def test_local_csm_grad_c():
6869
@pytest.mark.skipif(
6970
not pytensor.config.cxx, reason="G++ not available, so we need to skip this test."
7071
)
72+
@pytest.mark.skipif(
73+
isinstance(get_default_mode().linker, NumbaLinker),
74+
reason="This is a C-specific test",
75+
)
7176
def test_local_mul_s_d():
7277
mode = get_default_mode()
7378
mode = mode.including("specialize", "local_mul_s_d")
@@ -86,6 +91,10 @@ def test_local_mul_s_d():
8691
@pytest.mark.skipif(
8792
not pytensor.config.cxx, reason="G++ not available, so we need to skip this test."
8893
)
94+
@pytest.mark.skipif(
95+
isinstance(get_default_mode().linker, NumbaLinker),
96+
reason="This is a C-specific test",
97+
)
8998
def test_local_mul_s_v():
9099
mode = get_default_mode()
91100
mode = mode.including("specialize", "local_mul_s_v")
@@ -104,6 +113,10 @@ def test_local_mul_s_v():
104113
@pytest.mark.skipif(
105114
not pytensor.config.cxx, reason="G++ not available, so we need to skip this test."
106115
)
116+
@pytest.mark.skipif(
117+
isinstance(get_default_mode().linker, NumbaLinker),
118+
reason="This is a C-specific test",
119+
)
107120
def test_local_structured_add_s_v():
108121
mode = get_default_mode()
109122
mode = mode.including("specialize", "local_structured_add_s_v")
@@ -122,6 +135,10 @@ def test_local_structured_add_s_v():
122135
@pytest.mark.skipif(
123136
not pytensor.config.cxx, reason="G++ not available, so we need to skip this test."
124137
)
138+
@pytest.mark.skipif(
139+
isinstance(get_default_mode().linker, NumbaLinker),
140+
reason="This is a C-specific test",
141+
)
125142
def test_local_sampling_dot_csr():
126143
mode = get_default_mode()
127144
mode = mode.including("specialize", "local_sampling_dot_csr")

tests/test_raise_op.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44

55
import pytensor
66
import pytensor.tensor as pt
7-
from pytensor.compile.mode import OPT_FAST_RUN, Mode, get_default_mode
7+
from pytensor.compile.mode import OPT_FAST_RUN, Mode
88
from pytensor.graph import vectorize_graph
99
from pytensor.graph.basic import Constant, equal_computations
10-
from pytensor.link.numba import NumbaLinker
1110
from pytensor.raise_op import Assert, CheckAndRaise, assert_op
1211
from pytensor.scalar.basic import ScalarType, float64
1312
from pytensor.sparse import as_sparse_variable
@@ -182,10 +181,6 @@ def test_infer_shape_scalar(self):
182181
)
183182

184183

185-
@pytest.mark.xfail(
186-
condition=isinstance(get_default_mode().linker, NumbaLinker),
187-
reason="Numba does not support Sparse Ops yet",
188-
)
189184
def test_CheckAndRaise_sparse_variable():
190185
check_and_raise = CheckAndRaise(ValueError, "sparse_check")
191186

tests/typed_list/test_basic.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import pytensor
88
import pytensor.typed_list
99
from pytensor import sparse
10-
from pytensor.compile import get_default_mode
11-
from pytensor.link.numba import NumbaLinker
1210
from pytensor.tensor.type import (
1311
TensorType,
1412
integer_dtypes,
@@ -454,10 +452,6 @@ def test_non_tensor_type(self):
454452

455453
assert f([[x, y], [x, y, y]], [x, y]) == 0
456454

457-
@pytest.mark.xfail(
458-
condition=isinstance(get_default_mode().linker, NumbaLinker),
459-
reason="Numba does not support Sparse Ops yet",
460-
)
461455
def test_sparse(self):
462456
mySymbolicSparseList = TypedListType(
463457
sparse.SparseTensorType("csr", pytensor.config.floatX)
@@ -525,10 +519,6 @@ def test_non_tensor_type(self):
525519

526520
assert f([[x, y], [x, y, y]], [x, y]) == 1
527521

528-
@pytest.mark.xfail(
529-
condition=isinstance(get_default_mode().linker, NumbaLinker),
530-
reason="Numba does not support Sparse Ops yet",
531-
)
532522
def test_sparse(self):
533523
mySymbolicSparseList = TypedListType(
534524
sparse.SparseTensorType("csr", pytensor.config.floatX)

0 commit comments

Comments
 (0)