@@ -15,7 +15,7 @@ def dprint(*args):
15
15
sys .stdout .flush ()
16
16
17
17
import dpctl
18
- from dpctl ._memory import MemoryUSMShared
18
+ from dpctl .memory import MemoryUSMShared
19
19
20
20
functions_list = [o [0 ] for o in getmembers (np ) if isfunction (o [1 ]) or isbuiltin (o [1 ])]
21
21
class_list = [o for o in getmembers (np ) if isclass (o [1 ])]
@@ -38,7 +38,8 @@ def __new__(subtype, shape,
38
38
nelems = np .prod (shape )
39
39
dt = np .dtype (dtype )
40
40
isz = dt .itemsize
41
- buf = MemoryUSMShared (nbytes = isz * max (1 ,nelems ))
41
+ nbytes = int (isz * max (1 , nelems ))
42
+ buf = MemoryUSMShared (nbytes )
42
43
new_obj = np .ndarray .__new__ (
43
44
subtype , shape , dtype = dt ,
44
45
buffer = buf , offset = 0 ,
@@ -71,7 +72,8 @@ def __new__(subtype, shape,
71
72
dtype = dtype , buffer = buffer ,
72
73
offset = offset , strides = strides ,
73
74
order = order )
74
- buf = MemoryUSMShared (nbytes = ar .nbytes )
75
+ nbytes = int (ar .nbytes )
76
+ buf = MemoryUSMShared (nbytes )
75
77
new_obj = np .ndarray .__new__ (
76
78
subtype , shape , dtype = dtype ,
77
79
buffer = buf , offset = 0 ,
0 commit comments