@@ -49,6 +49,19 @@ def dprint(*args):
49
49
def has_array_interface (x ):
50
50
return hasattr (x , array_interface_property )
51
51
52
+ def _get_usm_base (ary ):
53
+ ob = ary
54
+ while True :
55
+ if ob is None :
56
+ return None
57
+ elif hasattr (ob , '__sycl_usm_array_interface__' ):
58
+ return ob
59
+ elif isinstance (ob , np .ndarray ):
60
+ ob = ob .base
61
+ elif isinstance (ob , memoryview ):
62
+ ob = ob .obj
63
+ else :
64
+ return None
52
65
53
66
class ndarray (np .ndarray ):
54
67
"""
@@ -80,7 +93,9 @@ def __new__(
80
93
dprint ("buffer None new_obj already has sycl_usm" )
81
94
else :
82
95
dprint ("buffer None new_obj will add sycl_usm" )
83
- setattr (new_obj , array_interface_property , {})
96
+ setattr (new_obj ,
97
+ array_interface_property ,
98
+ new_obj ._getter_sycl_usm_array_interface_ ())
84
99
return new_obj
85
100
# zero copy if buffer is a usm backed array-like thing
86
101
elif hasattr (buffer , array_interface_property ):
@@ -99,7 +114,8 @@ def __new__(
99
114
dprint ("buffer None new_obj already has sycl_usm" )
100
115
else :
101
116
dprint ("buffer None new_obj will add sycl_usm" )
102
- setattr (new_obj , array_interface_property , {})
117
+ setattr (new_obj , array_interface_property ,
118
+ new_obj ._getter_sycl_usm_array_interface_ ())
103
119
return new_obj
104
120
else :
105
121
dprint ("dparray::ndarray __new__ buffer not None and not sycl_usm" )
@@ -129,9 +145,29 @@ def __new__(
129
145
dprint ("buffer None new_obj already has sycl_usm" )
130
146
else :
131
147
dprint ("buffer None new_obj will add sycl_usm" )
132
- setattr (new_obj , array_interface_property , {})
148
+ setattr (new_obj , array_interface_property ,
149
+ new_obj ._getter_sycl_usm_array_interface_ ())
133
150
return new_obj
134
151
152
+
153
+ def _getter_sycl_usm_array_interface_ (self ):
154
+ ary_iface = self .__array_interface__
155
+ _base = _get_usm_base (self )
156
+ if _base is None :
157
+ raise TypeError
158
+
159
+ usm_iface = getattr (_base , '__sycl_usm_array_interface__' , None )
160
+ if usm_iface is None :
161
+ raise TypeError
162
+
163
+ if (ary_iface ['data' ][0 ] == usm_iface ['data' ][0 ]):
164
+ ary_iface ['version' ] = usm_iface ['version' ]
165
+ ary_iface ['syclobj' ] = usm_iface ['syclobj' ]
166
+ else :
167
+ raise TypeError
168
+ return ary_iface
169
+
170
+
135
171
def __array_finalize__ (self , obj ):
136
172
dprint ("__array_finalize__:" , obj , hex (id (obj )), type (obj ))
137
173
# When called from the explicit constructor, obj is None
@@ -156,6 +192,7 @@ def __array_finalize__(self, obj):
156
192
"Non-USM allocated ndarray can not viewed as a USM-allocated one without a copy"
157
193
)
158
194
195
+
159
196
# Tell Numba to not treat this type just like a NumPy ndarray but to propagate its type.
160
197
# This way it will use the custom dparray allocator.
161
198
__numba_no_subtype_ndarray__ = True
0 commit comments