Skip to content

Commit 4a7b94e

Browse files
dpctl.dptensor.ndarray provides proper __sycl_usm_array_interface__
1 parent aebe5b7 commit 4a7b94e

File tree

1 file changed

+40
-3
lines changed

1 file changed

+40
-3
lines changed

dpctl/dptensor/dparray.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,19 @@ def dprint(*args):
4949
def has_array_interface(x):
5050
return hasattr(x, array_interface_property)
5151

52+
def _get_usm_base(ary):
53+
ob = ary
54+
while True:
55+
if ob is None:
56+
return None
57+
elif hasattr(ob, '__sycl_usm_array_interface__'):
58+
return ob
59+
elif isinstance(ob, np.ndarray):
60+
ob = ob.base
61+
elif isinstance(ob, memoryview):
62+
ob = ob.obj
63+
else:
64+
return None
5265

5366
class ndarray(np.ndarray):
5467
"""
@@ -80,7 +93,9 @@ def __new__(
8093
dprint("buffer None new_obj already has sycl_usm")
8194
else:
8295
dprint("buffer None new_obj will add sycl_usm")
83-
setattr(new_obj, array_interface_property, {})
96+
setattr(new_obj,
97+
array_interface_property,
98+
new_obj._getter_sycl_usm_array_interface_())
8499
return new_obj
85100
# zero copy if buffer is a usm backed array-like thing
86101
elif hasattr(buffer, array_interface_property):
@@ -99,7 +114,8 @@ def __new__(
99114
dprint("buffer None new_obj already has sycl_usm")
100115
else:
101116
dprint("buffer None new_obj will add sycl_usm")
102-
setattr(new_obj, array_interface_property, {})
117+
setattr(new_obj, array_interface_property,
118+
new_obj._getter_sycl_usm_array_interface_())
103119
return new_obj
104120
else:
105121
dprint("dparray::ndarray __new__ buffer not None and not sycl_usm")
@@ -129,9 +145,29 @@ def __new__(
129145
dprint("buffer None new_obj already has sycl_usm")
130146
else:
131147
dprint("buffer None new_obj will add sycl_usm")
132-
setattr(new_obj, array_interface_property, {})
148+
setattr(new_obj, array_interface_property,
149+
new_obj._getter_sycl_usm_array_interface_())
133150
return new_obj
134151

152+
153+
def _getter_sycl_usm_array_interface_(self):
154+
ary_iface = self.__array_interface__
155+
_base = _get_usm_base(self)
156+
if _base is None:
157+
raise TypeError
158+
159+
usm_iface = getattr(_base, '__sycl_usm_array_interface__', None)
160+
if usm_iface is None:
161+
raise TypeError
162+
163+
if (ary_iface['data'][0] == usm_iface['data'][0]):
164+
ary_iface['version'] = usm_iface['version']
165+
ary_iface['syclobj'] = usm_iface['syclobj']
166+
else:
167+
raise TypeError
168+
return ary_iface
169+
170+
135171
def __array_finalize__(self, obj):
136172
dprint("__array_finalize__:", obj, hex(id(obj)), type(obj))
137173
# When called from the explicit constructor, obj is None
@@ -156,6 +192,7 @@ def __array_finalize__(self, obj):
156192
"Non-USM allocated ndarray can not viewed as a USM-allocated one without a copy"
157193
)
158194

195+
159196
# Tell Numba to not treat this type just like a NumPy ndarray but to propagate its type.
160197
# This way it will use the custom dparray allocator.
161198
__numba_no_subtype_ndarray__ = True

0 commit comments

Comments
 (0)