Skip to content

Commit 68f485c

Browse files
Changes prompted by pylint warnings
1 parent 4f421a5 commit 68f485c

File tree

1 file changed

+65
-60
lines changed

1 file changed

+65
-60
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 65 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
import dpctl.tensor._tensor_impl as ti
2121
from dpctl.tensor._device import normalize_queue_device
2222

23+
__doc__ = (
24+
"Implementation module for copy- and cast- operations on "
25+
":class:`dpctl.tensor.usm_ndarray`."
26+
)
27+
2328

2429
def _has_memory_overlap(x1, x2):
2530
if x1.size and x2.size:
@@ -33,15 +38,13 @@ def _has_memory_overlap(x1, x2):
3338
p2_end = p2_beg + m2.nbytes
3439
# may intersect if not ((p1_beg >= p2_end) or (p2_beg >= p2_end))
3540
return (p1_beg < p2_end) and (p2_beg < p1_end)
36-
else:
37-
return False
38-
else:
39-
# zero element array do not overlap anything
4041
return False
42+
# zero element array do not overlap anything
43+
return False
4144

4245

4346
def _copy_to_numpy(ary):
44-
if type(ary) is not dpt.usm_ndarray:
47+
if not isinstance(ary, dpt.usm_ndarray):
4548
raise TypeError
4649
h = ary.usm_data.copy_to_host().view(ary.dtype)
4750
itsz = ary.itemsize
@@ -78,9 +81,9 @@ def _copy_from_numpy(np_ary, usm_type="device", sycl_queue=None):
7881
def _copy_from_numpy_into(dst, np_ary):
7982
"Copies `np_ary` into `dst` of type :class:`dpctl.tensor.usm_ndarray"
8083
if not isinstance(np_ary, np.ndarray):
81-
raise TypeError("Expected numpy.ndarray, got {}".format(type(np_ary)))
84+
raise TypeError(f"Expected numpy.ndarray, got {type(np_ary)}")
8285
if not isinstance(dst, dpt.usm_ndarray):
83-
raise TypeError("Expected usm_ndarray, got {}".format(type(dst)))
86+
raise TypeError(f"Expected usm_ndarray, got {type(dst)}")
8487
src_ary = np.broadcast_to(np_ary, dst.shape)
8588
copy_q = dst.sycl_queue
8689
if copy_q.sycl_device.has_aspect_fp64 is False:
@@ -143,6 +146,8 @@ def asnumpy(usm_ary):
143146

144147

145148
class Dummy:
149+
"Helper class with specified __sycl_usm_array_interface__ attribute"
150+
146151
def __init__(self, iface):
147152
self.__sycl_usm_array_interface__ = iface
148153

@@ -160,7 +165,7 @@ def _copy_overlapping(dst, src):
160165
hcp1, cp1 = ti._copy_usm_ndarray_into_usm_ndarray(
161166
src=src, dst=tmp, sycl_queue=q
162167
)
163-
hcp2, cp2 = ti._copy_usm_ndarray_into_usm_ndarray(
168+
hcp2, _ = ti._copy_usm_ndarray_into_usm_ndarray(
164169
src=tmp, dst=dst, sycl_queue=q, depends=[cp1]
165170
)
166171
hcp2.wait()
@@ -174,7 +179,7 @@ def _copy_same_shape(dst, src):
174179
_copy_overlapping(src=src, dst=dst)
175180
return
176181

177-
hev, ev = ti._copy_usm_ndarray_into_usm_ndarray(
182+
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
178183
src=src, dst=dst, sycl_queue=dst.sycl_queue
179184
)
180185
hev.wait()
@@ -197,7 +202,13 @@ def _broadcast_shapes(sh1, sh2):
197202

198203

199204
def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
200-
if type(dst) is not dpt.usm_ndarray or type(src) is not dpt.usm_ndarray:
205+
if any(
206+
not isinstance(arg, dpt.usm_ndarray)
207+
for arg in (
208+
dst,
209+
src,
210+
)
211+
):
201212
raise TypeError(
202213
"Both types are expected to be dpctl.tensor.usm_ndarray, "
203214
f"got {type(dst)} and {type(src)}."
@@ -209,8 +220,8 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
209220

210221
try:
211222
common_shape = _broadcast_shapes(dst.shape, src.shape)
212-
except ValueError:
213-
raise ValueError("Shapes of two arrays are not compatible")
223+
except ValueError as exc:
224+
raise ValueError("Shapes of two arrays are not compatible") from exc
214225

215226
if dst.size < src.size:
216227
raise ValueError("Destination is smaller ")
@@ -251,9 +262,7 @@ def copy(usm_ary, order="K"):
251262
"""
252263
if not isinstance(usm_ary, dpt.usm_ndarray):
253264
return TypeError(
254-
"Expected object of type dpt.usm_ndarray, got {}".format(
255-
type(usm_ary)
256-
)
265+
f"Expected object of type dpt.usm_ndarray, got {type(usm_ary)}"
257266
)
258267
copy_order = "C"
259268
if order == "C":
@@ -308,9 +317,7 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
308317
"""
309318
if not isinstance(usm_ary, dpt.usm_ndarray):
310319
return TypeError(
311-
"Expected object of type dpt.usm_ndarray, got {}".format(
312-
type(usm_ary)
313-
)
320+
f"Expected object of type dpt.usm_ndarray, got {type(usm_ary)}"
314321
)
315322
if not isinstance(order, str) or order not in ["A", "C", "F", "K"]:
316323
raise ValueError(
@@ -321,56 +328,54 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
321328
target_dtype = dpt.dtype(newdtype)
322329
if not dpt.can_cast(ary_dtype, target_dtype, casting=casting):
323330
raise TypeError(
324-
"Can not cast from {} to {} according to rule {}".format(
325-
ary_dtype, newdtype, casting
326-
)
331+
f"Can not cast from {ary_dtype} to {newdtype} "
332+
f"according to rule {casting}."
327333
)
328334
c_contig = usm_ary.flags.c_contiguous
329335
f_contig = usm_ary.flags.f_contiguous
330-
needs_copy = copy or not (ary_dtype == target_dtype)
336+
needs_copy = copy or not ary_dtype == target_dtype
331337
if not needs_copy and (order != "K"):
332338
needs_copy = (c_contig and order not in ["A", "C"]) or (
333339
f_contig and order not in ["A", "F"]
334340
)
335-
if needs_copy:
336-
copy_order = "C"
337-
if order == "C":
338-
pass
339-
elif order == "F":
340-
copy_order = order
341-
elif order == "A":
342-
if usm_ary.flags.f_contiguous:
343-
copy_order = "F"
344-
elif order == "K":
345-
if usm_ary.flags.f_contiguous:
346-
copy_order = "F"
347-
else:
348-
raise ValueError(
349-
"Unrecognized value of the order keyword. "
350-
"Recognized values are 'A', 'C', 'F', or 'K'"
351-
)
341+
if not needs_copy:
342+
return usm_ary
343+
copy_order = "C"
344+
if order == "C":
345+
pass
346+
elif order == "F":
347+
copy_order = order
348+
elif order == "A":
349+
if usm_ary.flags.f_contiguous:
350+
copy_order = "F"
351+
elif order == "K":
352+
if usm_ary.flags.f_contiguous:
353+
copy_order = "F"
354+
else:
355+
raise ValueError(
356+
"Unrecognized value of the order keyword. "
357+
"Recognized values are 'A', 'C', 'F', or 'K'"
358+
)
359+
R = dpt.usm_ndarray(
360+
usm_ary.shape,
361+
dtype=target_dtype,
362+
buffer=usm_ary.usm_type,
363+
order=copy_order,
364+
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
365+
)
366+
if order == "K" and (not c_contig and not f_contig):
367+
original_strides = usm_ary.strides
368+
ind = sorted(
369+
range(usm_ary.ndim),
370+
key=lambda i: abs(original_strides[i]),
371+
reverse=True,
372+
)
373+
new_strides = tuple(R.strides[ind[i]] for i in ind)
352374
R = dpt.usm_ndarray(
353375
usm_ary.shape,
354376
dtype=target_dtype,
355-
buffer=usm_ary.usm_type,
356-
order=copy_order,
357-
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
377+
buffer=R.usm_data,
378+
strides=new_strides,
358379
)
359-
if order == "K" and (not c_contig and not f_contig):
360-
original_strides = usm_ary.strides
361-
ind = sorted(
362-
range(usm_ary.ndim),
363-
key=lambda i: abs(original_strides[i]),
364-
reverse=True,
365-
)
366-
new_strides = tuple(R.strides[ind[i]] for i in ind)
367-
R = dpt.usm_ndarray(
368-
usm_ary.shape,
369-
dtype=target_dtype,
370-
buffer=R.usm_data,
371-
strides=new_strides,
372-
)
373-
_copy_from_usm_ndarray_to_usm_ndarray(R, usm_ary)
374-
return R
375-
else:
376-
return usm_ary
380+
_copy_from_usm_ndarray_to_usm_ndarray(R, usm_ary)
381+
return R

0 commit comments

Comments
 (0)