Skip to content

Commit 91a53bc

Browse files
Adding __array_function__ so that Numpy calls with dparrays work. (#261)
* Calling numpy.sum with a dparray should work now. __array_function__ is now implemented. I changed the way that we determine if a class or function is defined in this file from using eval to getting the module from sys.modules[__name__] and then using hasattr on the module. * Efficiency improvements 1. Create class_list, function_list in one pass over members of np module 2. Renamed isdef into _isdef, since it is an internal function 3. Create certain objects in __array_function__ only if we intend to use them. * Darkened numpy_with_usm_shared.py with black * account of C-API defined functions of NumPy * Don't need atypes since that just verified it was the same as the types argument. Don't need fatypes since it wasn't being used. * black changes * Don't need explicit if debug once we get rid of atypes calculation. That allows dprint to move left and back to one line. Co-authored-by: Oleksandr Pavlyk <[email protected]>
1 parent 609c8fd commit 91a53bc

File tree

2 files changed

+31
-14
lines changed

2 files changed

+31
-14
lines changed

dpctl/dptensor/numpy_usm_shared.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import inspect
3636
import dpctl
3737
from dpctl.memory import MemoryUSMShared
38+
import builtins
3839

3940
debug = False
4041

@@ -45,8 +46,15 @@ def dprint(*args):
4546
sys.stdout.flush()
4647

4748

48-
functions_list = [o[0] for o in getmembers(np) if isfunction(o[1]) or isbuiltin(o[1])]
49-
class_list = [o for o in getmembers(np) if isclass(o[1])]
49+
functions_list = []
50+
class_list = []
51+
for o in getmembers(np):
52+
s = o[1]
53+
if isfunction(s) or isbuiltin(s):
54+
functions_list.append(o[0])
55+
elif isclass(s):
56+
class_list.append(o)
57+
5058

5159
array_interface_property = "__sycl_usm_array_interface__"
5260

@@ -90,7 +98,9 @@ def __new__(
9098
nelems = np.prod(shape)
9199
dt = np.dtype(dtype)
92100
isz = dt.itemsize
93-
nbytes = int(isz * max(1, nelems))
101+
# Have to use builtins.max explicitly since this module will
102+
# import numpy's max function.
103+
nbytes = int(isz * builtins.max(1, nelems))
94104
buf = MemoryUSMShared(nbytes)
95105
new_obj = np.ndarray.__new__(
96106
subtype,
@@ -224,7 +234,7 @@ def __array_finalize__(self, obj):
224234

225235
# Convert to a NumPy ndarray.
226236
def as_ndarray(self):
227-
return np.copy(self)
237+
return np.copy(np.ndarray(self.shape, self.dtype, self))
228238

229239
def __array__(self):
230240
return self
@@ -277,18 +287,26 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
277287
else:
278288
return NotImplemented
279289

290+
def __array_function__(self, func, types, args, kwargs):
291+
fname = func.__name__
292+
has_func = _isdef(fname)
293+
dprint("__array_function__:", func, fname, type(func), types, has_func)
294+
if has_func:
295+
cm = sys.modules[__name__]
296+
affunc = getattr(cm, fname)
297+
fargs = [x.view(np.ndarray) if isinstance(x, ndarray) else x for x in args]
298+
return affunc(*fargs, **kwargs)
299+
return NotImplemented
280300

281-
def isdef(x):
282-
try:
283-
eval(x)
284-
return True
285-
except NameError:
286-
return False
301+
302+
def _isdef(x):
303+
cm = sys.modules[__name__]
304+
return hasattr(cm, x)
287305

288306

289307
for c in class_list:
290308
cname = c[0]
291-
if isdef(cname):
309+
if _isdef(cname):
292310
continue
293311
# For now we do the simple thing and copy the types from NumPy module
294312
# into numpy_usm_shared module.
@@ -305,7 +323,7 @@ def isdef(x):
305323
# instead. This is a stop-gap. We should eventually find a
306324
# way to do the allocation correct to start with.
307325
for fname in functions_list:
308-
if isdef(fname):
326+
if _isdef(fname):
309327
continue
310328
new_func = "def %s(*args, **kwargs):\n" % fname
311329
new_func += " ret = np.%s(*args, **kwargs)\n" % fname
@@ -321,4 +339,4 @@ def from_ndarray(x):
321339

322340

323341
def as_ndarray(x):
324-
return np.copy(x)
342+
return np.copy(np.ndarray(x.shape, x.dtype, x))

dpctl/tests/test_dparray.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def test_numpy_ravel_with_dparray(self):
7272
res = numpy.ravel(self.X)
7373
self.assertEqual(res.shape, (1024,))
7474

75-
@unittest.expectedFailure
7675
def test_numpy_sum_with_dparray(self):
7776
res = numpy.sum(self.X)
7877
self.assertEqual(res, 1024.0)

0 commit comments

Comments
 (0)