diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 25efd69a8d..84875dac97 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -31,55 +31,47 @@ @pytest.mark.parametrize( - "inputs, input_vals, output_fn, exc", + "inputs, input_vals, output_fn", [ ( [pt.vector()], [rng.uniform(size=100).astype(config.floatX)], lambda x: pt.gammaln(x), - None, ), ( [pt.vector()], [rng.standard_normal(100).astype(config.floatX)], lambda x: pt.sigmoid(x), - None, ), ( [pt.vector()], [rng.standard_normal(100).astype(config.floatX)], lambda x: pt.log1mexp(x), - None, ), ( [pt.vector()], [rng.standard_normal(100).astype(config.floatX)], lambda x: pt.erf(x), - None, ), ( [pt.vector()], [rng.standard_normal(100).astype(config.floatX)], lambda x: pt.erfc(x), - None, ), ( [pt.vector()], [rng.standard_normal(100).astype(config.floatX)], lambda x: pt.erfcx(x), - None, ), ( [pt.vector() for i in range(4)], [rng.standard_normal(100).astype(config.floatX) for i in range(4)], lambda x, y, x1, y1: (x + y) * (x1 + y1) * y, - None, ), ( [pt.matrix(), pt.scalar()], [rng.normal(size=(2, 2)).astype(config.floatX), 0.0], lambda a, b: pt.switch(a, b, a), - None, ), ( [pt.scalar(), pt.scalar()], @@ -88,7 +80,6 @@ np.array(1.0, dtype=config.floatX), ], lambda x, y: pti.add_inplace(deep_copy_op(x), deep_copy_op(y)), - None, ), ( [pt.vector(), pt.vector()], @@ -97,7 +88,6 @@ rng.standard_normal(100).astype(config.floatX), ], lambda x, y: pti.add_inplace(deep_copy_op(x), deep_copy_op(y)), - None, ), ( [pt.vector(), pt.vector()], @@ -106,20 +96,30 @@ rng.standard_normal(100).astype(config.floatX), ], lambda x, y: scalar_my_multi_out(x, y), - None, ), ], + ids=[ + "gammaln", + "sigmoid", + "log1mexp", + "erf", + "erfc", + "erfcx", + "complex_arithmetic", + "switch", + "add_inplace_scalar", + "add_inplace_vector", + "scalar_multi_out", + ], ) -def test_Elemwise(inputs, input_vals, output_fn, exc): +def test_Elemwise(inputs, input_vals, output_fn): outputs = output_fn(*inputs) - cm = contextlib.suppress() if exc is None else pytest.raises(exc) - with cm: - compare_numba_and_py( - inputs, - outputs, - input_vals, - ) + compare_numba_and_py( + inputs, + outputs, + input_vals, + ) @pytest.mark.xfail(reason="Logic had to be reversed due to surprising segfaults")