Skip to content

Commit de49888

Browse files
KsanaKozlovashssf
andauthored
support parameter out for matmul func (#722)
* add out to matmul function Co-authored-by: Sergey Shalnov <[email protected]>
1 parent b082fd3 commit de49888

File tree

3 files changed

+54
-29
lines changed

3 files changed

+54
-29
lines changed

dpnp/dpnp_algo/dpnp_algo.pyx

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@ cpdef dparray dpnp_init_val(shape, dtype, value):
174174
return result
175175

176176

177-
cpdef dparray dpnp_matmul(dparray in_array1, dparray in_array2):
177+
cpdef dparray dpnp_matmul(dparray in_array1, dparray in_array2, dparray out=None):
178+
178179
cdef vector[Py_ssize_t] shape_result
179180

180181
cdef vector[Py_ssize_t] shape1 = in_array1.shape
@@ -235,8 +236,18 @@ cpdef dparray dpnp_matmul(dparray in_array1, dparray in_array2):
235236
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_MATMUL, param1_type, param2_type)
236237

237238
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
238-
# ceate result array with type given by FPTR data
239-
cdef dparray result = dparray(shape_result, dtype=result_type)
239+
240+
cdef dparray result
241+
242+
if out is not None:
243+
if out.dtype != result_type:
244+
checker_throw_value_error('matmul', 'out.dtype', out.dtype, result_type)
245+
if out.shape != shape_result:
246+
checker_throw_value_error('matmul', 'out.shape', out.shape, shape_result)
247+
result = out
248+
else:
249+
result = dparray(shape_result, dtype=result_type)
250+
240251
if result.size == 0:
241252
return result
242253

dpnp/dpnp_iface.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def get_include():
144144
return dpnp_path
145145

146146

147-
def matmul(in_array1, in_array2, out=None):
147+
def matmul(in_array1, in_array2, out=None, **kwargs):
148148
"""
149149
Matrix product of two arrays.
150150
@@ -154,7 +154,7 @@ def matmul(in_array1, in_array2, out=None):
154154
-----------
155155
Input arrays are supported as :obj:`dpnp.ndarray`.
156156
Otherwise the function will be executed sequentially on CPU.
157-
Parameter ``out`` is supported only with default value ``None``.
157+
Parameter ``out`` is supported as :obj:`dpnp.ndarray` and as default value ``None``.
158158
Input array data types are limited by supported DPNP :ref:`Data types`.
159159
160160
See Also
@@ -180,33 +180,29 @@ def matmul(in_array1, in_array2, out=None):
180180
181181
"""
182182

183-
is_dparray1 = isinstance(in_array1, dparray)
184-
is_dparray2 = isinstance(in_array2, dparray)
185-
186-
if (not use_origin_backend(in_array1) and is_dparray1 and is_dparray2):
187-
188-
if out is not None:
189-
checker_throw_value_error("matmul", "out", type(out), None)
190-
191-
"""
192-
Cost model checks
193-
"""
194-
cost_size = 4096 # 2D array shape(64, 64)
195-
if ((in_array1.dtype == numpy.float64) or (in_array1.dtype == numpy.float32)):
183+
if not use_origin_backend(in_array1) and not kwargs:
184+
if not isinstance(in_array1, dparray):
185+
pass
186+
elif not isinstance(in_array2, dparray):
187+
pass
188+
elif out is not None and not isinstance(out, dparray):
189+
pass
190+
else:
196191
"""
197-
Floating point types are handled via original math library better than SYCL math library
192+
Cost model checks
198193
"""
199-
cost_size = 262144 # 2D array shape(512, 512)
200194

201-
dparray1_size = in_array1.size
202-
dparray2_size = in_array2.size
195+
dparray1_size = in_array1.size
196+
dparray2_size = in_array2.size
197+
cost_size = 4096 # 2D array shape(64, 64)
203198

204-
if (dparray1_size > cost_size) and (dparray2_size > cost_size):
205-
# print(f"dparray1_size={dparray1_size}")
206-
return dpnp_matmul(in_array1, in_array2)
199+
if ((in_array1.dtype == numpy.float64) or (in_array1.dtype == numpy.float32)):
200+
"""
201+
Floating point types are handled via original math library better than SYCL math library
202+
"""
203+
cost_size = 262144 # 2D array shape(512, 512)
207204

208-
input1 = asnumpy(in_array1) if is_dparray1 else in_array1
209-
input2 = asnumpy(in_array2) if is_dparray2 else in_array2
205+
if (dparray1_size > cost_size) and (dparray2_size > cost_size):
206+
return dpnp_matmul(in_array1, in_array2, out=out)
210207

211-
# TODO need to return dparray instead ndarray
212-
return numpy.matmul(input1, input2, out=out)
208+
return call_origin(numpy.matmul, in_array1, in_array2, out=out, **kwargs)

tests/test_mixins.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,24 @@ def test_matmul2(self):
4646

4747
numpy.testing.assert_array_equal(expected, result)
4848

49+
def test_matmul3(self):
50+
array_data1 = numpy.full((513, 513), 5)
51+
array_data2 = numpy.full((513, 513), 2)
52+
out = numpy.empty((513, 513), dtype=numpy.float64)
53+
54+
# DPNP
55+
array1 = inp.array(array_data1, dtype=inp.float64)
56+
array2 = inp.array(array_data2, dtype=inp.float64)
57+
out1 = inp.array(out, dtype=inp.float64)
58+
result = inp.matmul(array1, array2, out=out1)
59+
60+
# original
61+
array_1 = numpy.array(array_data1, dtype=numpy.float64)
62+
array_2 = numpy.array(array_data2, dtype=numpy.float64)
63+
expected = numpy.matmul(array_1, array_2, out=out)
64+
65+
numpy.testing.assert_array_equal(expected, result)
66+
4967

5068
if __name__ == '__main__':
5169
unittest.main()

0 commit comments

Comments
 (0)