Skip to content

Commit 39d1198

Browse files
committed
refactor test_trig_out_type
1 parent 8c94751 commit 39d1198

File tree

1 file changed

+5
-14
lines changed

1 file changed

+5
-14
lines changed

dpctl/tests/elementwise/test_trigonometric.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import dpctl.tensor as dpt
2626
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2727

28-
from .utils import _all_dtypes, _map_to_device_dtype
28+
from .utils import _all_dtypes
2929

3030
_trig_funcs = [(np.sin, dpt.sin), (np.cos, dpt.cos), (np.tan, dpt.tan)]
3131
_inv_trig_funcs = [
@@ -37,23 +37,14 @@
3737
_dpt_funcs = [t[1] for t in _all_funcs]
3838

3939

40-
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
40+
@pytest.mark.parametrize("func", _dpt_funcs)
4141
@pytest.mark.parametrize("dtype", _all_dtypes)
42-
def test_trig_out_type(np_call, dpt_call, dtype):
42+
def test_trig_out_type(func, dtype):
4343
q = get_queue_or_skip()
4444
skip_if_dtype_not_supported(dtype, q)
4545

46-
X = dpt.asarray(0, dtype=dtype, sycl_queue=q)
47-
expected_dtype = np_call(np.array(0, dtype=dtype)).dtype
48-
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
49-
assert dpt_call(X).dtype == expected_dtype
50-
51-
X = dpt.asarray(0, dtype=dtype, sycl_queue=q)
52-
expected_dtype = np_call(np.array(0, dtype=dtype)).dtype
53-
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
54-
Y = dpt.empty_like(X, dtype=expected_dtype)
55-
dpt_call(X, out=Y)
56-
assert_allclose(dpt.asnumpy(dpt_call(X)), dpt.asnumpy(Y))
46+
x = dpt.asarray(0, dtype=dtype, sycl_queue=q)
47+
assert func(x).dtype == x.dtype
5748

5849

5950
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)

0 commit comments

Comments
 (0)