@@ -275,6 +275,73 @@ void dpnp_inv_c(void* array1_in, void* result1, size_t* shape, size_t ndim)
275
275
return ;
276
276
}
277
277
278
+ template <typename _DataType1, typename _DataType2, typename _ResultType>
279
+ class dpnp_kron_c_kernel ;
280
+
281
+ template <typename _DataType1, typename _DataType2, typename _ResultType>
282
+ void dpnp_kron_c (void * array1_in,
283
+ void * array2_in,
284
+ void * result1,
285
+ size_t * in1_shape,
286
+ size_t * in2_shape,
287
+ size_t * res_shape,
288
+ size_t ndim)
289
+ {
290
+ _DataType1* array1 = reinterpret_cast <_DataType1*>(array1_in);
291
+ _DataType2* array2 = reinterpret_cast <_DataType2*>(array2_in);
292
+ _ResultType* result = reinterpret_cast <_ResultType*>(result1);
293
+
294
+ size_t size = 1 ;
295
+ for (size_t i = 0 ; i < ndim; ++i)
296
+ {
297
+ size *= res_shape[i];
298
+ }
299
+
300
+ size_t * _in1_shape = reinterpret_cast <size_t *>(dpnp_memory_alloc_c (ndim * sizeof (size_t )));
301
+ size_t * _in2_shape = reinterpret_cast <size_t *>(dpnp_memory_alloc_c (ndim * sizeof (size_t )));
302
+
303
+ dpnp_memory_memcpy_c (_in1_shape, in1_shape, ndim * sizeof (size_t ));
304
+ dpnp_memory_memcpy_c (_in2_shape, in2_shape, ndim * sizeof (size_t ));
305
+
306
+ size_t * in1_offsets = reinterpret_cast <size_t *>(dpnp_memory_alloc_c (ndim * sizeof (size_t )));
307
+ size_t * in2_offsets = reinterpret_cast <size_t *>(dpnp_memory_alloc_c (ndim * sizeof (size_t )));
308
+ size_t * res_offsets = reinterpret_cast <size_t *>(dpnp_memory_alloc_c (ndim * sizeof (size_t )));
309
+
310
+ get_shape_offsets_inkernel<size_t >(in1_shape, ndim, in1_offsets);
311
+ get_shape_offsets_inkernel<size_t >(in2_shape, ndim, in2_offsets);
312
+ get_shape_offsets_inkernel<size_t >(res_shape, ndim, res_offsets);
313
+
314
+ cl::sycl::range<1 > gws (size);
315
+ auto kernel_parallel_for_func = [=](cl::sycl::id<1 > global_id) {
316
+ const size_t idx = global_id[0 ];
317
+
318
+ size_t idx1 = 0 ;
319
+ size_t idx2 = 0 ;
320
+ size_t reminder = idx;
321
+ for (size_t axis = 0 ; axis < ndim; ++axis)
322
+ {
323
+ const size_t res_axis = reminder / res_offsets[axis];
324
+ reminder = reminder - res_axis * res_offsets[axis];
325
+
326
+ const size_t in1_axis = res_axis / _in2_shape[axis];
327
+ const size_t in2_axis = res_axis - in1_axis * _in2_shape[axis];
328
+
329
+ idx1 += in1_axis * in1_offsets[axis];
330
+ idx2 += in2_axis * in2_offsets[axis];
331
+ }
332
+
333
+ result[idx] = array1[idx1] * array2[idx2];
334
+ };
335
+
336
+ auto kernel_func = [&](cl::sycl::handler& cgh) {
337
+ cgh.parallel_for <class dpnp_kron_c_kernel <_DataType1, _DataType2, _ResultType>>(gws, kernel_parallel_for_func);
338
+ };
339
+
340
+ cl::sycl::event event = DPNP_QUEUE.submit (kernel_func);
341
+
342
+ event.wait ();
343
+ }
344
+
278
345
template <typename _DataType>
279
346
class dpnp_matrix_rank_c_kernel ;
280
347
@@ -379,6 +446,41 @@ void func_map_init_linalg_func(func_map_t& fmap)
379
446
fmap[DPNPFuncName::DPNP_FN_INV][eft_FLT][eft_FLT] = {eft_FLT, (void *)dpnp_inv_c<float >};
380
447
fmap[DPNPFuncName::DPNP_FN_INV][eft_DBL][eft_DBL] = {eft_DBL, (void *)dpnp_inv_c<double >};
381
448
449
+ fmap[DPNPFuncName::DPNP_FN_KRON][eft_INT][eft_INT] = {eft_INT, (void *)dpnp_kron_c<int , int , int >};
450
+ fmap[DPNPFuncName::DPNP_FN_KRON][eft_INT][eft_LNG] = {eft_LNG, (void *)dpnp_kron_c<int , long , long >};
451
+ fmap[DPNPFuncName::DPNP_FN_KRON][eft_INT][eft_FLT] = {eft_FLT, (void *)dpnp_kron_c<int , float , float >};
452
+ fmap[DPNPFuncName::DPNP_FN_KRON][eft_INT][eft_DBL] = {eft_DBL, (void *)dpnp_kron_c<int , double , double >};
453
+ // fmap[DPNPFuncName::DPNP_FN_KRON][eft_INT][eft_C128] = {
454
+ // eft_C128, (void*)dpnp_kron_c<int, std::complex<double>, std::complex<double>>};
455
+ fmap[DPNPFuncName::DPNP_FN_KRON][eft_LNG][eft_INT] = {eft_LNG, (void *)dpnp_kron_c<long , int , long >};
456
+ fmap[DPNPFuncName::DPNP_FN_KRON][eft_LNG][eft_LNG] = {eft_LNG, (void *)dpnp_kron_c<long , long , long >};
457
+ fmap[DPNPFuncName::DPNP_FN_KRON][eft_LNG][eft_FLT] = {eft_FLT, (void *)dpnp_kron_c<long , float , float >};
458
+ fmap[DPNPFuncName::DPNP_FN_KRON][eft_LNG][eft_DBL] = {eft_DBL, (void *)dpnp_kron_c<long , double , double >};
459
+ // fmap[DPNPFuncName::DPNP_FN_KRON][eft_LNG][eft_C128] = {
460
+ // eft_C128, (void*)dpnp_kron_c<long, std::complex<double>, std::complex<double>>};
461
+ fmap[DPNPFuncName::DPNP_FN_KRON][eft_FLT][eft_INT] = {eft_FLT, (void *)dpnp_kron_c<float , int , float >};
462
+ fmap[DPNPFuncName::DPNP_FN_KRON][eft_FLT][eft_LNG] = {eft_FLT, (void *)dpnp_kron_c<float , long , float >};
463
+ fmap[DPNPFuncName::DPNP_FN_KRON][eft_FLT][eft_FLT] = {eft_FLT, (void *)dpnp_kron_c<float , float , float >};
464
+ fmap[DPNPFuncName::DPNP_FN_KRON][eft_FLT][eft_DBL] = {eft_DBL, (void *)dpnp_kron_c<float , double , double >};
465
+ // fmap[DPNPFuncName::DPNP_FN_KRON][eft_FLT][eft_C128] = {
466
+ // eft_C128, (void*)dpnp_kron_c<float, std::complex<double>, std::complex<double>>};
467
+ fmap[DPNPFuncName::DPNP_FN_KRON][eft_DBL][eft_INT] = {eft_DBL, (void *)dpnp_kron_c<double , int , double >};
468
+ fmap[DPNPFuncName::DPNP_FN_KRON][eft_DBL][eft_LNG] = {eft_DBL, (void *)dpnp_kron_c<double , long , double >};
469
+ fmap[DPNPFuncName::DPNP_FN_KRON][eft_DBL][eft_FLT] = {eft_DBL, (void *)dpnp_kron_c<double , float , double >};
470
+ fmap[DPNPFuncName::DPNP_FN_KRON][eft_DBL][eft_DBL] = {eft_DBL, (void *)dpnp_kron_c<double , double , double >};
471
+ fmap[DPNPFuncName::DPNP_FN_KRON][eft_DBL][eft_C128] = {
472
+ eft_C128, (void *)dpnp_kron_c<double , std::complex<double >, std::complex<double >>};
473
+ // fmap[DPNPFuncName::DPNP_FN_KRON][eft_C128][eft_INT] = {
474
+ // eft_C128, (void*)dpnp_kron_c<std::complex<double>, int, std::complex<double>>};
475
+ // fmap[DPNPFuncName::DPNP_FN_KRON][eft_C128][eft_LNG] = {
476
+ // eft_C128, (void*)dpnp_kron_c<std::complex<double>, long, std::complex<double>>};
477
+ // fmap[DPNPFuncName::DPNP_FN_KRON][eft_C128][eft_FLT] = {
478
+ // eft_C128, (void*)dpnp_kron_c<std::complex<double>, float, std::complex<double>>};
479
+ fmap[DPNPFuncName::DPNP_FN_KRON][eft_C128][eft_DBL] = {
480
+ eft_C128, (void *)dpnp_kron_c<std::complex<double >, double , std::complex<double >>};
481
+ fmap[DPNPFuncName::DPNP_FN_KRON][eft_C128][eft_C128] = {
482
+ eft_C128, (void *)dpnp_kron_c<std::complex<double >, std::complex<double >, std::complex<double >>};
483
+
382
484
fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK][eft_INT][eft_INT] = {eft_INT, (void *)dpnp_matrix_rank_c<int >};
383
485
fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK][eft_LNG][eft_LNG] = {eft_LNG, (void *)dpnp_matrix_rank_c<long >};
384
486
fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK][eft_FLT][eft_FLT] = {eft_FLT, (void *)dpnp_matrix_rank_c<float >};
0 commit comments