Skip to content

Commit 9032ab8

Browse files
Heeded pylint warnings on _reshape.py
1 parent 533a625 commit 9032ab8

File tree

1 file changed

+17
-18
lines changed

1 file changed

+17
-18
lines changed

dpctl/tensor/_reshape.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from dpctl.tensor._copy_utils import _copy_from_usm_ndarray_to_usm_ndarray
2222
from dpctl.tensor._tensor_impl import _copy_usm_ndarray_for_reshape
2323

24+
__doc__ = "Implementation module for :func:`dpctl.tensor.reshape`."
25+
2426

2527
def _make_unit_indexes(shape):
2628
"""
@@ -67,10 +69,8 @@ def reshaped_strides(old_sh, old_sts, new_sh, order="C"):
6769
]
6870
]
6971
valid = all(
70-
[
71-
check_st == old_st or old_dim == 1
72-
for check_st, old_st, old_dim in zip(check_sts, old_sts, old_sh)
73-
]
72+
check_st == old_st or old_dim == 1
73+
for check_st, old_st, old_dim in zip(check_sts, old_sts, old_sh)
7474
)
7575
return new_sts if valid else None
7676

@@ -82,7 +82,7 @@ def reshape(X, newshape, order="C", copy=None):
8282
Reshapes given usm_ndarray into new shape. Returns a view, if possible,
8383
a copy otherwise. Memory layout of the copy is controlled by order keyword.
8484
"""
85-
if type(X) is not dpt.usm_ndarray:
85+
if not isinstance(X, dpt.usm_ndarray):
8686
raise TypeError
8787
if not isinstance(newshape, (list, tuple)):
8888
newshape = (newshape,)
@@ -99,10 +99,10 @@ def reshape(X, newshape, order="C", copy=None):
9999
)
100100
newshape = [operator.index(d) for d in newshape]
101101
negative_ones_count = 0
102-
for i in range(len(newshape)):
103-
if newshape[i] == -1:
102+
for nshi in newshape:
103+
if nshi == -1:
104104
negative_ones_count = negative_ones_count + 1
105-
if (newshape[i] < -1) or negative_ones_count > 1:
105+
if (nshi < -1) or negative_ones_count > 1:
106106
raise ValueError(
107107
"Target shape should have at most 1 negative "
108108
"value which can only be -1"
@@ -111,7 +111,7 @@ def reshape(X, newshape, order="C", copy=None):
111111
v = X.size // (-np.prod(newshape))
112112
newshape = [v if d == -1 else d for d in newshape]
113113
if X.size != np.prod(newshape):
114-
raise ValueError("Can not reshape into {}".format(newshape))
114+
raise ValueError(f"Can not reshape into {newshape}")
115115
if X.size:
116116
newsts = reshaped_strides(X.shape, X.strides, newshape, order=order)
117117
else:
@@ -143,12 +143,11 @@ def reshape(X, newshape, order="C", copy=None):
143143
return dpt.usm_ndarray(
144144
tuple(newshape), dtype=X.dtype, buffer=flat_res, order=order
145145
)
146-
else:
147-
# can form a view
148-
return dpt.usm_ndarray(
149-
newshape,
150-
dtype=X.dtype,
151-
buffer=X,
152-
strides=tuple(newsts),
153-
offset=X.__sycl_usm_array_interface__.get("offset", 0),
154-
)
146+
# can form a view
147+
return dpt.usm_ndarray(
148+
newshape,
149+
dtype=X.dtype,
150+
buffer=X,
151+
strides=tuple(newsts),
152+
offset=X.__sycl_usm_array_interface__.get("offset", 0),
153+
)

0 commit comments

Comments
 (0)