Skip to content

Commit a107236

Browse files
committed
Add more tests
1 parent 65a75fb commit a107236

File tree

2 files changed

+72
-4
lines changed

2 files changed

+72
-4
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -633,8 +633,10 @@ def resolve_weak_types_2nd_arg_int(o1_dtype, o2_dtype, sycl_dev):
633633
o1_kind_num = dtu._strong_dtype_num_kind(o1_dtype)
634634
o2_kind_num = dtu._weak_type_num_kind(o2_dtype)
635635
if o2_kind_num < o1_kind_num:
636-
if isinstance(o2_dtype, (dtu.WeakBooleanType, dtu.WeakIntegralType)):
637-
print()
638-
print(o1_dtype, dpt.dtype(dti.default_device_int_type(sycl_dev)))
639-
return o1_dtype, dpt.dtype(dti.default_device_int_type(sycl_dev))
636+
if isinstance(
637+
o2_dtype, (dtu.WeakBooleanType, dtu.WeakIntegralType)
638+
):
639+
return o1_dtype, dpt.dtype(
640+
dti.default_device_int_type(sycl_dev)
641+
)
640642
return dtu._resolve_weak_types(o1_dtype, o2_dtype, sycl_dev)

tests/test_mathematical.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,6 +1289,72 @@ def test_op_multiple_dtypes(dtype1, func, dtype2, data):
12891289
assert_allclose(result, expected)
12901290

12911291

1292+
class TestLdexp:
1293+
@pytest.mark.parametrize("mant_dt", get_float_dtypes())
1294+
@pytest.mark.parametrize("exp_dt", get_integer_dtypes())
1295+
def test_basic(self, mant_dt, exp_dt):
1296+
mant = numpy.array(2.0, dtype=mant_dt)
1297+
exp = numpy.array(3, dtype=exp_dt)
1298+
imant, iexp = dpnp.array(mant), dpnp.array(exp)
1299+
1300+
result = dpnp.ldexp(imant, iexp)
1301+
expected = numpy.ldexp(mant, exp)
1302+
assert_almost_equal(result, expected)
1303+
1304+
def test_float_scalar(self):
1305+
a = numpy.array(3)
1306+
ia = dpnp.array(a)
1307+
1308+
result = dpnp.ldexp(2.0, ia)
1309+
expected = numpy.ldexp(2.0, a)
1310+
assert_almost_equal(result, expected)
1311+
1312+
@pytest.mark.parametrize("max_min", ["max", "min"])
1313+
def test_overflow(self, max_min):
1314+
exp_val = getattr(numpy.iinfo(numpy.dtype("l")), max_min)
1315+
1316+
result = dpnp.ldexp(dpnp.array(2.0), exp_val)
1317+
with numpy.errstate(over="ignore"):
1318+
expected = numpy.ldexp(numpy.array(2.0), exp_val)
1319+
assert_equal(result, expected)
1320+
1321+
@pytest.mark.parametrize("val", [numpy.nan, numpy.inf, -numpy.inf])
1322+
def test_nan_int_mant(self, val):
1323+
mant = numpy.array(val)
1324+
imant = dpnp.array(mant)
1325+
1326+
result = dpnp.ldexp(imant, 5)
1327+
expected = numpy.ldexp(mant, 5)
1328+
assert_equal(result, expected)
1329+
1330+
def test_zero_exp(self):
1331+
exp = numpy.array(0)
1332+
iexp = dpnp.array(exp)
1333+
1334+
result = dpnp.ldexp(-2.5, iexp)
1335+
expected = numpy.ldexp(-2.5, exp)
1336+
assert_equal(result, expected)
1337+
1338+
@pytest.mark.parametrize("stride", [-4, -2, -1, 1, 2, 4])
1339+
@pytest.mark.parametrize("dt", get_float_dtypes())
1340+
def test_strides(self, stride, dt):
1341+
mant = numpy.array(
1342+
[0.125, 0.25, 0.5, 1.0, 1.0, 2.0, 4.0, 8.0], dtype=dt
1343+
)
1344+
exp = numpy.array([3, 2, 1, 0, 0, -1, -2, -3], dtype="i")
1345+
out = numpy.zeros(8, dtype=dt)
1346+
imant, iexp, iout = dpnp.array(mant), dpnp.array(exp), dpnp.array(out)
1347+
1348+
result = dpnp.ldexp(imant[::stride], iexp[::stride], out=iout[::stride])
1349+
expected = numpy.ldexp(mant[::stride], exp[::stride], out=out[::stride])
1350+
assert_equal(result, expected)
1351+
1352+
@pytest.mark.parametrize("xp", [dpnp, numpy])
1353+
def test_uint64_exp(self, xp):
1354+
x = xp.array(4, dtype=numpy.uint64)
1355+
assert_raises((ValueError, TypeError), xp.ldexp, 7.3, x)
1356+
1357+
12921358
@pytest.mark.parametrize(
12931359
"rhs", [[[1, 2, 3], [4, 5, 6]], [2.0, 1.5, 1.0], 3, 0.3]
12941360
)

0 commit comments

Comments
 (0)