Skip to content

Commit 437f046

Browse files
antonwolfyvtavana
andauthored
Check type of input in dpnp.repeat to raise a proper validation exception if any (#1894)
* Check type of input to raise a proper validation exception if any * Update dpnp/dpnp_iface_manipulation.py Co-authored-by: vtavana <[email protected]> --------- Co-authored-by: vtavana <[email protected]>
1 parent 067a784 commit 437f046

File tree

4 files changed

+237
-141
lines changed

4 files changed

+237
-141
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,12 +1248,16 @@ def repeat(a, repeats, axis=None):
12481248
----------
12491249
x : {dpnp.ndarray, usm_ndarray}
12501250
Input array.
1251-
repeat : int or array of int
1251+
repeats : {int, tuple, list, range, dpnp.ndarray, usm_ndarray}
12521252
The number of repetitions for each element. `repeats` is broadcasted to
12531253
fit the shape of the given axis.
1254-
axis : int, optional
1254+
If `repeats` is an array, it must have an integer data type.
1255+
Otherwise, `repeats` must be a Python integer or sequence of Python
1256+
integers (i.e., a tuple, list, or range).
1257+
axis : {None, int}, optional
12551258
The axis along which to repeat values. By default, use the flattened
12561259
input array, and return a flat output array.
1260+
Default: ``None``.
12571261
12581262
Returns
12591263
-------
@@ -1263,8 +1267,8 @@ def repeat(a, repeats, axis=None):
12631267
12641268
See Also
12651269
--------
1266-
:obj:`dpnp.tile` : Construct an array by repeating A the number of times
1267-
given by reps.
1270+
:obj:`dpnp.tile` : Tile an array.
1271+
:obj:`dpnp.unique` : Find the unique elements of an array.
12681272
12691273
Examples
12701274
--------
@@ -1286,14 +1290,15 @@ def repeat(a, repeats, axis=None):
12861290
12871291
"""
12881292

1289-
rep = repeats
1290-
if isinstance(repeats, dpnp_array):
1291-
rep = dpnp.get_usm_ndarray(repeats)
1293+
dpnp.check_supported_arrays_type(a)
1294+
if not isinstance(repeats, (int, tuple, list, range)):
1295+
repeats = dpnp.get_usm_ndarray(repeats)
1296+
12921297
if axis is None and a.ndim > 1:
1293-
usm_arr = dpnp.get_usm_ndarray(a.flatten())
1294-
else:
1295-
usm_arr = dpnp.get_usm_ndarray(a)
1296-
usm_arr = dpt.repeat(usm_arr, rep, axis=axis)
1298+
a = dpnp.ravel(a)
1299+
1300+
usm_arr = dpnp.get_usm_ndarray(a)
1301+
usm_arr = dpt.repeat(usm_arr, repeats, axis=axis)
12971302
return dpnp_array._create_from_usm_ndarray(usm_arr)
12981303

12991304

tests/test_arraymanipulation.py

Lines changed: 0 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,114 +1016,3 @@ def test_can_cast():
10161016
assert dpnp.can_cast(X, "float32") == numpy.can_cast(X_np, "float32")
10171017
assert dpnp.can_cast(X, dpnp.int32) == numpy.can_cast(X_np, numpy.int32)
10181018
assert dpnp.can_cast(X, dpnp.int64) == numpy.can_cast(X_np, numpy.int64)
1019-
1020-
1021-
def test_repeat_scalar_sequence_agreement():
1022-
x = dpnp.arange(5, dtype="i4")
1023-
expected_res = dpnp.empty(10, dtype="i4")
1024-
expected_res[1::2], expected_res[::2] = x, x
1025-
1026-
# scalar case
1027-
reps = 2
1028-
res = dpnp.repeat(x, reps)
1029-
assert dpnp.all(res == expected_res)
1030-
1031-
# tuple
1032-
reps = (2, 2, 2, 2, 2)
1033-
res = dpnp.repeat(x, reps)
1034-
assert dpnp.all(res == expected_res)
1035-
1036-
1037-
def test_repeat_as_broadcasting():
1038-
reps = 5
1039-
x = dpnp.arange(reps, dtype="i4")
1040-
x1 = x[:, dpnp.newaxis]
1041-
expected_res = dpnp.broadcast_to(x1, (reps, reps))
1042-
1043-
res = dpnp.repeat(x1, reps, axis=1)
1044-
assert dpnp.all(res == expected_res)
1045-
1046-
x2 = x[dpnp.newaxis, :]
1047-
expected_res = dpnp.broadcast_to(x2, (reps, reps))
1048-
1049-
res = dpnp.repeat(x2, reps, axis=0)
1050-
assert dpnp.all(res == expected_res)
1051-
1052-
1053-
def test_repeat_axes():
1054-
reps = 2
1055-
x = dpnp.reshape(dpnp.arange(5 * 10, dtype="i4"), (5, 10))
1056-
expected_res = dpnp.empty((x.shape[0] * 2, x.shape[1]), dtype=x.dtype)
1057-
expected_res[::2, :], expected_res[1::2] = x, x
1058-
res = dpnp.repeat(x, reps, axis=0)
1059-
assert dpnp.all(res == expected_res)
1060-
1061-
expected_res = dpnp.empty((x.shape[0], x.shape[1] * 2), dtype=x.dtype)
1062-
expected_res[:, ::2], expected_res[:, 1::2] = x, x
1063-
res = dpnp.repeat(x, reps, axis=1)
1064-
assert dpnp.all(res == expected_res)
1065-
1066-
1067-
def test_repeat_size_0_outputs():
1068-
x = dpnp.ones((3, 0, 5), dtype="i4")
1069-
reps = 10
1070-
res = dpnp.repeat(x, reps, axis=0)
1071-
assert res.size == 0
1072-
assert res.shape == (30, 0, 5)
1073-
1074-
res = dpnp.repeat(x, reps, axis=1)
1075-
assert res.size == 0
1076-
assert res.shape == (3, 0, 5)
1077-
1078-
res = dpnp.repeat(x, (2, 2, 2), axis=0)
1079-
assert res.size == 0
1080-
assert res.shape == (6, 0, 5)
1081-
1082-
x = dpnp.ones((3, 2, 5))
1083-
res = dpnp.repeat(x, 0, axis=1)
1084-
assert res.size == 0
1085-
assert res.shape == (3, 0, 5)
1086-
1087-
x = dpnp.ones((3, 2, 5))
1088-
res = dpnp.repeat(x, (0, 0), axis=1)
1089-
assert res.size == 0
1090-
assert res.shape == (3, 0, 5)
1091-
1092-
1093-
def test_repeat_strides():
1094-
reps = 2
1095-
x = dpnp.reshape(dpnp.arange(10 * 10, dtype="i4"), (10, 10))
1096-
x1 = x[:, ::-2]
1097-
expected_res = dpnp.empty((10, 10), dtype="i4")
1098-
expected_res[:, ::2], expected_res[:, 1::2] = x1, x1
1099-
res = dpnp.repeat(x1, reps, axis=1)
1100-
assert dpnp.all(res == expected_res)
1101-
res = dpnp.repeat(x1, (reps,) * x1.shape[1], axis=1)
1102-
assert dpnp.all(res == expected_res)
1103-
1104-
x1 = x[::-2, :]
1105-
expected_res = dpnp.empty((10, 10), dtype="i4")
1106-
expected_res[::2, :], expected_res[1::2, :] = x1, x1
1107-
res = dpnp.repeat(x1, reps, axis=0)
1108-
assert dpnp.all(res == expected_res)
1109-
res = dpnp.repeat(x1, (reps,) * x1.shape[0], axis=0)
1110-
assert dpnp.all(res == expected_res)
1111-
1112-
1113-
def test_repeat_casting():
1114-
x = dpnp.arange(5, dtype="i4")
1115-
# i4 is cast to i8
1116-
reps = dpnp.ones(5, dtype="i4")
1117-
res = dpnp.repeat(x, reps)
1118-
assert res.shape == x.shape
1119-
assert dpnp.all(res == x)
1120-
1121-
1122-
def test_repeat_strided_repeats():
1123-
x = dpnp.arange(5, dtype="i4")
1124-
reps = dpnp.ones(10, dtype="i8")
1125-
reps[::2] = 0
1126-
reps = reps[::-2]
1127-
res = dpnp.repeat(x, reps)
1128-
assert res.shape == x.shape
1129-
assert dpnp.all(res == x)

0 commit comments

Comments
 (0)