Skip to content

Commit 14e1044

Browse files
author
Vahid Tavanashad
committed
address comments
1 parent a65bb9a commit 14e1044

File tree

2 files changed

+32
-30
lines changed

2 files changed

+32
-30
lines changed

dpnp/tests/test_binary_ufuncs.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,7 @@ def test_invalid_out(self, xp, out):
141141

142142
@pytest.mark.parametrize("func", ["fmax", "fmin", "maximum", "minimum"])
143143
class TestBoundFuncs:
144-
@pytest.mark.parametrize(
145-
"dtype", get_all_dtypes(no_bool=True, no_complex=True)
146-
)
144+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
147145
def test_out(self, func, dtype):
148146
a = generate_random_numpy_array(10, dtype)
149147
b = generate_random_numpy_array(10, dtype)
@@ -157,7 +155,7 @@ def test_out(self, func, dtype):
157155
assert_dtype_allclose(result, expected)
158156

159157
@pytest.mark.parametrize(
160-
"dtype", get_all_dtypes(no_bool=True, no_complex=True)
158+
"dtype", get_all_dtypes(no_none=True, no_bool=True)
161159
)
162160
def test_out_overlap(self, func, dtype):
163161
size = 15
@@ -190,17 +188,15 @@ def test_invalid_out(self, func, xp, out):
190188

191189
class TestDivide:
192190
@pytest.mark.usefixtures("suppress_divide_invalid_numpy_warnings")
193-
@pytest.mark.parametrize(
194-
"dtype", get_all_dtypes(no_none=True, no_bool=True)
195-
)
191+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
196192
def test_divide(self, dtype):
197193
a = generate_random_numpy_array(10, dtype)
198194
b = generate_random_numpy_array(10, dtype)
199195
expected = numpy.divide(a, b)
200196

201197
ia, ib = dpnp.array(a), dpnp.array(b)
202-
if numpy.issubdtype(dtype, numpy.integer):
203-
out_dtype = map_dtype_to_device(dpnp.float64, ia.sycl_device)
198+
if numpy.issubdtype(dtype, numpy.bool):
199+
out_dtype = dpnp.float64
204200
else:
205201
out_dtype = _get_output_data_type(dtype)
206202
iout = dpnp.empty(expected.shape, dtype=out_dtype)
@@ -379,9 +375,9 @@ def test_invalid_out(self, func, xp, out):
379375
assert_raises(TypeError, getattr(xp, func), a, 2, out)
380376

381377

378+
@pytest.mark.parametrize("func", ["fmax", "fmin"])
382379
class TestFmaxFmin:
383380
@pytest.mark.skipif(not has_support_aspect16(), reason="no fp16 support")
384-
@pytest.mark.parametrize("func", ["fmax", "fmin"])
385381
def test_half(self, func):
386382
a = numpy.array([0, 1, 2, 4, 2], dtype=numpy.float16)
387383
b = numpy.array([-2, 5, 1, 4, 3], dtype=numpy.float16)
@@ -396,7 +392,6 @@ def test_half(self, func):
396392
expected = getattr(numpy, func)(b, c)
397393
assert_equal(result, expected)
398394

399-
@pytest.mark.parametrize("func", ["fmax", "fmin"])
400395
@pytest.mark.parametrize("dtype", get_float_dtypes())
401396
def test_float_nans(self, func, dtype):
402397
a = numpy.array([0, numpy.nan, numpy.nan], dtype=dtype)
@@ -407,7 +402,6 @@ def test_float_nans(self, func, dtype):
407402
expected = getattr(numpy, func)(a, b)
408403
assert_equal(result, expected)
409404

410-
@pytest.mark.parametrize("func", ["fmax", "fmin"])
411405
@pytest.mark.parametrize("dtype", get_complex_dtypes())
412406
@pytest.mark.parametrize(
413407
"nan_val",
@@ -427,7 +421,6 @@ def test_complex_nans(self, func, dtype, nan_val):
427421
expected = getattr(numpy, func)(a, b)
428422
assert_equal(result, expected)
429423

430-
@pytest.mark.parametrize("func", ["fmax", "fmin"])
431424
@pytest.mark.parametrize("dtype", get_float_dtypes(no_float16=False))
432425
def test_precision(self, func, dtype):
433426
dtmin = numpy.finfo(dtype).min

dpnp/tests/test_umath.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def get_id(val):
7474
return val.__str__()
7575

7676

77-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
7877
@pytest.mark.usefixtures("suppress_divide_invalid_numpy_warnings")
7978
@pytest.mark.parametrize("test_cases", test_cases, ids=get_id)
8079
def test_umaths(test_cases):
@@ -134,7 +133,9 @@ def _get_output_data_type(dtype):
134133

135134

136135
class TestArctan2:
137-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
136+
@pytest.mark.parametrize(
137+
"dtype", get_all_dtypes(no_none=True, no_complex=True)
138+
)
138139
def test_arctan2(self, dtype):
139140
a = generate_random_numpy_array(10, dtype, low=0)
140141
b = generate_random_numpy_array(10, dtype, low=0)
@@ -149,10 +150,10 @@ def test_arctan2(self, dtype):
149150
assert_dtype_allclose(result, expected)
150151

151152
@pytest.mark.parametrize(
152-
"dtype", get_all_dtypes(no_complex=True, no_none=True)[:-1]
153+
"dtype", get_all_dtypes(no_none=True, no_complex=True)[:-1]
153154
)
154155
def test_invalid_dtype(self, dtype):
155-
dpnp_dtype = get_all_dtypes(no_complex=True, no_none=True)[-1]
156+
dpnp_dtype = get_all_dtypes(no_none=True, no_complex=True)[-1]
156157
a = dpnp.arange(10, dtype=dpnp_dtype)
157158
iout = dpnp.empty(10, dtype=dtype)
158159

@@ -178,7 +179,9 @@ def test_alias(self):
178179

179180

180181
class TestCbrt:
181-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
182+
@pytest.mark.parametrize(
183+
"dtype", get_all_dtypes(no_none=True, no_complex=True)
184+
)
182185
def test_cbrt(self, dtype):
183186
a = generate_random_numpy_array(10, dtype)
184187
expected = numpy.cbrt(a)
@@ -192,10 +195,10 @@ def test_cbrt(self, dtype):
192195
assert_dtype_allclose(result, expected)
193196

194197
@pytest.mark.parametrize(
195-
"dtype", get_all_dtypes(no_complex=True, no_none=True)[:-1]
198+
"dtype", get_all_dtypes(no_none=True, no_complex=True)[:-1]
196199
)
197200
def test_invalid_dtype(self, dtype):
198-
dpnp_dtype = get_all_dtypes(no_complex=True, no_none=True)[-1]
201+
dpnp_dtype = get_all_dtypes(no_none=True, no_complex=True)[-1]
199202
a = dpnp.arange(10, dtype=dpnp_dtype)
200203
iout = dpnp.empty(10, dtype=dtype)
201204

@@ -214,7 +217,9 @@ def test_invalid_shape(self, shape):
214217

215218

216219
class TestCopySign:
217-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
220+
@pytest.mark.parametrize(
221+
"dtype", get_all_dtypes(no_none=True, no_complex=True)
222+
)
218223
def test_copysign(self, dtype):
219224
a = generate_random_numpy_array(10, dtype, low=0)
220225
b = generate_random_numpy_array(10, dtype, low=0)
@@ -229,10 +234,10 @@ def test_copysign(self, dtype):
229234
assert_dtype_allclose(result, expected)
230235

231236
@pytest.mark.parametrize(
232-
"dtype", get_all_dtypes(no_complex=True, no_none=True)[:-1]
237+
"dtype", get_all_dtypes(no_none=True, no_complex=True)[:-1]
233238
)
234239
def test_invalid_dtype(self, dtype):
235-
dpnp_dtype = get_all_dtypes(no_complex=True, no_none=True)[-1]
240+
dpnp_dtype = get_all_dtypes(no_none=True, no_complex=True)[-1]
236241
a = dpnp.arange(10, dtype=dpnp_dtype)
237242
iout = dpnp.empty(10, dtype=dtype)
238243
with pytest.raises(ValueError):
@@ -331,7 +336,9 @@ def test_nan_infs_base(self, exp_val, dtype):
331336

332337

333338
class TestLogAddExp:
334-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
339+
@pytest.mark.parametrize(
340+
"dtype", get_all_dtypes(no_none=True, no_complex=True)
341+
)
335342
def test_logaddexp(self, dtype):
336343
a = generate_random_numpy_array(10, dtype, low=0)
337344
b = generate_random_numpy_array(10, dtype, low=0)
@@ -346,10 +353,10 @@ def test_logaddexp(self, dtype):
346353
assert_dtype_allclose(result, expected)
347354

348355
@pytest.mark.parametrize(
349-
"dtype", get_all_dtypes(no_complex=True, no_none=True)[:-1]
356+
"dtype", get_all_dtypes(no_none=True, no_complex=True)[:-1]
350357
)
351358
def test_invalid_dtype(self, dtype):
352-
dpnp_dtype = get_all_dtypes(no_complex=True, no_none=True)[-1]
359+
dpnp_dtype = get_all_dtypes(no_none=True, no_complex=True)[-1]
353360
a = dpnp.arange(10, dtype=dpnp_dtype)
354361
iout = dpnp.empty(10, dtype=dtype)
355362
with pytest.raises(ValueError):
@@ -510,7 +517,9 @@ def test_invalid_shape(self, shape):
510517

511518
class TestRsqrt:
512519
@pytest.mark.usefixtures("suppress_divide_numpy_warnings")
513-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
520+
@pytest.mark.parametrize(
521+
"dtype", get_all_dtypes(no_none=True, no_complex=True)
522+
)
514523
def test_rsqrt(self, dtype):
515524
a = generate_random_numpy_array(10, dtype, low=0)
516525
expected = numpy.reciprocal(numpy.sqrt(a))
@@ -524,10 +533,10 @@ def test_rsqrt(self, dtype):
524533
assert_dtype_allclose(result, expected)
525534

526535
@pytest.mark.parametrize(
527-
"dtype", get_all_dtypes(no_complex=True, no_none=True)[:-1]
536+
"dtype", get_all_dtypes(no_none=True, no_complex=True)[:-1]
528537
)
529538
def test_invalid_dtype(self, dtype):
530-
dpnp_dtype = get_all_dtypes(no_complex=True, no_none=True)[-1]
539+
dpnp_dtype = get_all_dtypes(no_none=True, no_complex=True)[-1]
531540
a = dpnp.arange(10, dtype=dpnp_dtype)
532541
iout = dpnp.empty(10, dtype=dtype)
533542

@@ -554,7 +563,7 @@ def test_invalid_out(self, out):
554563

555564

556565
class TestSquare:
557-
@pytest.mark.parametrize("dtype", get_all_dtypes())
566+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
558567
def test_square(self, dtype):
559568
a = generate_random_numpy_array(10, dtype)
560569
expected = numpy.square(a)

0 commit comments

Comments
 (0)