20
20
import dpctl .tensor ._tensor_impl as ti
21
21
from dpctl .tensor ._device import normalize_queue_device
22
22
23
+ __doc__ = (
24
+ "Implementation module for copy- and cast- operations on "
25
+ ":class:`dpctl.tensor.usm_ndarray`."
26
+ )
27
+
23
28
24
29
def _has_memory_overlap (x1 , x2 ):
25
30
if x1 .size and x2 .size :
@@ -33,15 +38,13 @@ def _has_memory_overlap(x1, x2):
33
38
p2_end = p2_beg + m2 .nbytes
34
39
# may intersect if not ((p1_beg >= p2_end) or (p2_beg >= p2_end))
35
40
return (p1_beg < p2_end ) and (p2_beg < p1_end )
36
- else :
37
- return False
38
- else :
39
- # zero element array do not overlap anything
40
41
return False
42
+ # zero element array do not overlap anything
43
+ return False
41
44
42
45
43
46
def _copy_to_numpy (ary ):
44
- if type (ary ) is not dpt .usm_ndarray :
47
+ if not isinstance (ary , dpt .usm_ndarray ) :
45
48
raise TypeError
46
49
h = ary .usm_data .copy_to_host ().view (ary .dtype )
47
50
itsz = ary .itemsize
@@ -78,9 +81,9 @@ def _copy_from_numpy(np_ary, usm_type="device", sycl_queue=None):
78
81
def _copy_from_numpy_into (dst , np_ary ):
79
82
"Copies `np_ary` into `dst` of type :class:`dpctl.tensor.usm_ndarray"
80
83
if not isinstance (np_ary , np .ndarray ):
81
- raise TypeError ("Expected numpy.ndarray, got {}" . format ( type (np_ary )) )
84
+ raise TypeError (f "Expected numpy.ndarray, got { type (np_ary )} " )
82
85
if not isinstance (dst , dpt .usm_ndarray ):
83
- raise TypeError ("Expected usm_ndarray, got {}" . format ( type (dst )) )
86
+ raise TypeError (f "Expected usm_ndarray, got { type (dst )} " )
84
87
src_ary = np .broadcast_to (np_ary , dst .shape )
85
88
copy_q = dst .sycl_queue
86
89
if copy_q .sycl_device .has_aspect_fp64 is False :
@@ -143,6 +146,8 @@ def asnumpy(usm_ary):
143
146
144
147
145
148
class Dummy :
149
+ "Helper class with specified __sycl_usm_array_interface__ attribute"
150
+
146
151
def __init__ (self , iface ):
147
152
self .__sycl_usm_array_interface__ = iface
148
153
@@ -160,7 +165,7 @@ def _copy_overlapping(dst, src):
160
165
hcp1 , cp1 = ti ._copy_usm_ndarray_into_usm_ndarray (
161
166
src = src , dst = tmp , sycl_queue = q
162
167
)
163
- hcp2 , cp2 = ti ._copy_usm_ndarray_into_usm_ndarray (
168
+ hcp2 , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
164
169
src = tmp , dst = dst , sycl_queue = q , depends = [cp1 ]
165
170
)
166
171
hcp2 .wait ()
@@ -174,7 +179,7 @@ def _copy_same_shape(dst, src):
174
179
_copy_overlapping (src = src , dst = dst )
175
180
return
176
181
177
- hev , ev = ti ._copy_usm_ndarray_into_usm_ndarray (
182
+ hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
178
183
src = src , dst = dst , sycl_queue = dst .sycl_queue
179
184
)
180
185
hev .wait ()
@@ -197,7 +202,13 @@ def _broadcast_shapes(sh1, sh2):
197
202
198
203
199
204
def _copy_from_usm_ndarray_to_usm_ndarray (dst , src ):
200
- if type (dst ) is not dpt .usm_ndarray or type (src ) is not dpt .usm_ndarray :
205
+ if any (
206
+ not isinstance (arg , dpt .usm_ndarray )
207
+ for arg in (
208
+ dst ,
209
+ src ,
210
+ )
211
+ ):
201
212
raise TypeError (
202
213
"Both types are expected to be dpctl.tensor.usm_ndarray, "
203
214
f"got { type (dst )} and { type (src )} ."
@@ -209,8 +220,8 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
209
220
210
221
try :
211
222
common_shape = _broadcast_shapes (dst .shape , src .shape )
212
- except ValueError :
213
- raise ValueError ("Shapes of two arrays are not compatible" )
223
+ except ValueError as exc :
224
+ raise ValueError ("Shapes of two arrays are not compatible" ) from exc
214
225
215
226
if dst .size < src .size :
216
227
raise ValueError ("Destination is smaller " )
@@ -251,9 +262,7 @@ def copy(usm_ary, order="K"):
251
262
"""
252
263
if not isinstance (usm_ary , dpt .usm_ndarray ):
253
264
return TypeError (
254
- "Expected object of type dpt.usm_ndarray, got {}" .format (
255
- type (usm_ary )
256
- )
265
+ f"Expected object of type dpt.usm_ndarray, got { type (usm_ary )} "
257
266
)
258
267
copy_order = "C"
259
268
if order == "C" :
@@ -308,9 +317,7 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
308
317
"""
309
318
if not isinstance (usm_ary , dpt .usm_ndarray ):
310
319
return TypeError (
311
- "Expected object of type dpt.usm_ndarray, got {}" .format (
312
- type (usm_ary )
313
- )
320
+ f"Expected object of type dpt.usm_ndarray, got { type (usm_ary )} "
314
321
)
315
322
if not isinstance (order , str ) or order not in ["A" , "C" , "F" , "K" ]:
316
323
raise ValueError (
@@ -321,56 +328,54 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
321
328
target_dtype = dpt .dtype (newdtype )
322
329
if not dpt .can_cast (ary_dtype , target_dtype , casting = casting ):
323
330
raise TypeError (
324
- "Can not cast from {} to {} according to rule {}" .format (
325
- ary_dtype , newdtype , casting
326
- )
331
+ f"Can not cast from { ary_dtype } to { newdtype } "
332
+ f"according to rule { casting } ."
327
333
)
328
334
c_contig = usm_ary .flags .c_contiguous
329
335
f_contig = usm_ary .flags .f_contiguous
330
- needs_copy = copy or not ( ary_dtype == target_dtype )
336
+ needs_copy = copy or not ary_dtype == target_dtype
331
337
if not needs_copy and (order != "K" ):
332
338
needs_copy = (c_contig and order not in ["A" , "C" ]) or (
333
339
f_contig and order not in ["A" , "F" ]
334
340
)
335
- if needs_copy :
336
- copy_order = "C"
337
- if order == "C" :
338
- pass
339
- elif order == "F" :
340
- copy_order = order
341
- elif order == "A" :
342
- if usm_ary .flags .f_contiguous :
343
- copy_order = "F"
344
- elif order == "K" :
345
- if usm_ary .flags .f_contiguous :
346
- copy_order = "F"
347
- else :
348
- raise ValueError (
349
- "Unrecognized value of the order keyword. "
350
- "Recognized values are 'A', 'C', 'F', or 'K'"
351
- )
341
+ if not needs_copy :
342
+ return usm_ary
343
+ copy_order = "C"
344
+ if order == "C" :
345
+ pass
346
+ elif order == "F" :
347
+ copy_order = order
348
+ elif order == "A" :
349
+ if usm_ary .flags .f_contiguous :
350
+ copy_order = "F"
351
+ elif order == "K" :
352
+ if usm_ary .flags .f_contiguous :
353
+ copy_order = "F"
354
+ else :
355
+ raise ValueError (
356
+ "Unrecognized value of the order keyword. "
357
+ "Recognized values are 'A', 'C', 'F', or 'K'"
358
+ )
359
+ R = dpt .usm_ndarray (
360
+ usm_ary .shape ,
361
+ dtype = target_dtype ,
362
+ buffer = usm_ary .usm_type ,
363
+ order = copy_order ,
364
+ buffer_ctor_kwargs = {"queue" : usm_ary .sycl_queue },
365
+ )
366
+ if order == "K" and (not c_contig and not f_contig ):
367
+ original_strides = usm_ary .strides
368
+ ind = sorted (
369
+ range (usm_ary .ndim ),
370
+ key = lambda i : abs (original_strides [i ]),
371
+ reverse = True ,
372
+ )
373
+ new_strides = tuple (R .strides [ind [i ]] for i in ind )
352
374
R = dpt .usm_ndarray (
353
375
usm_ary .shape ,
354
376
dtype = target_dtype ,
355
- buffer = usm_ary .usm_type ,
356
- order = copy_order ,
357
- buffer_ctor_kwargs = {"queue" : usm_ary .sycl_queue },
377
+ buffer = R .usm_data ,
378
+ strides = new_strides ,
358
379
)
359
- if order == "K" and (not c_contig and not f_contig ):
360
- original_strides = usm_ary .strides
361
- ind = sorted (
362
- range (usm_ary .ndim ),
363
- key = lambda i : abs (original_strides [i ]),
364
- reverse = True ,
365
- )
366
- new_strides = tuple (R .strides [ind [i ]] for i in ind )
367
- R = dpt .usm_ndarray (
368
- usm_ary .shape ,
369
- dtype = target_dtype ,
370
- buffer = R .usm_data ,
371
- strides = new_strides ,
372
- )
373
- _copy_from_usm_ndarray_to_usm_ndarray (R , usm_ary )
374
- return R
375
- else :
376
- return usm_ary
380
+ _copy_from_usm_ndarray_to_usm_ndarray (R , usm_ary )
381
+ return R
0 commit comments