Skip to content

Commit 1eaceaf

Browse files
authored
Move dpnp_matmul() to right place (#745)
1 parent 3a9aa7e commit 1eaceaf

File tree

2 files changed

+85
-85
lines changed

2 files changed

+85
-85
lines changed

dpnp/dpnp_algo/dpnp_algo.pyx

Lines changed: 0 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ __all__ = [
4848
"dpnp_astype",
4949
"dpnp_flatten",
5050
"dpnp_init_val",
51-
"dpnp_matmul",
5251
"dpnp_queue_initialize",
5352
"dpnp_queue_is_cpu"
5453
]
@@ -174,90 +173,6 @@ cpdef dparray dpnp_init_val(shape, dtype, value):
174173
return result
175174

176175

177-
cpdef dparray dpnp_matmul(dparray in_array1, dparray in_array2, dparray out=None):
178-
179-
cdef vector[Py_ssize_t] shape_result
180-
181-
cdef vector[Py_ssize_t] shape1 = in_array1.shape
182-
cdef vector[Py_ssize_t] shape2 = in_array2.shape
183-
184-
cdef size_t size_m = 0
185-
cdef size_t size_n = 0
186-
cdef size_t size_k = 0
187-
188-
# Calling this function on an empty container causes undefined behavior.
189-
if not shape1.empty():
190-
size_m = shape1.front()
191-
if not shape2.empty():
192-
size_n = shape2.back()
193-
if not shape1.empty():
194-
size_k = shape1.back()
195-
196-
cdef size_t ndim_max = max(in_array1.ndim, in_array2.ndim)
197-
198-
if in_array1.ndim < ndim_max or ndim_max == 1:
199-
"""
200-
shape1(2,), shape2(2,4)
201-
test: pytest tests/test_matmul.py::test_matmul[shape_pair4-types0] -v -s
202-
or
203-
shape1(2,), shape2(2,)
204-
test: pytest tests/test_matmul.py::test_matmul[shape_pair8-types0] -v -s
205-
"""
206-
size_m = 1
207-
208-
if in_array2.ndim < ndim_max or ndim_max == 1:
209-
"""
210-
shape1(5,2), shape2(2,)
211-
test: pytest tests/test_matmul.py::test_matmul[shape_pair6-types0] -v -s
212-
or
213-
shape1(3,), shape2(3,)
214-
test: pytest tests/test_matmul.py::test_matmul[shape_pair8-types0] -v -s
215-
"""
216-
size_n = 1
217-
218-
if ndim_max > 2:
219-
"""
220-
shape1(5, 3, 2) * shape2(5, 2, 4) -> result(5, 3, 4)
221-
test: pytest tests/test_matmul.py::test_matmul[shape_pair10-types0] -v -s
222-
"""
223-
shape_result = shape1[:-1] + [shape2.back()]
224-
else:
225-
"""
226-
shape1(5,2) * shape2(2,3) -> result(5,3)
227-
test: pytest tests/test_matmul.py::test_matmul[shape_pair0-types0] -v -s
228-
"""
229-
shape_result = shape1[:-1] + shape2[1:]
230-
231-
# convert string type names (dparray.dtype) to C enum DPNPFuncType
232-
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(in_array1.dtype)
233-
cdef DPNPFuncType param2_type = dpnp_dtype_to_DPNPFuncType(in_array2.dtype)
234-
235-
# get the FPTR data structure
236-
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_MATMUL, param1_type, param2_type)
237-
238-
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
239-
240-
cdef dparray result
241-
242-
if out is not None:
243-
if out.dtype != result_type:
244-
utils.checker_throw_value_error('matmul', 'out.dtype', out.dtype, result_type)
245-
if out.shape != shape_result:
246-
utils.checker_throw_value_error('matmul', 'out.shape', out.shape, shape_result)
247-
result = out
248-
else:
249-
result = dparray(shape_result, dtype=result_type)
250-
251-
if result.size == 0:
252-
return result
253-
254-
cdef fptr_blas_gemm_2in_1out_t func = <fptr_blas_gemm_2in_1out_t > kernel_data.ptr
255-
# call FPTR function
256-
func(in_array1.get_data(), in_array2.get_data(), result.get_data(), size_m, size_n, size_k)
257-
258-
return result
259-
260-
261176
cpdef dpnp_queue_initialize():
262177
"""
263178
Initialize SYCL queue which will be used for any library operations.

dpnp/dpnp_algo/dpnp_algo_linearalgebra.pyx

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ __all__ += [
3838
"dpnp_dot",
3939
"dpnp_inner",
4040
"dpnp_kron",
41+
"dpnp_matmul",
4142
"dpnp_outer"
4243
]
4344

@@ -196,6 +197,90 @@ cpdef dparray dpnp_kron(dparray in_array1, dparray in_array2):
196197
return result
197198

198199

200+
cpdef dparray dpnp_matmul(dparray in_array1, dparray in_array2, dparray out=None):
201+
202+
cdef vector[Py_ssize_t] shape_result
203+
204+
cdef vector[Py_ssize_t] shape1 = in_array1.shape
205+
cdef vector[Py_ssize_t] shape2 = in_array2.shape
206+
207+
cdef size_t size_m = 0
208+
cdef size_t size_n = 0
209+
cdef size_t size_k = 0
210+
211+
# Calling this function on an empty container causes undefined behavior.
212+
if not shape1.empty():
213+
size_m = shape1.front()
214+
if not shape2.empty():
215+
size_n = shape2.back()
216+
if not shape1.empty():
217+
size_k = shape1.back()
218+
219+
cdef size_t ndim_max = max(in_array1.ndim, in_array2.ndim)
220+
221+
if in_array1.ndim < ndim_max or ndim_max == 1:
222+
"""
223+
shape1(2,), shape2(2,4)
224+
test: pytest tests/test_matmul.py::test_matmul[shape_pair4-types0] -v -s
225+
or
226+
shape1(2,), shape2(2,)
227+
test: pytest tests/test_matmul.py::test_matmul[shape_pair8-types0] -v -s
228+
"""
229+
size_m = 1
230+
231+
if in_array2.ndim < ndim_max or ndim_max == 1:
232+
"""
233+
shape1(5,2), shape2(2,)
234+
test: pytest tests/test_matmul.py::test_matmul[shape_pair6-types0] -v -s
235+
or
236+
shape1(3,), shape2(3,)
237+
test: pytest tests/test_matmul.py::test_matmul[shape_pair8-types0] -v -s
238+
"""
239+
size_n = 1
240+
241+
if ndim_max > 2:
242+
"""
243+
shape1(5, 3, 2) * shape2(5, 2, 4) -> result(5, 3, 4)
244+
test: pytest tests/test_matmul.py::test_matmul[shape_pair10-types0] -v -s
245+
"""
246+
shape_result = shape1[:-1] + [shape2.back()]
247+
else:
248+
"""
249+
shape1(5,2) * shape2(2,3) -> result(5,3)
250+
test: pytest tests/test_matmul.py::test_matmul[shape_pair0-types0] -v -s
251+
"""
252+
shape_result = shape1[:-1] + shape2[1:]
253+
254+
# convert string type names (dparray.dtype) to C enum DPNPFuncType
255+
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(in_array1.dtype)
256+
cdef DPNPFuncType param2_type = dpnp_dtype_to_DPNPFuncType(in_array2.dtype)
257+
258+
# get the FPTR data structure
259+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_MATMUL, param1_type, param2_type)
260+
261+
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
262+
263+
cdef dparray result
264+
265+
if out is not None:
266+
if out.dtype != result_type:
267+
utils.checker_throw_value_error('matmul', 'out.dtype', out.dtype, result_type)
268+
if out.shape != shape_result:
269+
utils.checker_throw_value_error('matmul', 'out.shape', out.shape, shape_result)
270+
result = out
271+
else:
272+
result = dparray(shape_result, dtype=result_type)
273+
274+
if result.size == 0:
275+
return result
276+
277+
cdef fptr_blas_gemm_2in_1out_t func = <fptr_blas_gemm_2in_1out_t > kernel_data.ptr
278+
# call FPTR function
279+
func(in_array1.get_data(), in_array2.get_data(), result.get_data(), size_m, size_n, size_k)
280+
281+
return result
282+
283+
199284
cpdef dparray dpnp_outer(dparray array1, dparray array2):
200285
cdef dparray_shape_type result_shape = (array1.size, array2.size)
201286
result_type = numpy.promote_types(array1.dtype, array1.dtype)

0 commit comments

Comments
 (0)