Skip to content

Commit 8bd0e7f

Browse files
committed
refactor test_hyper_out_type
1 parent 39d1198 commit 8bd0e7f

File tree

1 file changed

+5
-16
lines changed

1 file changed

+5
-16
lines changed

dpctl/tests/elementwise/test_hyperbolic.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import dpctl.tensor as dpt
2626
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2727

28-
from .utils import _all_dtypes, _map_to_device_dtype
28+
from .utils import _all_dtypes
2929

3030
_hyper_funcs = [(np.sinh, dpt.sinh), (np.cosh, dpt.cosh), (np.tanh, dpt.tanh)]
3131
_inv_hyper_funcs = [
@@ -37,25 +37,14 @@
3737
_dpt_funcs = [t[1] for t in _all_funcs]
3838

3939

40-
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
40+
@pytest.mark.parametrize("func", _dpt_funcs)
4141
@pytest.mark.parametrize("dtype", _all_dtypes)
42-
def test_hyper_out_type(np_call, dpt_call, dtype):
42+
def test_hyper_out_type(func, dtype):
4343
q = get_queue_or_skip()
4444
skip_if_dtype_not_supported(dtype, q)
4545

46-
a = 1 if np_call == np.arccosh else 0
47-
48-
X = dpt.asarray(a, dtype=dtype, sycl_queue=q)
49-
expected_dtype = np_call(np.array(a, dtype=dtype)).dtype
50-
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
51-
assert dpt_call(X).dtype == expected_dtype
52-
53-
X = dpt.asarray(a, dtype=dtype, sycl_queue=q)
54-
expected_dtype = np_call(np.array(a, dtype=dtype)).dtype
55-
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
56-
Y = dpt.empty_like(X, dtype=expected_dtype)
57-
dpt_call(X, out=Y)
58-
assert_allclose(dpt.asnumpy(dpt_call(X)), dpt.asnumpy(Y))
46+
x = dpt.asarray(0, dtype=dtype, sycl_queue=q)
47+
assert func(x).dtype == x.dtype
5948

6049

6150
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)

0 commit comments

Comments
 (0)