@@ -82,30 +82,66 @@ __all__ = [
82
82
83
83
cdef ERROR_PREFIX = " DPNP error:"
84
84
85
+ def convert_item (item ):
86
+ if getattr (item, " __sycl_usm_array_interface__" , False ):
87
+ item_converted = dpnp.asnumpy(item)
88
+ elif getattr (item, " __array_interface__" , False ): # detect if it is a container (TODO any better way?)
89
+ mod_name = getattr (item, " __module__" , ' none' )
90
+ if (mod_name != ' numpy' ):
91
+ item_converted = dpnp.asnumpy(item)
92
+ else :
93
+ item_converted = item
94
+ elif isinstance (item, list ):
95
+ item_converted = convert_list_args(item)
96
+ elif isinstance (item, tuple ):
97
+ item_converted = tuple (convert_list_args(item))
98
+ else :
99
+ item_converted = item
100
+
101
+ return item_converted
102
+
103
+ def convert_list_args (input_list ):
104
+ result_list = []
105
+ for item in input_list:
106
+ item_converted = convert_item(item)
107
+ result_list.append(item_converted)
85
108
109
+ return result_list
110
+
86
111
def call_origin (function , *args , **kwargs ):
87
112
"""
88
113
Call fallback function for unsupported cases
89
114
"""
90
115
91
- # print(f"DPNP call_origin(): Fallback called. \n\t function={function}, \n\t args={args}, \n\t kwargs={kwargs}")
116
+ dpnp_inplace = kwargs.pop(" dpnp_inplace" , False )
117
+ # print(f"DPNP call_origin(): Fallback called. \n\t function={function}, \n\t args={args}, \n\t kwargs={kwargs}, \n\t dpnp_inplace={dpnp_inplace}")
92
118
93
119
kwargs_out = kwargs.get(" out" , None )
94
120
if (kwargs_out is not None ):
95
- kwargs[" out" ] = dpnp.asnumpy(kwargs_out) if isinstance (kwargs_out, dparray) else kwargs_out
121
+ if isinstance (kwargs_out, numpy.ndarray):
122
+ kwargs[" out" ] = kwargs_out
123
+ else :
124
+ kwargs[" out" ] = dpnp.asnumpy(kwargs_out)
96
125
97
- args_new = []
98
- for arg in args:
99
- argx = dpnp.asnumpy(arg) if isinstance (arg, dparray) else arg
100
- args_new.append(argx)
126
+ if dpnp_inplace:
127
+ # TODO replacement of foreign containers is still needed
128
+ args_new = args
129
+ else :
130
+ args_new_list = []
131
+ for arg in args:
132
+ argx = convert_item(arg)
133
+ args_new_list.append(argx)
134
+ args_new = tuple (args_new_list)
101
135
102
136
kwargs_new = {}
103
137
for key, kwarg in kwargs.items():
104
- kwargx = dpnp.asnumpy (kwarg) if isinstance (kwarg, dparray) else kwarg
138
+ kwargx = convert_item (kwarg)
105
139
kwargs_new[key] = kwargx
106
140
107
- # TODO need to put dparray memory into NumPy call
141
+ # print(f"DPNP call_origin(): bakend called. \n\t function={function}, \n\t args_new={args_new}, \n\t kwargs_new={kwargs_new}, \n\t dpnp_inplace={dpnp_inplace}")
142
+ # TODO need to put array memory into NumPy call
108
143
result_origin = function(* args_new, ** kwargs_new)
144
+ # print(f"DPNP call_origin(): result from backend. \n\t result_origin={result_origin}, \n\t args_new={args_new}, \n\t kwargs_new={kwargs_new}, \n\t dpnp_inplace={dpnp_inplace}")
109
145
result = result_origin
110
146
if isinstance (result, numpy.ndarray):
111
147
if (kwargs_out is None ):
@@ -119,7 +155,8 @@ def call_origin(function, *args, **kwargs):
119
155
result = kwargs_out
120
156
121
157
for i in range (result.size):
122
- result._setitem_scalar(i, result_origin.item(i))
158
+ result.flat[i] = result_origin.item(i)
159
+
123
160
elif isinstance (result, tuple ):
124
161
# convert tuple(ndarray) to tuple(dparray)
125
162
result_list = []
@@ -128,7 +165,7 @@ def call_origin(function, *args, **kwargs):
128
165
if isinstance (res_origin, numpy.ndarray):
129
166
res = dparray(res_origin.shape, dtype = res_origin.dtype)
130
167
for i in range (res.size):
131
- res._setitem_scalar(i, res_origin.item(i) )
168
+ res.flat[i] = res_origin.item(i)
132
169
result_list.append(res)
133
170
134
171
result = tuple (result_list)
0 commit comments