Skip to content

Commit 1de9697

Browse files
committed
Add xfail test for constants
1 parent ffe08e5 commit 1de9697

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

tests/link/numba/sparse/test_basic.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import pytensor.sparse as ps
99
import pytensor.tensor as pt
10+
from pytensor.sparse.variable import SparseConstant
1011

1112

1213
numba = pytest.importorskip("numba")
@@ -161,6 +162,27 @@ def csr_matrix_constructor(data, indices, indptr):
161162
assert (out_pt.indptr == inp.indptr).all()
162163

163164

165+
@pytest.mark.xfail(reason="We cannot lower constant SparseVariables yet")
166+
@pytest.mark.parametrize("cache", [True, False])
167+
@pytest.mark.parametrize("format", ["csr", "csc"])
168+
def test_constant(format, cache):
169+
x = sp.sparse.random(3, 3, density=0.5, format=format, random_state=166)
170+
x = ps.as_sparse(x)
171+
assert isinstance(x, SparseConstant)
172+
assert x.type.format == format
173+
y = pt.vector("y", shape=(3,))
174+
out = x * y
175+
176+
y_test = np.array([np.pi, np.e, np.euler_gamma])
177+
with config.change_flags(numba__cache=cache):
178+
compare_numba_and_py_sparse(
179+
[y],
180+
[out],
181+
[y_test],
182+
eval_obj_mode=False,
183+
)
184+
185+
164186
@pytest.mark.parametrize("format", ["csr", "csc"])
165187
def test_simple_graph(format):
166188
ps_matrix = ps.csr_matrix if format == "csr" else ps.csc_matrix

0 commit comments

Comments
 (0)