|
25 | 25 | import dpctl.tensor as dpt
|
26 | 26 | from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
|
27 | 27 |
|
28 |
| -from .utils import _all_dtypes, _map_to_device_dtype |
| 28 | +from .utils import _all_dtypes |
29 | 29 |
|
30 | 30 | _trig_funcs = [(np.sin, dpt.sin), (np.cos, dpt.cos), (np.tan, dpt.tan)]
|
31 | 31 | _inv_trig_funcs = [
|
|
37 | 37 | _dpt_funcs = [t[1] for t in _all_funcs]
|
38 | 38 |
|
39 | 39 |
|
40 |
| -@pytest.mark.parametrize("np_call, dpt_call", _all_funcs) |
| 40 | +@pytest.mark.parametrize("func", _dpt_funcs) |
41 | 41 | @pytest.mark.parametrize("dtype", _all_dtypes)
|
42 |
| -def test_trig_out_type(np_call, dpt_call, dtype): |
| 42 | +def test_trig_out_type(func, dtype): |
43 | 43 | q = get_queue_or_skip()
|
44 | 44 | skip_if_dtype_not_supported(dtype, q)
|
45 | 45 |
|
46 |
| - X = dpt.asarray(0, dtype=dtype, sycl_queue=q) |
47 |
| - expected_dtype = np_call(np.array(0, dtype=dtype)).dtype |
48 |
| - expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device) |
49 |
| - assert dpt_call(X).dtype == expected_dtype |
50 |
| - |
51 |
| - X = dpt.asarray(0, dtype=dtype, sycl_queue=q) |
52 |
| - expected_dtype = np_call(np.array(0, dtype=dtype)).dtype |
53 |
| - expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device) |
54 |
| - Y = dpt.empty_like(X, dtype=expected_dtype) |
55 |
| - dpt_call(X, out=Y) |
56 |
| - assert_allclose(dpt.asnumpy(dpt_call(X)), dpt.asnumpy(Y)) |
| 46 | + x = dpt.asarray(0, dtype=dtype, sycl_queue=q) |
| 47 | + assert func(x).dtype == x.dtype |
57 | 48 |
|
58 | 49 |
|
59 | 50 | @pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
|
|
0 commit comments