|
25 | 25 | import dpctl.tensor as dpt |
26 | 26 | from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported |
27 | 27 |
|
28 | | -from .utils import _all_dtypes, _map_to_device_dtype |
| 28 | +from .utils import _all_dtypes |
29 | 29 |
|
30 | 30 | _hyper_funcs = [(np.sinh, dpt.sinh), (np.cosh, dpt.cosh), (np.tanh, dpt.tanh)] |
31 | 31 | _inv_hyper_funcs = [ |
|
37 | 37 | _dpt_funcs = [t[1] for t in _all_funcs] |
38 | 38 |
|
39 | 39 |
|
40 | | -@pytest.mark.parametrize("np_call, dpt_call", _all_funcs) |
| 40 | +@pytest.mark.parametrize("func", _dpt_funcs) |
41 | 41 | @pytest.mark.parametrize("dtype", _all_dtypes) |
42 | | -def test_hyper_out_type(np_call, dpt_call, dtype): |
| 42 | +def test_hyper_out_type(func, dtype): |
43 | 43 | q = get_queue_or_skip() |
44 | 44 | skip_if_dtype_not_supported(dtype, q) |
45 | 45 |
|
46 | | - a = 1 if np_call == np.arccosh else 0 |
47 | | - |
48 | | - X = dpt.asarray(a, dtype=dtype, sycl_queue=q) |
49 | | - expected_dtype = np_call(np.array(a, dtype=dtype)).dtype |
50 | | - expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device) |
51 | | - assert dpt_call(X).dtype == expected_dtype |
52 | | - |
53 | | - X = dpt.asarray(a, dtype=dtype, sycl_queue=q) |
54 | | - expected_dtype = np_call(np.array(a, dtype=dtype)).dtype |
55 | | - expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device) |
56 | | - Y = dpt.empty_like(X, dtype=expected_dtype) |
57 | | - dpt_call(X, out=Y) |
58 | | - assert_allclose(dpt.asnumpy(dpt_call(X)), dpt.asnumpy(Y)) |
| 46 | + x = dpt.asarray(0, dtype=dtype, sycl_queue=q) |
| 47 | + assert func(x).dtype == x.dtype |
59 | 48 |
|
60 | 49 |
|
61 | 50 | @pytest.mark.parametrize("np_call, dpt_call", _all_funcs) |
|
0 commit comments