Skip to content

Commit 624564e

Browse files
Fixed #728
Special-cased copying of 0-element arrays.
1 parent 9c90442 commit 624564e

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,20 @@ def contract_iter2(shape, strides1, strides2):
6666

6767

6868
def _has_memory_overlap(x1, x2):
69-
m1 = dpm.as_usm_memory(x1)
70-
m2 = dpm.as_usm_memory(x2)
71-
if m1.sycl_device == m2.sycl_device:
72-
p1_beg = m1._pointer
73-
p1_end = p1_beg + m1.nbytes
74-
p2_beg = m2._pointer
75-
p2_end = p2_beg + m2.nbytes
76-
return p1_beg > p2_end or p2_beg < p1_end
69+
if x1.size and x2.size:
70+
m1 = dpm.as_usm_memory(x1)
71+
m2 = dpm.as_usm_memory(x2)
72+
# can only overlap if bound to the same context
73+
if m1.sycl_context == m2.sycl_context:
74+
p1_beg = m1._pointer
75+
p1_end = p1_beg + m1.nbytes
76+
p2_beg = m2._pointer
77+
p2_end = p2_beg + m2.nbytes
78+
return p1_beg > p2_end or p2_beg < p1_end
79+
else:
80+
return False
7781
else:
82+
# zero element array do not overlap anything
7883
return False
7984

8085

@@ -193,6 +198,9 @@ def copy_same_dtype(dst, src):
193198
if dst.dtype != src.dtype:
194199
raise ValueError
195200

201+
if dst.size == 0:
202+
return
203+
196204
# check that memory regions do not overlap
197205
if _has_memory_overlap(dst, src):
198206
tmp = _copy_to_numpy(src)

0 commit comments

Comments
 (0)