@@ -38,6 +38,7 @@ __all__ += [
38
38
" dpnp_dot" ,
39
39
" dpnp_inner" ,
40
40
" dpnp_kron" ,
41
+ " dpnp_matmul" ,
41
42
" dpnp_outer"
42
43
]
43
44
@@ -196,6 +197,90 @@ cpdef dparray dpnp_kron(dparray in_array1, dparray in_array2):
196
197
return result
197
198
198
199
200
+ cpdef dparray dpnp_matmul(dparray in_array1, dparray in_array2, dparray out = None ):
201
+
202
+ cdef vector[Py_ssize_t] shape_result
203
+
204
+ cdef vector[Py_ssize_t] shape1 = in_array1.shape
205
+ cdef vector[Py_ssize_t] shape2 = in_array2.shape
206
+
207
+ cdef size_t size_m = 0
208
+ cdef size_t size_n = 0
209
+ cdef size_t size_k = 0
210
+
211
+ # Calling this function on an empty container causes undefined behavior.
212
+ if not shape1.empty():
213
+ size_m = shape1.front()
214
+ if not shape2.empty():
215
+ size_n = shape2.back()
216
+ if not shape1.empty():
217
+ size_k = shape1.back()
218
+
219
+ cdef size_t ndim_max = max (in_array1.ndim, in_array2.ndim)
220
+
221
+ if in_array1.ndim < ndim_max or ndim_max == 1 :
222
+ """
223
+ shape1(2,), shape2(2,4)
224
+ test: pytest tests/test_matmul.py::test_matmul[shape_pair4-types0] -v -s
225
+ or
226
+ shape1(2,), shape2(2,)
227
+ test: pytest tests/test_matmul.py::test_matmul[shape_pair8-types0] -v -s
228
+ """
229
+ size_m = 1
230
+
231
+ if in_array2.ndim < ndim_max or ndim_max == 1 :
232
+ """
233
+ shape1(5,2), shape2(2,)
234
+ test: pytest tests/test_matmul.py::test_matmul[shape_pair6-types0] -v -s
235
+ or
236
+ shape1(3,), shape2(3,)
237
+ test: pytest tests/test_matmul.py::test_matmul[shape_pair8-types0] -v -s
238
+ """
239
+ size_n = 1
240
+
241
+ if ndim_max > 2 :
242
+ """
243
+ shape1(5, 3, 2) * shape2(5, 2, 4) -> result(5, 3, 4)
244
+ test: pytest tests/test_matmul.py::test_matmul[shape_pair10-types0] -v -s
245
+ """
246
+ shape_result = shape1[:- 1 ] + [shape2.back()]
247
+ else :
248
+ """
249
+ shape1(5,2) * shape2(2,3) -> result(5,3)
250
+ test: pytest tests/test_matmul.py::test_matmul[shape_pair0-types0] -v -s
251
+ """
252
+ shape_result = shape1[:- 1 ] + shape2[1 :]
253
+
254
+ # convert string type names (dparray.dtype) to C enum DPNPFuncType
255
+ cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(in_array1.dtype)
256
+ cdef DPNPFuncType param2_type = dpnp_dtype_to_DPNPFuncType(in_array2.dtype)
257
+
258
+ # get the FPTR data structure
259
+ cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_MATMUL, param1_type, param2_type)
260
+
261
+ result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
262
+
263
+ cdef dparray result
264
+
265
+ if out is not None :
266
+ if out.dtype != result_type:
267
+ utils.checker_throw_value_error(' matmul' , ' out.dtype' , out.dtype, result_type)
268
+ if out.shape != shape_result:
269
+ utils.checker_throw_value_error(' matmul' , ' out.shape' , out.shape, shape_result)
270
+ result = out
271
+ else :
272
+ result = dparray(shape_result, dtype = result_type)
273
+
274
+ if result.size == 0 :
275
+ return result
276
+
277
+ cdef fptr_blas_gemm_2in_1out_t func = < fptr_blas_gemm_2in_1out_t > kernel_data.ptr
278
+ # call FPTR function
279
+ func(in_array1.get_data(), in_array2.get_data(), result.get_data(), size_m, size_n, size_k)
280
+
281
+ return result
282
+
283
+
199
284
cpdef dparray dpnp_outer(dparray array1, dparray array2):
200
285
cdef dparray_shape_type result_shape = (array1.size, array2.size)
201
286
result_type = numpy.promote_types(array1.dtype, array1.dtype)
0 commit comments