Skip to content

Commit f0c3c2b

Browse files
fix comments
1 parent fb163ec commit f0c3c2b

File tree

4 files changed

+94
-65
lines changed

4 files changed

+94
-65
lines changed

dpctl/__init__.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,4 @@
2929

3030
from dpctl._sycl_core cimport *
3131
from dpctl._memory import *
32-
from dparray import *
32+
from .dparray import *

dpctl/dparray.py

Lines changed: 72 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,47 +3,55 @@
33
from numbers import Number
44
from types import FunctionType as ftype, BuiltinFunctionType as bftype
55
import sys
6-
#import importlib
7-
#import functools
86
import inspect
7+
import dpctl
8+
from dpctl.memory import MemoryUSMShared
99

1010
debug = False
1111

12+
1213
def dprint(*args):
1314
if debug:
1415
print(*args)
1516
sys.stdout.flush()
1617

17-
import dpctl
18-
from dpctl.memory import MemoryUSMShared
1918

2019
functions_list = [o[0] for o in getmembers(np) if isfunction(o[1]) or isbuiltin(o[1])]
2120
class_list = [o for o in getmembers(np) if isclass(o[1])]
2221

2322
array_interface_property = "__array_interface__"
23+
24+
2425
def has_array_interface(x):
2526
return hasattr(x, array_interface_property)
2627

28+
2729
class ndarray(np.ndarray):
2830
"""
2931
numpy.ndarray subclass whose underlying memory buffer is allocated
3032
with a foreign allocator.
3133
"""
32-
def __new__(subtype, shape,
33-
dtype=float, buffer=None, offset=0,
34-
strides=None, order=None):
34+
35+
def __new__(
36+
subtype, shape, dtype=float, buffer=None, offset=0, strides=None, order=None
37+
):
3538
# Create a new array.
3639
if buffer is None:
3740
dprint("dparray::ndarray __new__ buffer None")
3841
nelems = np.prod(shape)
3942
dt = np.dtype(dtype)
4043
isz = dt.itemsize
41-
nbytes = int(isz*max(1, nelems))
44+
nbytes = int(isz * max(1, nelems))
4245
buf = MemoryUSMShared(nbytes)
4346
new_obj = np.ndarray.__new__(
44-
subtype, shape, dtype=dt,
45-
buffer=buf, offset=0,
46-
strides=strides, order=order)
47+
subtype,
48+
shape,
49+
dtype=dt,
50+
buffer=buf,
51+
offset=0,
52+
strides=strides,
53+
order=order,
54+
)
4755
if hasattr(new_obj, array_interface_property):
4856
dprint("buffer None new_obj already has sycl_usm")
4957
else:
@@ -55,9 +63,14 @@ def __new__(subtype, shape,
5563
dprint("dparray::ndarray __new__ buffer", array_interface_property)
5664
# also check for array interface
5765
new_obj = np.ndarray.__new__(
58-
subtype, shape, dtype=dtype,
59-
buffer=buffer, offset=offset,
60-
strides=strides, order=order)
66+
subtype,
67+
shape,
68+
dtype=dtype,
69+
buffer=buffer,
70+
offset=offset,
71+
strides=strides,
72+
order=order,
73+
)
6174
if hasattr(new_obj, array_interface_property):
6275
dprint("buffer None new_obj already has sycl_usm")
6376
else:
@@ -68,17 +81,26 @@ def __new__(subtype, shape,
6881
dprint("dparray::ndarray __new__ buffer not None and not sycl_usm")
6982
nelems = np.prod(shape)
7083
# must copy
71-
ar = np.ndarray(shape,
72-
dtype=dtype, buffer=buffer,
73-
offset=offset, strides=strides,
74-
order=order)
84+
ar = np.ndarray(
85+
shape,
86+
dtype=dtype,
87+
buffer=buffer,
88+
offset=offset,
89+
strides=strides,
90+
order=order,
91+
)
7592
nbytes = int(ar.nbytes)
7693
buf = MemoryUSMShared(nbytes)
7794
new_obj = np.ndarray.__new__(
78-
subtype, shape, dtype=dtype,
79-
buffer=buf, offset=0,
80-
strides=strides, order=order)
81-
np.copyto(new_obj, ar, casting='no')
95+
subtype,
96+
shape,
97+
dtype=dtype,
98+
buffer=buf,
99+
offset=0,
100+
strides=strides,
101+
order=order,
102+
)
103+
np.copyto(new_obj, ar, casting="no")
82104
if hasattr(new_obj, array_interface_property):
83105
dprint("buffer None new_obj already has sycl_usm")
84106
else:
@@ -89,7 +111,8 @@ def __new__(subtype, shape,
89111
def __array_finalize__(self, obj):
90112
dprint("__array_finalize__:", obj, hex(id(obj)), type(obj))
91113
# When called from the explicit constructor, obj is None
92-
if obj is None: return
114+
if obj is None:
115+
return
93116
# When called in new-from-template, `obj` is another instance of our own
94117
# subclass, that we might use to update the new `self` instance.
95118
# However, when called from view casting, `obj` can be an instance of any
@@ -105,7 +128,11 @@ def __array_finalize__(self, obj):
105128
dprint("external_allocator:", hex(ea), type(ea))
106129
dprint("data:", hex(d), type(d))
107130
dppl_rt_allocator = numba.dppl._dppl_rt.get_external_allocator()
108-
dprint("dppl external_allocator:", hex(dppl_rt_allocator), type(dppl_rt_allocator))
131+
dprint(
132+
"dppl external_allocator:",
133+
hex(dppl_rt_allocator),
134+
type(dppl_rt_allocator),
135+
)
109136
dprint(dir(mobj))
110137
if ea == dppl_rt_allocator:
111138
return
@@ -118,24 +145,26 @@ def __array_finalize__(self, obj):
118145
if hasattr(obj, array_interface_property):
119146
return
120147
ob = ob.base
121-
148+
122149
# Just raise an exception since __array_ufunc__ makes all reasonable cases not
123150
# need the code below.
124-
raise ValueError("Non-USM allocated ndarray can not viewed as a USM-allocated one without a copy")
125-
151+
raise ValueError(
152+
"Non-USM allocated ndarray can not viewed as a USM-allocated one without a copy"
153+
)
154+
126155
# Tell Numba to not treat this type just like a NumPy ndarray but to propagate its type.
127156
# This way it will use the custom dparray allocator.
128157
__numba_no_subtype_ndarray__ = True
129158

130159
# Convert to a NumPy ndarray.
131160
def as_ndarray(self):
132-
return np.copy(self)
161+
return np.copy(self)
133162

134163
def __array__(self):
135164
return self
136165

137166
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
138-
if method == '__call__':
167+
if method == "__call__":
139168
N = None
140169
scalars = []
141170
typing = []
@@ -162,41 +191,43 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
162191
# USM memory. However, if kwarg has dparray-typed out then
163192
# array_ufunc is called recursively so we cast out as regular
164193
# NumPy ndarray (having a USM data pointer).
165-
if kwargs.get('out', None) is None:
194+
if kwargs.get("out", None) is None:
166195
# maybe copy?
167196
# deal with multiple returned arrays, so kwargs['out'] can be tuple
168197
res_type = np.result_type(*typing)
169198
out = empty(inputs[0].shape, dtype=res_type)
170199
out_as_np = np.ndarray(out.shape, out.dtype, out)
171-
kwargs['out'] = out_as_np
200+
kwargs["out"] = out_as_np
172201
else:
173202
# If they manually gave dparray as out kwarg then we have to also
174203
# cast as regular NumPy ndarray to avoid recursion.
175-
if isinstance(kwargs['out'], ndarray):
176-
out = kwargs['out']
177-
kwargs['out'] = np.ndarray(out.shape, out.dtype, out)
204+
if isinstance(kwargs["out"], ndarray):
205+
out = kwargs["out"]
206+
kwargs["out"] = np.ndarray(out.shape, out.dtype, out)
178207
else:
179-
out = kwargs['out']
208+
out = kwargs["out"]
180209
ret = ufunc(*scalars, **kwargs)
181210
return out
182211
else:
183212
return NotImplemented
184213

214+
185215
def isdef(x):
186216
try:
187217
eval(x)
188218
return True
189219
except NameError:
190220
return False
191221

222+
192223
for c in class_list:
193224
cname = c[0]
194225
if isdef(cname):
195226
continue
196227
# For now we do the simple thing and copy the types from NumPy module into dparray module.
197228
new_func = "%s = np.%s" % (cname, cname)
198229
try:
199-
the_code = compile(new_func, '__init__', 'exec')
230+
the_code = compile(new_func, "__init__", "exec")
200231
exec(the_code)
201232
except:
202233
print("Failed to exec type propagation", cname)
@@ -209,16 +240,18 @@ def isdef(x):
209240
for fname in functions_list:
210241
if isdef(fname):
211242
continue
212-
new_func = "def %s(*args, **kwargs):\n" % fname
243+
new_func = "def %s(*args, **kwargs):\n" % fname
213244
new_func += " ret = np.%s(*args, **kwargs)\n" % fname
214245
new_func += " if type(ret) == np.ndarray:\n"
215246
new_func += " ret = ndarray(ret.shape, ret.dtype, ret)\n"
216247
new_func += " return ret\n"
217-
the_code = compile(new_func, '__init__', 'exec')
248+
the_code = compile(new_func, "__init__", "exec")
218249
exec(the_code)
219250

251+
220252
def from_ndarray(x):
221253
return copy(x)
222254

255+
223256
def as_ndarray(x):
224-
return np.copy(x)
257+
return np.copy(x)

dpctl/tests/test_dparray.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,10 @@
33
import numpy
44

55

6-
def func_operation_with_const(dpctl_array):
7-
return dpctl_array * 2.0 + 13
8-
9-
10-
def multiply_func(np_array, dpcrtl_array):
11-
return np_array * dpcrtl_array
12-
13-
146
class TestOverloadList(unittest.TestCase):
15-
maxDiff = None
16-
17-
X = dparray.ndarray((256, 4), dtype='d')
18-
X.fill(1.0)
7+
def setUp(self):
8+
self.X = dparray.ndarray((256, 4), dtype="d")
9+
self.X.fill(1.0)
1910

2011
def test_dparray_type(self):
2112
self.assertIsInstance(self.X, dparray.ndarray)
@@ -38,13 +29,16 @@ def test_multiplication_dparray(self):
3829
self.assertIsInstance(C, dparray.ndarray)
3930

4031
def test_dparray_through_python_func(self):
32+
def func_operation_with_const(dpctl_array):
33+
return dpctl_array * 2.0 + 13
34+
4135
C = self.X * 5
4236
dp_func = func_operation_with_const(C)
4337
self.assertIsInstance(dp_func, dparray.ndarray)
4438

4539
def test_dparray_mixing_dpctl_and_numpy(self):
46-
dp_numpy = numpy.ones((256, 4), dtype='d')
47-
res = multiply_func(dp_numpy, self.X)
40+
dp_numpy = numpy.ones((256, 4), dtype="d")
41+
res = dp_numpy * self.X
4842
self.assertIsInstance(res, dparray.ndarray)
4943

5044
def test_dparray_shape(self):
@@ -55,6 +49,15 @@ def test_dparray_T(self):
5549
res = self.X.T
5650
self.assertEqual(res.shape, (4, 256))
5751

52+
def test_numpy_ravel_with_dparray(self):
53+
res = numpy.ravel(self.X)
54+
self.assertEqual(res.shape, (1024,))
55+
56+
@unittest.expectedFailure
57+
def test_numpy_sum_with_dparray(self):
58+
res = numpy.sum(self.X)
59+
self.assertEqual(res, 1024.0)
60+
5861

59-
if __name__ == '__main__':
62+
if __name__ == "__main__":
6063
unittest.main()

setup.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,7 @@ def extensions():
148148
runtime_library_dirs = []
149149

150150
extension_args = {
151-
"depends": [
152-
dppl_sycl_interface_include,
153-
],
151+
"depends": [dppl_sycl_interface_include,],
154152
"include_dirs": [np.get_include(), dppl_sycl_interface_include],
155153
"extra_compile_args": eca
156154
+ get_other_cxxflags()
@@ -165,16 +163,12 @@ def extensions():
165163
extensions = [
166164
Extension(
167165
"dpctl._sycl_core",
168-
[
169-
os.path.join("dpctl", "_sycl_core.pyx"),
170-
],
166+
[os.path.join("dpctl", "_sycl_core.pyx"),],
171167
**extension_args
172168
),
173169
Extension(
174170
"dpctl.memory._memory",
175-
[
176-
os.path.join("dpctl", "memory", "_memory.pyx"),
177-
],
171+
[os.path.join("dpctl", "memory", "_memory.pyx"),],
178172
**extension_args
179173
),
180174
]
@@ -201,7 +195,6 @@ def _get_cmdclass():
201195
cmdclass["develop"] = develop
202196
return cmdclass
203197

204-
print("packages:", find_packages(include=["*"]))
205198

206199
setup(
207200
name="dpctl",

0 commit comments

Comments
 (0)