|
7 | 7 |
|
8 | 8 | import pytensor.sparse as ps |
9 | 9 | import pytensor.tensor as pt |
| 10 | +from pytensor.sparse.variable import SparseConstant |
10 | 11 |
|
11 | 12 |
|
12 | 13 | numba = pytest.importorskip("numba") |
@@ -161,6 +162,27 @@ def csr_matrix_constructor(data, indices, indptr): |
161 | 162 | assert (out_pt.indptr == inp.indptr).all() |
162 | 163 |
|
163 | 164 |
|
| 165 | +@pytest.mark.xfail(reason="We cannot lower constant SparseVariables yet") |
| 166 | +@pytest.mark.parametrize("cache", [True, False]) |
| 167 | +@pytest.mark.parametrize("format", ["csr", "csc"]) |
| 168 | +def test_constant(format, cache): |
| 169 | + x = sp.sparse.random(3, 3, density=0.5, format=format, random_state=166) |
| 170 | + x = ps.as_sparse(x) |
| 171 | + assert isinstance(x, SparseConstant) |
| 172 | + assert x.type.format == format |
| 173 | + y = pt.vector("y", shape=(3,)) |
| 174 | + out = x * y |
| 175 | + |
| 176 | + y_test = np.array([np.pi, np.e, np.euler_gamma]) |
| 177 | + with config.change_flags(numba__cache=cache): |
| 178 | + compare_numba_and_py_sparse( |
| 179 | + [y], |
| 180 | + [out], |
| 181 | + [y_test], |
| 182 | + eval_obj_mode=False, |
| 183 | + ) |
| 184 | + |
| 185 | + |
164 | 186 | @pytest.mark.parametrize("format", ["csr", "csc"]) |
165 | 187 | def test_simple_graph(format): |
166 | 188 | ps_matrix = ps.csr_matrix if format == "csr" else ps.csc_matrix |
|
0 commit comments