Skip to content

Commit 67a101c

Browse files
update backend interfaces for dot and matmul functions (#926)
1 parent f959765 commit 67a101c

File tree

4 files changed

+137
-37
lines changed

4 files changed

+137
-37
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -193,16 +193,40 @@ INP_DLLEXPORT void dpnp_full_like_c(void* array_in, void* result, size_t size);
193193
* @ingroup BACKEND_API
194194
* @brief Matrix multiplication.
195195
*
196-
* Matrix multiplication procedure. Works with 2-D matrices
196+
* Matrix multiplication procedure.
197197
*
198-
* @param [in] array1 Input array.
199-
* @param [in] array2 Input array.
200-
* @param [out] result1 Output array.
201-
* @param [in] size Number of elements in input arrays.
198+
* @param [out] result_out Output array.
199+
* @param [in] result_size Size of output array.
200+
* @param [in] result_ndim Number of output array dimensions.
201+
* @param [in] result_shape Shape of output array.
202+
* @param [in] result_strides Strides of output array.
203+
* @param [in] input1_in First input array.
204+
* @param [in] input1_size Size of first input array.
205+
* @param [in] input1_ndim Number of first input array dimensions.
206+
* @param [in] input1_shape Shape of first input array.
207+
* @param [in] input1_strides Strides of first input array.
208+
* @param [in] input2_in Second input array.
209+
* @param [in] input2_size Size of second input array.
210+
* @param [in] input2_ndim Number of second input array dimensions.
211+
* @param [in] input2_shape Shape of second input array.
212+
* @param [in] input2_strides Strides of second input array.
202213
*/
203214
template <typename _DataType>
204-
INP_DLLEXPORT void
205-
dpnp_matmul_c(void* array1, void* array2, void* result1, size_t size_m, size_t size_n, size_t size_k);
215+
INP_DLLEXPORT void dpnp_matmul_c(void* result_out,
216+
const size_t result_size,
217+
const size_t result_ndim,
218+
const size_t* result_shape,
219+
const size_t* result_strides,
220+
const void* input1_in,
221+
const size_t input1_size,
222+
const size_t input1_ndim,
223+
const size_t* input1_shape,
224+
const size_t* input1_strides,
225+
const void* input2_in,
226+
const size_t input2_size,
227+
const size_t input2_ndim,
228+
const size_t* input2_shape,
229+
const size_t* input2_strides);
206230

207231
/**
208232
* @ingroup BACKEND_API
@@ -248,29 +272,37 @@ INP_DLLEXPORT void dpnp_elemwise_absolute_c(const void* input1_in, void* result1
248272
* @brief Custom implementation of dot function
249273
*
250274
* @param [out] result_out Output array.
275+
* @param [in] result_size Size of output array.
276+
* @param [in] result_ndim Number of output array dimensions.
277+
* @param [in] result_shape Shape of output array.
278+
* @param [in] result_strides Strides of output array.
251279
* @param [in] input1_in First input array.
252280
* @param [in] input1_size Size of first input array.
281+
* @param [in] input1_ndim Number of first input array dimensions.
253282
* @param [in] input1_shape Shape of first input array.
254-
* @param [in] input1_shape_ndim Number of first array dimensions.
283+
* @param [in] input1_strides Strides of first input array.
255284
* @param [in] input2_in Second input array.
256-
* @param [in] input2_size Shape of second input array.
257-
* @param [in] input2_shape Shape of first input array.
258-
* @param [in] input2_shape_ndim Number of second array dimensions.
259-
* @param [in] where Mask array.
260-
* @param [out] result1 Output array.
261-
* @param [in] size Number of elements in input arrays.
285+
* @param [in] input2_size Size of second input array.
286+
* @param [in] input2_ndim Number of second input array dimensions.
287+
* @param [in] input2_shape Shape of second input array.
288+
* @param [in] input2_strides Strides of second input array.
262289
*/
263290
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
264291
INP_DLLEXPORT void dpnp_dot_c(void* result_out,
292+
const size_t result_size,
293+
const size_t result_ndim,
294+
const size_t* result_shape,
295+
const size_t* result_strides,
265296
const void* input1_in,
266297
const size_t input1_size,
298+
const size_t input1_ndim,
267299
const size_t* input1_shape,
268-
const size_t input1_shape_ndim,
300+
const size_t* input1_strides,
269301
const void* input2_in,
270302
const size_t input2_size,
303+
const size_t input2_ndim,
271304
const size_t* input2_shape,
272-
const size_t input2_shape_ndim,
273-
const size_t* where);
305+
const size_t* input2_strides);
274306

275307
/**
276308
* @ingroup BACKEND_API

dpnp/backend/kernels/dpnp_krnl_common.cpp

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,21 +77,32 @@ class dpnp_dot_c_kernel;
7777

7878
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
7979
void dpnp_dot_c(void* result_out,
80+
const size_t result_size,
81+
const size_t result_ndim,
82+
const size_t* result_shape,
83+
const size_t* result_strides,
8084
const void* input1_in,
8185
const size_t input1_size,
86+
const size_t input1_ndim,
8287
const size_t* input1_shape,
83-
const size_t input1_shape_ndim,
88+
const size_t* input1_strides,
8489
const void* input2_in,
8590
const size_t input2_size,
91+
const size_t input2_ndim,
8692
const size_t* input2_shape,
87-
const size_t input2_shape_ndim,
88-
const size_t* where)
93+
const size_t* input2_strides)
8994
{
9095
(void)input1_shape;
91-
(void)input1_shape_ndim;
96+
(void)input1_ndim;
9297
(void)input2_shape;
93-
(void)input2_shape_ndim;
94-
(void)where;
98+
(void)input2_ndim;
99+
100+
(void)result_size;
101+
(void)result_ndim;
102+
(void)result_shape;
103+
(void)result_strides;
104+
(void)input1_strides;
105+
(void)input2_strides;
95106

96107
cl::sycl::event event;
97108
DPNPC_ptr_adapter<_DataType_input1> input1_ptr(input1_in, input1_size);
@@ -308,17 +319,46 @@ template <typename _KernelNameSpecialization>
308319
class dpnp_matmul_c_kernel;
309320

310321
template <typename _DataType>
311-
void dpnp_matmul_c(void* array1_in, void* array2_in, void* result1, size_t size_m, size_t size_n, size_t size_k)
322+
void dpnp_matmul_c(void* result_out,
323+
const size_t result_size,
324+
const size_t result_ndim,
325+
const size_t* result_shape,
326+
const size_t* result_strides,
327+
const void* input1_in,
328+
const size_t input1_size,
329+
const size_t input1_ndim,
330+
const size_t* input1_shape,
331+
const size_t* input1_strides,
332+
const void* input2_in,
333+
const size_t input2_size,
334+
const size_t input2_ndim,
335+
const size_t* input2_shape,
336+
const size_t* input2_strides)
312337
{
338+
(void)result_size;
339+
(void)result_ndim;
340+
(void)result_shape;
341+
(void)result_strides;
342+
(void)input1_size;
343+
(void)input1_ndim;
344+
(void)input1_strides;
345+
(void)input2_size;
346+
(void)input2_ndim;
347+
(void)input2_strides;
348+
349+
size_t size_m = input1_shape[0];
350+
size_t size_n = input2_shape[1];
351+
size_t size_k = input1_shape[1];
352+
313353
if (!size_m || !size_n || !size_k)
314354
{
315355
return;
316356
}
317357

318358
cl::sycl::event event;
319-
DPNPC_ptr_adapter<_DataType> input1_ptr(array1_in, size_m * size_k, true);
320-
DPNPC_ptr_adapter<_DataType> input2_ptr(array2_in, size_k * size_n, true);
321-
DPNPC_ptr_adapter<_DataType> result_ptr(result1, size_m * size_n, true, true);
359+
DPNPC_ptr_adapter<_DataType> input1_ptr(input1_in, size_m * size_k, true);
360+
DPNPC_ptr_adapter<_DataType> input2_ptr(input2_in, size_k * size_n, true);
361+
DPNPC_ptr_adapter<_DataType> result_ptr(result_out, size_m * size_n, true, true);
322362
_DataType* array_1 = input1_ptr.get_ptr();
323363
_DataType* array_2 = input2_ptr.get_ptr();
324364
_DataType* result = result_ptr.get_ptr();

dpnp/backend/kernels/dpnp_krnl_statistics.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,23 @@ void dpnp_correlate_c(void* result_out,
4949
const size_t input2_shape_ndim,
5050
const size_t* where)
5151
{
52+
(void)where;
53+
5254
dpnp_dot_c<_DataType_output, _DataType_input1, _DataType_input2>(result_out,
55+
42, // dummy result_size
56+
42, // dummy result_ndim
57+
NULL, // dummy result_shape
58+
NULL, // dummy result_strides
5359
input1_in,
5460
input1_size,
55-
input1_shape,
5661
input1_shape_ndim,
62+
input1_shape,
63+
NULL, // dummy input1_strides
5764
input2_in,
5865
input2_size,
59-
input2_shape,
6066
input2_shape_ndim,
61-
where);
67+
input2_shape,
68+
NULL); // dummy input2_strides
6269

6370
return;
6471
}

dpnp/dpnp_algo/dpnp_algo_linearalgebra.pyx

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ __all__ += [
4545

4646
# C function pointer to the C library template functions
4747
ctypedef void(*fptr_2in_1out_shapes_t)(void *, void * , void * , size_t * , size_t * , size_t * , size_t)
48-
48+
ctypedef void(*fptr_2in_1out_dot_t)(void * , const size_t, const size_t, const long * , const long * ,
49+
void * , const size_t, const size_t, const long * , const long * ,
50+
void * , const size_t, const size_t, const long * , const long * )
4951

5052
cpdef utils.dpnp_descriptor dpnp_dot(utils.dpnp_descriptor in_array1, utils.dpnp_descriptor in_array2):
5153

@@ -88,18 +90,23 @@ cpdef utils.dpnp_descriptor dpnp_dot(utils.dpnp_descriptor in_array1, utils.dpnp
8890
cdef shape_type_c result_shape = (1,)
8991
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape, kernel_data.return_type, None)
9092

91-
cdef fptr_2in_1out_t func = <fptr_2in_1out_t > kernel_data.ptr
93+
cdef fptr_2in_1out_dot_t func = <fptr_2in_1out_dot_t > kernel_data.ptr
9294
# call FPTR function
9395
func(result.get_data(),
96+
result.size,
97+
result.ndim,
98+
result_shape.data(),
99+
NULL, # result_strides
94100
in_array1.get_data(),
95101
in_array1.size,
102+
in_array1.ndim,
96103
shape1.data(),
97-
shape1.size(),
104+
NULL, # in_array1_strides
98105
in_array2.get_data(),
99106
in_array2.size,
107+
in_array2.ndim,
100108
shape2.data(),
101-
shape2.size(),
102-
NULL)
109+
NULL) # in_array2_strides
103110

104111
return result
105112

@@ -258,9 +265,23 @@ cpdef utils.dpnp_descriptor dpnp_matmul(utils.dpnp_descriptor in_array1, utils.d
258265
if result.size == 0:
259266
return result
260267

261-
cdef fptr_blas_gemm_2in_1out_t func = <fptr_blas_gemm_2in_1out_t > kernel_data.ptr
268+
cdef fptr_2in_1out_dot_t func = <fptr_2in_1out_dot_t > kernel_data.ptr
262269
# call FPTR function
263-
func(in_array1.get_data(), in_array2.get_data(), result.get_data(), size_m, size_n, size_k)
270+
func(result.get_data(),
271+
result.size,
272+
result.ndim,
273+
NULL, # result_shape
274+
NULL, # result_strides
275+
in_array1.get_data(),
276+
in_array1.size,
277+
in_array1.ndim,
278+
shape1.data(),
279+
NULL, # in_array1_strides
280+
in_array2.get_data(),
281+
in_array2.size,
282+
in_array2.ndim,
283+
shape2.data(),
284+
NULL) # in_array2_strides
264285

265286
return result
266287

0 commit comments

Comments
 (0)