40
40
"""
41
41
42
42
43
- import numpy
44
-
45
43
from dpnp .dpnp_algo import *
46
- from dpnp .dparray import dparray
47
44
from dpnp .dpnp_utils import *
48
45
import dpnp
49
46
import dpnp .config as config
50
47
48
+ import numpy
49
+
51
50
52
51
__all__ = [
53
52
"dot" ,
@@ -92,15 +91,14 @@ def dot(x1, x2, **kwargs):
92
91
93
92
"""
94
93
95
- is_x1_dparray = isinstance (x1 , dparray )
96
- is_x2_dparray = isinstance (x2 , dparray )
97
-
98
- if (not use_origin_backend (x1 ) and is_x1_dparray and is_x2_dparray and not kwargs ):
99
- dim1 = x1 .ndim
100
- dim2 = x2 .ndim
94
+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
95
+ x2_desc = dpnp .get_dpnp_descriptor (x2 )
96
+ if x1_desc and x2_desc and not kwargs :
97
+ dim1 = x1_desc .ndim
98
+ dim2 = x2_desc .ndim
101
99
102
- if not (dim1 >= 2 and dim2 == 1 ) and not (dim1 >= 2 and dim2 >= 2 ) and (x1 .dtype == x2 .dtype ):
103
- result = dpnp_dot (x1 , x2 )
100
+ if not (dim1 >= 2 and dim2 == 1 ) and not (dim1 >= 2 and dim2 >= 2 ) and (x1_desc .dtype == x2_desc .dtype ):
101
+ result = dpnp_dot (x1_desc , x2_desc )
104
102
105
103
# scalar returned
106
104
if result .shape == (1 ,):
@@ -186,16 +184,15 @@ def inner(x1, x2, **kwargs):
186
184
187
185
"""
188
186
189
- is_x1_dparray = isinstance (x1 , dparray )
190
- is_x2_dparray = isinstance (x2 , dparray )
191
-
192
- if (not use_origin_backend (x1 ) and is_x1_dparray and is_x2_dparray and not kwargs ):
193
- return dpnp_inner (x1 , x2 )
187
+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
188
+ x2_desc = dpnp .get_dpnp_descriptor (x2 )
189
+ if 0 and x1_desc and x2_desc and not kwargs :
190
+ return dpnp_inner (x1_desc , x2_desc )
194
191
195
192
return call_origin (numpy .inner , x1 , x2 , ** kwargs )
196
193
197
194
198
- def kron (a , b ):
195
+ def kron (x1 , x2 ):
199
196
"""
200
197
Returns the kronecker product of two arrays.
201
198
@@ -205,23 +202,15 @@ def kron(a, b):
205
202
206
203
"""
207
204
208
- if not use_origin_backend (a ):
209
- if dpnp .isscalar (a ):
210
- a = dpnp .array (a )
211
- if dpnp .isscalar (b ):
212
- b = dpnp .array (b )
213
-
214
- if not isinstance (a , dparray ):
215
- pass
216
- elif not isinstance (b , dparray ):
217
- pass
218
- else :
219
- return dpnp_kron (a , b )
205
+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
206
+ x2_desc = dpnp .get_dpnp_descriptor (x2 )
207
+ if x1_desc and x2_desc :
208
+ return dpnp_kron (x1_desc , x2_desc )
220
209
221
- return call_origin (numpy .kron , a , b )
210
+ return call_origin (numpy .kron , x1 , x2 )
222
211
223
212
224
- def matmul (in_array1 , in_array2 , out = None , ** kwargs ):
213
+ def matmul (x1 , x2 , out = None , ** kwargs ):
225
214
"""
226
215
Matrix product of two arrays.
227
216
@@ -257,32 +246,42 @@ def matmul(in_array1, in_array2, out=None, **kwargs):
257
246
258
247
"""
259
248
260
- if not use_origin_backend (in_array1 ) and not kwargs :
261
- if not isinstance (in_array1 , dparray ):
249
+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
250
+ x2_desc = dpnp .get_dpnp_descriptor (x2 )
251
+ out_desc = dpnp .get_dpnp_descriptor (x2 )
252
+ if x1_desc and x2_desc and out_desc and not kwargs :
253
+ if x1_desc .size != x2_desc .size :
254
+ pass
255
+ elif not x1_desc .ndim :
256
+ pass
257
+ elif not x2_desc .ndim :
262
258
pass
263
- elif not isinstance ( in_array2 , dparray ) :
259
+ elif not x1_desc . size :
264
260
pass
265
- elif out is not None and not isinstance ( out , dparray ) :
261
+ elif not x2_desc . size :
266
262
pass
267
263
else :
268
- """
269
- Cost model checks
270
- """
271
-
272
- dparray1_size = in_array1 .size
273
- dparray2_size = in_array2 .size
274
- cost_size = 4096 # 2D array shape(64, 64)
275
-
276
- if ((in_array1 .dtype == numpy .float64 ) or (in_array1 .dtype == numpy .float32 )):
264
+ if 0 :
277
265
"""
278
- Floating point types are handled via original math library better than SYCL math library
266
+ Cost model checks
279
267
"""
280
- cost_size = 262144 # 2D array shape(512, 512)
281
268
282
- if (dparray1_size > cost_size ) and (dparray2_size > cost_size ):
283
- return dpnp_matmul (in_array1 , in_array2 , out = out )
269
+ dparray1_size = x1_desc .size
270
+ dparray2_size = x2_desc .size
271
+ cost_size = 4096 # 2D array shape(64, 64)
284
272
285
- return call_origin (numpy .matmul , in_array1 , in_array2 , out = out , ** kwargs )
273
+ if ((x1_desc .dtype == numpy .float64 ) or (x1_desc .dtype == numpy .float32 )):
274
+ """
275
+ Floating point types are handled via original math library better than SYCL math library
276
+ """
277
+ cost_size = 262144 # 2D array shape(512, 512)
278
+
279
+ if (dparray1_size > cost_size ) and (dparray2_size > cost_size ):
280
+ return dpnp_matmul (x1_desc , x2_desc , out )
281
+ else :
282
+ return dpnp_matmul (x1_desc , x2_desc , out )
283
+
284
+ return call_origin (numpy .matmul , x1 , x2 , out = out , ** kwargs )
286
285
287
286
288
287
def outer (x1 , x2 , ** kwargs ):
@@ -314,11 +313,10 @@ def outer(x1, x2, **kwargs):
314
313
315
314
"""
316
315
317
- is_x1_dparray = isinstance (x1 , dparray )
318
- is_x2_dparray = isinstance (x2 , dparray )
319
-
320
- if (not use_origin_backend (x1 ) and is_x1_dparray and is_x2_dparray and not kwargs ):
321
- return dpnp_outer (x1 , x2 )
316
+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
317
+ x2_desc = dpnp .get_dpnp_descriptor (x2 )
318
+ if 0 and x1_desc and x2_desc and not kwargs :
319
+ return dpnp_outer (x1_desc , x2_desc )
322
320
323
321
return call_origin (numpy .outer , x1 , x2 , ** kwargs )
324
322
@@ -353,11 +351,10 @@ def tensordot(x1, x2, axes=2):
353
351
354
352
"""
355
353
356
- is_x1_dparray = isinstance (x1 , dparray )
357
- is_x2_dparray = isinstance (x2 , dparray )
358
-
359
- if (not use_origin_backend (x1 ) and is_x1_dparray and is_x2_dparray and (axes == 1 )):
360
- return dpnp_tensordot (x1 , x2 ) # dpnp_matmul
354
+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
355
+ x2_desc = dpnp .get_dpnp_descriptor (x2 )
356
+ if x1_desc and x2_desc and (axes == 1 ):
357
+ return dpnp_tensordot_not_implemented (x1_desc , x2_desc ) # dpnp_matmul
361
358
362
359
return call_origin (numpy .tensordot , x1 , x2 , axes )
363
360
0 commit comments