|
7 | 7 |
|
8 | 8 | import pytensor.sparse as ps |
9 | 9 | import pytensor.tensor as pt |
| 10 | +from pytensor.graph import Apply, Op |
| 11 | +from pytensor.tensor.type import DenseTensorType |
10 | 12 |
|
11 | 13 |
|
12 | 14 | numba = pytest.importorskip("numba") |
|
15 | 17 | # Make sure the Numba customizations are loaded |
16 | 18 | import pytensor.link.numba.dispatch.sparse # noqa: F401 |
17 | 19 | from pytensor import config |
18 | | -from pytensor.sparse import Dot, SparseTensorType |
| 20 | +from pytensor.sparse import SparseTensorType |
19 | 21 | from tests.link.numba.test_basic import compare_numba_and_py |
20 | 22 |
|
21 | 23 |
|
@@ -108,20 +110,36 @@ def test_fn(x): |
108 | 110 | assert y is not x and np.all(x.data == y.data) and np.all(x.indices == y.indices) |
109 | 111 |
|
110 | 112 |
|
111 | | -def test_sparse_objmode(): |
112 | | - x = SparseTensorType("csc", dtype=config.floatX)() |
113 | | - y = SparseTensorType("csc", dtype=config.floatX)() |
| 113 | +@pytest.mark.parametrize("format", ["csc", "csr"]) |
| 114 | +@pytest.mark.parametrize("dense_out", [True, False]) |
| 115 | +def test_sparse_objmode(format, dense_out): |
| 116 | + class SparseTestOp(Op): |
| 117 | + def make_node(self, x): |
| 118 | + out = x.type.clone(shape=(1, x.type.shape[-1]))() |
| 119 | + if dense_out: |
| 120 | + out = out.todense().type() |
| 121 | + return Apply(self, [x], [out]) |
114 | 122 |
|
115 | | - out = Dot()(x, y) |
| 123 | + def perform(self, node, inputs, output_storage): |
| 124 | + [x] = inputs |
| 125 | + [out] = output_storage |
| 126 | + out[0] = x[0] |
| 127 | + if dense_out: |
| 128 | + out[0] = out[0].todense() |
116 | 129 |
|
117 | | - x_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX) |
118 | | - y_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX) |
| 130 | + x = SparseTensorType(format, dtype=config.floatX, shape=(5, 5))() |
| 131 | + |
| 132 | + out = SparseTestOp()(x) |
| 133 | + assert out.type.shape == (1, 5) |
| 134 | + assert isinstance(out.type, DenseTensorType if dense_out else SparseTensorType) |
| 135 | + |
| 136 | + x_val = sp.sparse.random(5, 5, density=0.25, dtype=config.floatX, format=format) |
119 | 137 |
|
120 | 138 | with pytest.warns( |
121 | 139 | UserWarning, |
122 | | - match="Numba will use object mode to run SparseDot's perform method", |
| 140 | + match="Numba will use object mode to run SparseTestOp's perform method", |
123 | 141 | ): |
124 | | - compare_numba_and_py([x, y], out, [x_val, y_val]) |
| 142 | + compare_numba_and_py_sparse([x], out, [x_val]) |
125 | 143 |
|
126 | 144 |
|
127 | 145 | def test_overload_csr_matrix_constructor(): |
|
0 commit comments