|
40 | 40 | using dpctl::tensor::kernels::alignment_utils::is_aligned; |
41 | 41 | using dpctl::tensor::kernels::alignment_utils::required_alignment; |
42 | 42 |
|
| 43 | +using sycl::ext::oneapi::experimental::group_load; |
| 44 | +using sycl::ext::oneapi::experimental::group_store; |
| 45 | + |
43 | 46 | template <typename T> |
44 | 47 | constexpr T dispatch_erf_op(T elem) |
45 | 48 | { |
@@ -522,41 +525,49 @@ static void func_map_init_elemwise_1arg_1type(func_map_t &fmap) |
522 | 525 | _DataType_input2, \ |
523 | 526 | _DataType_output>) \ |
524 | 527 | { \ |
525 | | - sycl::vec<_DataType_input1, vec_sz> x1 = \ |
526 | | - sg.load<vec_sz>(input1_multi_ptr); \ |
527 | | - sycl::vec<_DataType_input2, vec_sz> x2 = \ |
528 | | - sg.load<vec_sz>(input2_multi_ptr); \ |
| 528 | + sycl::vec<_DataType_input1, vec_sz> x1{}; \ |
| 529 | + sycl::vec<_DataType_input2, vec_sz> x2{}; \ |
| 530 | + \ |
| 531 | + group_load(sg, input1_multi_ptr, x1); \ |
| 532 | + group_load(sg, input2_multi_ptr, x2); \ |
529 | 533 | \ |
530 | 534 | res_vec = __vec_operation__; \ |
531 | 535 | } \ |
532 | 536 | else /* input types don't match result type, so \ |
533 | 537 | explicit casting is required */ \ |
534 | 538 | { \ |
| 539 | + sycl::vec<_DataType_input1, vec_sz> tmp_x1{}; \ |
| 540 | + sycl::vec<_DataType_input2, vec_sz> tmp_x2{}; \ |
| 541 | + \ |
| 542 | + group_load(sg, input1_multi_ptr, tmp_x1); \ |
| 543 | + group_load(sg, input2_multi_ptr, tmp_x2); \ |
| 544 | + \ |
535 | 545 | sycl::vec<_DataType_output, vec_sz> x1 = \ |
536 | 546 | dpnp_vec_cast<_DataType_output, \ |
537 | 547 | _DataType_input1, vec_sz>( \ |
538 | | - sg.load<vec_sz>(input1_multi_ptr)); \ |
| 548 | + tmp_x1); \ |
539 | 549 | sycl::vec<_DataType_output, vec_sz> x2 = \ |
540 | 550 | dpnp_vec_cast<_DataType_output, \ |
541 | 551 | _DataType_input2, vec_sz>( \ |
542 | | - sg.load<vec_sz>(input2_multi_ptr)); \ |
| 552 | + tmp_x2); \ |
543 | 553 | \ |
544 | 554 | res_vec = __vec_operation__; \ |
545 | 555 | } \ |
546 | 556 | } \ |
547 | 557 | else { \ |
548 | | - sycl::vec<_DataType_input1, vec_sz> x1 = \ |
549 | | - sg.load<vec_sz>(input1_multi_ptr); \ |
550 | | - sycl::vec<_DataType_input2, vec_sz> x2 = \ |
551 | | - sg.load<vec_sz>(input2_multi_ptr); \ |
| 558 | + sycl::vec<_DataType_input1, vec_sz> x1{}; \ |
| 559 | + sycl::vec<_DataType_input2, vec_sz> x2{}; \ |
| 560 | + \ |
| 561 | + group_load(sg, input1_multi_ptr, x1); \ |
| 562 | + group_load(sg, input2_multi_ptr, x2); \ |
552 | 563 | \ |
553 | 564 | for (size_t k = 0; k < vec_sz; ++k) { \ |
554 | 565 | const _DataType_output input1_elem = x1[k]; \ |
555 | 566 | const _DataType_output input2_elem = x2[k]; \ |
556 | 567 | res_vec[k] = __operation__; \ |
557 | 568 | } \ |
558 | 569 | } \ |
559 | | - sg.store<vec_sz>(result_multi_ptr, res_vec); \ |
| 570 | + group_store(sg, res_vec, result_multi_ptr); \ |
560 | 571 | } \ |
561 | 572 | else { \ |
562 | 573 | for (size_t k = start + sg.get_local_id()[0]; \ |
|
0 commit comments