Skip to content

Commit 383e162

Browse files
committed
More comprehensive sparse object mode test
1 parent bbf7b8e commit 383e162

File tree

1 file changed

+27
-9
lines changed

1 file changed

+27
-9
lines changed

tests/link/numba/sparse/test_basic.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import pytensor.sparse as ps
99
import pytensor.tensor as pt
10+
from pytensor.graph import Apply, Op
11+
from pytensor.tensor.type import DenseTensorType
1012

1113

1214
numba = pytest.importorskip("numba")
@@ -15,7 +17,7 @@
1517
# Make sure the Numba customizations are loaded
1618
import pytensor.link.numba.dispatch.sparse # noqa: F401
1719
from pytensor import config
18-
from pytensor.sparse import Dot, SparseTensorType
20+
from pytensor.sparse import SparseTensorType
1921
from tests.link.numba.test_basic import compare_numba_and_py
2022

2123

@@ -108,20 +110,36 @@ def test_fn(x):
108110
assert y is not x and np.all(x.data == y.data) and np.all(x.indices == y.indices)
109111

110112

111-
def test_sparse_objmode():
112-
x = SparseTensorType("csc", dtype=config.floatX)()
113-
y = SparseTensorType("csc", dtype=config.floatX)()
113+
@pytest.mark.parametrize("format", ["csc", "csr"])
114+
@pytest.mark.parametrize("dense_out", [True, False])
115+
def test_sparse_objmode(format, dense_out):
116+
class SparseTestOp(Op):
117+
def make_node(self, x):
118+
out = x.type.clone(shape=(1, x.type.shape[-1]))()
119+
if dense_out:
120+
out = out.todense().type()
121+
return Apply(self, [x], [out])
114122

115-
out = Dot()(x, y)
123+
def perform(self, node, inputs, output_storage):
124+
[x] = inputs
125+
[out] = output_storage
126+
out[0] = x[0]
127+
if dense_out:
128+
out[0] = out[0].todense()
116129

117-
x_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX)
118-
y_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX)
130+
x = SparseTensorType(format, dtype=config.floatX, shape=(5, 5))()
131+
132+
out = SparseTestOp()(x)
133+
assert out.type.shape == (1, 5)
134+
assert isinstance(out.type, DenseTensorType if dense_out else SparseTensorType)
135+
136+
x_val = sp.sparse.random(5, 5, density=0.25, dtype=config.floatX, format=format)
119137

120138
with pytest.warns(
121139
UserWarning,
122-
match="Numba will use object mode to run SparseDot's perform method",
140+
match="Numba will use object mode to run SparseTestOp's perform method",
123141
):
124-
compare_numba_and_py([x, y], out, [x_val, y_val])
142+
compare_numba_and_py_sparse([x], out, [x_val])
125143

126144

127145
def test_overload_csr_matrix_constructor():

0 commit comments

Comments
 (0)