Skip to content

Commit fc52460

Browse files
committed
Add xfail test with psi inner function
1 parent b6fbacd commit fc52460

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

tests/link/numba/test_elemwise.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,9 @@ def test_scalar_loop():
606606
)
607607

608608

609+
@pytest.mark.xfail(
610+
reason="Numba fails due to https://github.com/numba/numba/issues/10098"
611+
)
609612
def test_gammainc_wrt_k_grad():
610613
x = pt.vector("x", dtype="float64")
611614
k = pt.vector("k", dtype="float64")

tests/link/numba/test_scalar.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,15 @@ def test_scalar_while_loop(self):
245245
np.testing.assert_allclose(fn(n_steps=20, x0=1), [10, True])
246246
np.testing.assert_allclose(fn(n_steps=5, x0=1), [6, False])
247247
np.testing.assert_allclose(fn(n_steps=0, x0=1), [1, False])
248+
249+
@pytest.mark.xfail("Fails due to https://github.com/numba/numba/issues/10098")
250+
def test_loop_with_cython_wrapped_op(self):
251+
x = ps.float64("x")
252+
op = ScalarLoop(init=[x], update=[ps.psi(x)])
253+
out = op(1, x)
254+
255+
fn = function([x], out, mode=numba_mode)
256+
x_test = np.float64(0.5)
257+
res = fn(x_test)
258+
expected_res = ps.psi(x).eval({x: x_test})
259+
np.testing.assert_allclose(res, expected_res)

0 commit comments

Comments
 (0)