@@ -48,12 +48,13 @@ def dprint(*args):
48
48
def has_array_interface (x ):
49
49
return hasattr (x , array_interface_property )
50
50
51
+
51
52
def _get_usm_base (ary ):
52
53
ob = ary
53
54
while True :
54
55
if ob is None :
55
56
return None
56
- elif hasattr (ob , ' __sycl_usm_array_interface__' ):
57
+ elif hasattr (ob , " __sycl_usm_array_interface__" ):
57
58
return ob
58
59
elif isinstance (ob , np .ndarray ):
59
60
ob = ob .base
@@ -92,9 +93,11 @@ def __new__(
92
93
dprint ("buffer None new_obj already has sycl_usm" )
93
94
else :
94
95
dprint ("buffer None new_obj will add sycl_usm" )
95
- setattr (new_obj ,
96
- array_interface_property ,
97
- new_obj ._getter_sycl_usm_array_interface_ ())
96
+ setattr (
97
+ new_obj ,
98
+ array_interface_property ,
99
+ new_obj ._getter_sycl_usm_array_interface_ (),
100
+ )
98
101
return new_obj
99
102
# zero copy if buffer is a usm backed array-like thing
100
103
elif hasattr (buffer , array_interface_property ):
@@ -113,8 +116,11 @@ def __new__(
113
116
dprint ("buffer None new_obj already has sycl_usm" )
114
117
else :
115
118
dprint ("buffer None new_obj will add sycl_usm" )
116
- setattr (new_obj , array_interface_property ,
117
- new_obj ._getter_sycl_usm_array_interface_ ())
119
+ setattr (
120
+ new_obj ,
121
+ array_interface_property ,
122
+ new_obj ._getter_sycl_usm_array_interface_ (),
123
+ )
118
124
return new_obj
119
125
else :
120
126
dprint ("dparray::ndarray __new__ buffer not None and not sycl_usm" )
@@ -144,29 +150,30 @@ def __new__(
144
150
dprint ("buffer None new_obj already has sycl_usm" )
145
151
else :
146
152
dprint ("buffer None new_obj will add sycl_usm" )
147
- setattr (new_obj , array_interface_property ,
148
- new_obj ._getter_sycl_usm_array_interface_ ())
153
+ setattr (
154
+ new_obj ,
155
+ array_interface_property ,
156
+ new_obj ._getter_sycl_usm_array_interface_ (),
157
+ )
149
158
return new_obj
150
159
151
-
152
160
def _getter_sycl_usm_array_interface_ (self ):
153
161
ary_iface = self .__array_interface__
154
162
_base = _get_usm_base (self )
155
163
if _base is None :
156
164
raise TypeError
157
165
158
- usm_iface = getattr (_base , ' __sycl_usm_array_interface__' , None )
166
+ usm_iface = getattr (_base , " __sycl_usm_array_interface__" , None )
159
167
if usm_iface is None :
160
168
raise TypeError
161
169
162
- if ( ary_iface [' data' ][0 ] == usm_iface [' data' ][0 ]) :
163
- ary_iface [' version' ] = usm_iface [' version' ]
164
- ary_iface [' syclobj' ] = usm_iface [' syclobj' ]
170
+ if ary_iface [" data" ][0 ] == usm_iface [" data" ][0 ]:
171
+ ary_iface [" version" ] = usm_iface [" version" ]
172
+ ary_iface [" syclobj" ] = usm_iface [" syclobj" ]
165
173
else :
166
174
raise TypeError
167
175
return ary_iface
168
176
169
-
170
177
def __array_finalize__ (self , obj ):
171
178
dprint ("__array_finalize__:" , obj , hex (id (obj )), type (obj ))
172
179
# When called from the explicit constructor, obj is None
@@ -191,7 +198,6 @@ def __array_finalize__(self, obj):
191
198
"Non-USM allocated ndarray can not viewed as a USM-allocated one without a copy"
192
199
)
193
200
194
-
195
201
# Tell Numba to not treat this type just like a NumPy ndarray but to propagate its type.
196
202
# This way it will use the custom dparray allocator.
197
203
__numba_no_subtype_ndarray__ = True
0 commit comments