@@ -346,50 +346,80 @@ void dpnp_remainder_c(void* result_out,
346
346
const size_t input2_shape_ndim,
347
347
const size_t * where)
348
348
{
349
- (void )input1_shape;
350
- (void )input1_shape_ndim;
351
- (void )input2_size;
352
- (void )input2_shape;
353
- (void )input2_shape_ndim;
354
349
(void )where;
355
350
356
- cl::sycl::event event;
357
- _DataType_input1* input1 = reinterpret_cast <_DataType_input1*>(const_cast <void *>(input1_in));
358
- _DataType_input2* input2 = reinterpret_cast <_DataType_input2*>(const_cast <void *>(input2_in));
351
+ if (!input1_size || !input2_size)
352
+ {
353
+ return ;
354
+ }
355
+
356
+ _DataType_input1* input1_data = reinterpret_cast <_DataType_input1*>(const_cast <void *>(input1_in));
357
+ _DataType_input2* input2_data = reinterpret_cast <_DataType_input2*>(const_cast <void *>(input2_in));
359
358
_DataType_output* result = reinterpret_cast <_DataType_output*>(result_out);
360
359
361
- if constexpr ((std::is_same<_DataType_input1, double >::value || std::is_same<_DataType_input1, float >::value) &&
362
- std::is_same<_DataType_input2, _DataType_input1>::value)
360
+ std::vector<size_t > result_shape = get_result_shape (input1_shape, input1_shape_ndim,
361
+ input2_shape, input2_shape_ndim);
362
+
363
+ DPNPC_id<_DataType_input1>* input1_it;
364
+ const size_t input1_it_size_in_bytes = sizeof (DPNPC_id<_DataType_input1>);
365
+ input1_it = reinterpret_cast <DPNPC_id<_DataType_input1>*>(dpnp_memory_alloc_c (input1_it_size_in_bytes));
366
+ new (input1_it) DPNPC_id<_DataType_input1>(input1_data, input1_shape, input1_shape_ndim);
367
+
368
+ input1_it->broadcast_to_shape (result_shape);
369
+
370
+ DPNPC_id<_DataType_input2>* input2_it;
371
+ const size_t input2_it_size_in_bytes = sizeof (DPNPC_id<_DataType_input2>);
372
+ input2_it = reinterpret_cast <DPNPC_id<_DataType_input2>*>(dpnp_memory_alloc_c (input2_it_size_in_bytes));
373
+ new (input2_it) DPNPC_id<_DataType_input2>(input2_data, input2_shape, input2_shape_ndim);
374
+
375
+ input2_it->broadcast_to_shape (result_shape);
376
+
377
+ const size_t result_size = input1_it->get_output_size ();
378
+
379
+ cl::sycl::range<1 > gws (result_size);
380
+ auto kernel_parallel_for_func = [=](cl::sycl::id<1 > global_id) {
381
+ const size_t i = global_id[0 ];
382
+ const _DataType_output input1_elem = (*input1_it)[i];
383
+ const _DataType_output input2_elem = (*input2_it)[i];
384
+ double fmod_res = cl::sycl::fmod ((double )input1_elem, (double )input2_elem);
385
+ double add = fmod_res + input2_elem;
386
+ result[i] = cl::sycl::fmod (add, (double )input2_elem);
387
+
388
+ };
389
+ auto kernel_func = [&](cl::sycl::handler& cgh) {
390
+ cgh.parallel_for <class dpnp_remainder_c_kernel <_DataType_output, _DataType_input1, _DataType_input2>>(
391
+ gws, kernel_parallel_for_func);
392
+ };
393
+
394
+ cl::sycl::event event;
395
+
396
+ if (input1_size == input2_size)
363
397
{
364
- event = oneapi::mkl::vm::fmod (DPNP_QUEUE, input1_size, input1, input2, result);
365
- event.wait ();
366
- event = oneapi::mkl::vm::add (DPNP_QUEUE, input1_size, result, input2, result);
367
- event.wait ();
368
- event = oneapi::mkl::vm::fmod (DPNP_QUEUE, input1_size, result, input2, result);
398
+ if constexpr ((std::is_same<_DataType_input1, double >::value ||
399
+ std::is_same<_DataType_input1, float >::value) &&
400
+ std::is_same<_DataType_input2, _DataType_input1>::value)
401
+ {
402
+ event = oneapi::mkl::vm::fmod (DPNP_QUEUE, input1_size, input1_data, input2_data, result);
403
+ event.wait ();
404
+ event = oneapi::mkl::vm::add (DPNP_QUEUE, input1_size, result, input2_data, result);
405
+ event.wait ();
406
+ event = oneapi::mkl::vm::fmod (DPNP_QUEUE, input1_size, result, input2_data, result);
407
+ }
408
+ else
409
+ {
410
+ event = DPNP_QUEUE.submit (kernel_func);
411
+ }
369
412
}
370
413
else
371
414
{
372
- cl::sycl::range<1 > gws (input1_size);
373
- auto kernel_parallel_for_func = [=](cl::sycl::id<1 > global_id) {
374
- size_t i = global_id[0 ]; /* for (size_t i = 0; i < size; ++i)*/
375
- {
376
- _DataType_input1 input_elem1 = input1[i];
377
- _DataType_input2 input_elem2 = input2[i];
378
- double fmod = cl::sycl::fmod ((double )input_elem1, (double )input_elem2);
379
- double add = fmod + input_elem2;
380
- result[i] = cl::sycl::fmod (add, (double )input_elem2);
381
- }
382
- };
383
-
384
- auto kernel_func = [&](cl::sycl::handler& cgh) {
385
- cgh.parallel_for <class dpnp_remainder_c_kernel <_DataType_input1, _DataType_input2, _DataType_output>>(
386
- gws, kernel_parallel_for_func);
387
- };
388
-
389
415
event = DPNP_QUEUE.submit (kernel_func);
390
416
}
391
417
392
418
event.wait ();
419
+
420
+ input1_it->~DPNPC_id ();
421
+ input2_it->~DPNPC_id ();
422
+
393
423
}
394
424
395
425
template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2, typename _KernelNameSpecialization3>
0 commit comments