@@ -144,7 +144,7 @@ def get_include():
144
144
return dpnp_path
145
145
146
146
147
- def matmul (in_array1 , in_array2 , out = None ):
147
+ def matmul (in_array1 , in_array2 , out = None , ** kwargs ):
148
148
"""
149
149
Matrix product of two arrays.
150
150
@@ -154,7 +154,7 @@ def matmul(in_array1, in_array2, out=None):
154
154
-----------
155
155
Input arrays are supported as :obj:`dpnp.ndarray`.
156
156
Otherwise the function will be executed sequentially on CPU.
157
- Parameter ``out`` is supported only with default value ``None``.
157
+ Parameter ``out`` is supported as :obj:`dpnp.ndarray` and as default value ``None``.
158
158
Input array data types are limited by supported DPNP :ref:`Data types`.
159
159
160
160
See Also
@@ -180,33 +180,29 @@ def matmul(in_array1, in_array2, out=None):
180
180
181
181
"""
182
182
183
- is_dparray1 = isinstance (in_array1 , dparray )
184
- is_dparray2 = isinstance (in_array2 , dparray )
185
-
186
- if (not use_origin_backend (in_array1 ) and is_dparray1 and is_dparray2 ):
187
-
188
- if out is not None :
189
- checker_throw_value_error ("matmul" , "out" , type (out ), None )
190
-
191
- """
192
- Cost model checks
193
- """
194
- cost_size = 4096 # 2D array shape(64, 64)
195
- if ((in_array1 .dtype == numpy .float64 ) or (in_array1 .dtype == numpy .float32 )):
183
+ if not use_origin_backend (in_array1 ) and not kwargs :
184
+ if not isinstance (in_array1 , dparray ):
185
+ pass
186
+ elif not isinstance (in_array2 , dparray ):
187
+ pass
188
+ elif out is not None and not isinstance (out , dparray ):
189
+ pass
190
+ else :
196
191
"""
197
- Floating point types are handled via original math library better than SYCL math library
192
+ Cost model checks
198
193
"""
199
- cost_size = 262144 # 2D array shape(512, 512)
200
194
201
- dparray1_size = in_array1 .size
202
- dparray2_size = in_array2 .size
195
+ dparray1_size = in_array1 .size
196
+ dparray2_size = in_array2 .size
197
+ cost_size = 4096 # 2D array shape(64, 64)
203
198
204
- if (dparray1_size > cost_size ) and (dparray2_size > cost_size ):
205
- # print(f"dparray1_size={dparray1_size}")
206
- return dpnp_matmul (in_array1 , in_array2 )
199
+ if ((in_array1 .dtype == numpy .float64 ) or (in_array1 .dtype == numpy .float32 )):
200
+ """
201
+ Floating point types are handled via original math library better than SYCL math library
202
+ """
203
+ cost_size = 262144 # 2D array shape(512, 512)
207
204
208
- input1 = asnumpy ( in_array1 ) if is_dparray1 else in_array1
209
- input2 = asnumpy ( in_array2 ) if is_dparray2 else in_array2
205
+ if ( dparray1_size > cost_size ) and ( dparray2_size > cost_size ):
206
+ return dpnp_matmul ( in_array1 , in_array2 , out = out )
210
207
211
- # TODO need to return dparray instead ndarray
212
- return numpy .matmul (input1 , input2 , out = out )
208
+ return call_origin (numpy .matmul , in_array1 , in_array2 , out = out , ** kwargs )
0 commit comments