|
41 | 41 | using dpctl::tensor::kernels::alignment_utils::is_aligned; |
42 | 42 | using dpctl::tensor::kernels::alignment_utils::required_alignment; |
43 | 43 |
|
| 44 | +using sycl::ext::oneapi::experimental::group_load; |
| 45 | +using sycl::ext::oneapi::experimental::group_store; |
| 46 | + |
44 | 47 | template <typename T> |
45 | 48 | constexpr T dispatch_erf_op(T elem) |
46 | 49 | { |
@@ -523,41 +526,49 @@ static void func_map_init_elemwise_1arg_1type(func_map_t &fmap) |
523 | 526 | _DataType_input2, \ |
524 | 527 | _DataType_output>) \ |
525 | 528 | { \ |
526 | | - sycl::vec<_DataType_input1, vec_sz> x1 = \ |
527 | | - sg.load<vec_sz>(input1_multi_ptr); \ |
528 | | - sycl::vec<_DataType_input2, vec_sz> x2 = \ |
529 | | - sg.load<vec_sz>(input2_multi_ptr); \ |
| 529 | + sycl::vec<_DataType_input1, vec_sz> x1{}; \ |
| 530 | + sycl::vec<_DataType_input2, vec_sz> x2{}; \ |
| 531 | + \ |
| 532 | + group_load(sg, input1_multi_ptr, x1); \ |
| 533 | + group_load(sg, input2_multi_ptr, x2); \ |
530 | 534 | \ |
531 | 535 | res_vec = __vec_operation__; \ |
532 | 536 | } \ |
533 | 537 | else /* input types don't match result type, so \ |
534 | 538 | explicit casting is required */ \ |
535 | 539 | { \ |
| 540 | + sycl::vec<_DataType_input1, vec_sz> tmp_x1{}; \ |
| 541 | + sycl::vec<_DataType_input2, vec_sz> tmp_x2{}; \ |
| 542 | + \ |
| 543 | + group_load(sg, input1_multi_ptr, tmp_x1); \ |
| 544 | + group_load(sg, input2_multi_ptr, tmp_x2); \ |
| 545 | + \ |
536 | 546 | sycl::vec<_DataType_output, vec_sz> x1 = \ |
537 | 547 | dpnp_vec_cast<_DataType_output, \ |
538 | 548 | _DataType_input1, vec_sz>( \ |
539 | | - sg.load<vec_sz>(input1_multi_ptr)); \ |
| 549 | + tmp_x1); \ |
540 | 550 | sycl::vec<_DataType_output, vec_sz> x2 = \ |
541 | 551 | dpnp_vec_cast<_DataType_output, \ |
542 | 552 | _DataType_input2, vec_sz>( \ |
543 | | - sg.load<vec_sz>(input2_multi_ptr)); \ |
| 553 | + tmp_x2); \ |
544 | 554 | \ |
545 | 555 | res_vec = __vec_operation__; \ |
546 | 556 | } \ |
547 | 557 | } \ |
548 | 558 | else { \ |
549 | | - sycl::vec<_DataType_input1, vec_sz> x1 = \ |
550 | | - sg.load<vec_sz>(input1_multi_ptr); \ |
551 | | - sycl::vec<_DataType_input2, vec_sz> x2 = \ |
552 | | - sg.load<vec_sz>(input2_multi_ptr); \ |
| 559 | + sycl::vec<_DataType_input1, vec_sz> x1{}; \ |
| 560 | + sycl::vec<_DataType_input2, vec_sz> x2{}; \ |
| 561 | + \ |
| 562 | + group_load(sg, input1_multi_ptr, x1); \ |
| 563 | + group_load(sg, input2_multi_ptr, x2); \ |
553 | 564 | \ |
554 | 565 | for (size_t k = 0; k < vec_sz; ++k) { \ |
555 | 566 | const _DataType_output input1_elem = x1[k]; \ |
556 | 567 | const _DataType_output input2_elem = x2[k]; \ |
557 | 568 | res_vec[k] = __operation__; \ |
558 | 569 | } \ |
559 | 570 | } \ |
560 | | - sg.store<vec_sz>(result_multi_ptr, res_vec); \ |
| 571 | + group_store(sg, res_vec, result_multi_ptr); \ |
561 | 572 | } \ |
562 | 573 | else { \ |
563 | 574 | for (size_t k = start + sg.get_local_id()[0]; \ |
|
0 commit comments