Skip to content

Commit 7dbc689

Browse files
authored
dpnp.dot() to desc (#854)
1 parent 73445f6 commit 7dbc689

File tree

3 files changed

+16
-10
lines changed

3 files changed

+16
-10
lines changed

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ cpdef dpnp_descriptor dpnp_not_equal(dpnp_descriptor input1, dpnp_descriptor inp
297297
"""
298298
Linear algebra
299299
"""
300-
cpdef dparray dpnp_dot(dpnp_descriptor in_array1, dpnp_descriptor in_array2)
300+
cpdef dpnp_descriptor dpnp_dot(dpnp_descriptor in_array1, dpnp_descriptor in_array2)
301301
cpdef dpnp_descriptor dpnp_matmul(dpnp_descriptor in_array1, dpnp_descriptor in_array2, dpnp_descriptor out=*)
302302

303303

dpnp/dpnp_algo/dpnp_algo_linearalgebra.pyx

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ __all__ += [
4747
ctypedef void(*fptr_2in_1out_shapes_t)(void *, void * , void * , size_t * , size_t * , size_t * , size_t)
4848

4949

50-
cpdef dparray dpnp_dot(dpnp_descriptor in_array1, dpnp_descriptor in_array2):
50+
cpdef utils.dpnp_descriptor dpnp_dot(utils.dpnp_descriptor in_array1, utils.dpnp_descriptor in_array2):
5151

5252
cdef shape_type_c shape1, shape2
5353

@@ -63,9 +63,7 @@ cpdef dparray dpnp_dot(dpnp_descriptor in_array1, dpnp_descriptor in_array2):
6363

6464
# scalar
6565
if dim1 == 0 or dim2 == 0:
66-
x1_desc = dpnp.get_dpnp_descriptor(in_array1)
67-
x2_desc = dpnp.get_dpnp_descriptor(in_array2)
68-
return dpnp_multiply(x1_desc, x2_desc).get_pyobj()
66+
return dpnp_multiply(in_array1, in_array2)
6967

7068
cdef size_t size1 = 0
7169
cdef size_t size2 = 0
@@ -86,14 +84,22 @@ cpdef dparray dpnp_dot(dpnp_descriptor in_array1, dpnp_descriptor in_array2):
8684
# get the FPTR data structure
8785
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_DOT, param1_type, param2_type)
8886

89-
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
9087
# ceate result array with type given by FPTR data
91-
cdef dparray result = dparray((1,), dtype=result_type)
88+
cdef shape_type_c result_shape = (1,)
89+
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape, kernel_data.return_type, None)
9290

9391
cdef fptr_2in_1out_t func = <fptr_2in_1out_t > kernel_data.ptr
9492
# call FPTR function
95-
func(result.get_data(), in_array1.get_data(), in_array1.size, shape1.data(), shape1.size(),
96-
in_array2.get_data(), in_array2.size, shape2.data(), shape2.size(), NULL)
93+
func(result.get_data(),
94+
in_array1.get_data(),
95+
in_array1.size,
96+
shape1.data(),
97+
shape1.size(),
98+
in_array2.get_data(),
99+
in_array2.size,
100+
shape2.data(),
101+
shape2.size(),
102+
NULL)
97103

98104
return result
99105

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def dot(x1, x2, **kwargs):
9898
dim2 = x2_desc.ndim
9999

100100
if not (dim1 >= 2 and dim2 == 1) and not (dim1 >= 2 and dim2 >= 2) and (x1_desc.dtype == x2_desc.dtype):
101-
result_obj = dpnp_dot(x1_desc, x2_desc)
101+
result_obj = dpnp_dot(x1_desc, x2_desc).get_pyobj()
102102
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
103103

104104
return result

0 commit comments

Comments
 (0)