Skip to content

Commit db1e4f0

Browse files
committed
Use group_load_store compiler extension
1 parent a42a87f commit db1e4f0

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
using dpctl::tensor::kernels::alignment_utils::is_aligned;
4141
using dpctl::tensor::kernels::alignment_utils::required_alignment;
4242

43+
using sycl::ext::oneapi::experimental::group_load;
44+
using sycl::ext::oneapi::experimental::group_store;
45+
4346
template <typename T>
4447
constexpr T dispatch_erf_op(T elem)
4548
{
@@ -522,41 +525,49 @@ static void func_map_init_elemwise_1arg_1type(func_map_t &fmap)
522525
_DataType_input2, \
523526
_DataType_output>) \
524527
{ \
525-
sycl::vec<_DataType_input1, vec_sz> x1 = \
526-
sg.load<vec_sz>(input1_multi_ptr); \
527-
sycl::vec<_DataType_input2, vec_sz> x2 = \
528-
sg.load<vec_sz>(input2_multi_ptr); \
528+
sycl::vec<_DataType_input1, vec_sz> x1{}; \
529+
sycl::vec<_DataType_input2, vec_sz> x2{}; \
530+
\
531+
group_load(sg, input1_multi_ptr, x1); \
532+
group_load(sg, input2_multi_ptr, x2); \
529533
\
530534
res_vec = __vec_operation__; \
531535
} \
532536
else /* input types don't match result type, so \
533537
explicit casting is required */ \
534538
{ \
539+
sycl::vec<_DataType_input1, vec_sz> tmp_x1{}; \
540+
sycl::vec<_DataType_input2, vec_sz> tmp_x2{}; \
541+
\
542+
group_load(sg, input1_multi_ptr, tmp_x1); \
543+
group_load(sg, input2_multi_ptr, tmp_x2); \
544+
\
535545
sycl::vec<_DataType_output, vec_sz> x1 = \
536546
dpnp_vec_cast<_DataType_output, \
537547
_DataType_input1, vec_sz>( \
538-
sg.load<vec_sz>(input1_multi_ptr)); \
548+
tmp_x1); \
539549
sycl::vec<_DataType_output, vec_sz> x2 = \
540550
dpnp_vec_cast<_DataType_output, \
541551
_DataType_input2, vec_sz>( \
542-
sg.load<vec_sz>(input2_multi_ptr)); \
552+
tmp_x2); \
543553
\
544554
res_vec = __vec_operation__; \
545555
} \
546556
} \
547557
else { \
548-
sycl::vec<_DataType_input1, vec_sz> x1 = \
549-
sg.load<vec_sz>(input1_multi_ptr); \
550-
sycl::vec<_DataType_input2, vec_sz> x2 = \
551-
sg.load<vec_sz>(input2_multi_ptr); \
558+
sycl::vec<_DataType_input1, vec_sz> x1{}; \
559+
sycl::vec<_DataType_input2, vec_sz> x2{}; \
560+
\
561+
group_load(sg, input1_multi_ptr, x1); \
562+
group_load(sg, input2_multi_ptr, x2); \
552563
\
553564
for (size_t k = 0; k < vec_sz; ++k) { \
554565
const _DataType_output input1_elem = x1[k]; \
555566
const _DataType_output input2_elem = x2[k]; \
556567
res_vec[k] = __operation__; \
557568
} \
558569
} \
559-
sg.store<vec_sz>(result_multi_ptr, res_vec); \
570+
group_store(sg, res_vec, result_multi_ptr); \
560571
} \
561572
else { \
562573
for (size_t k = start + sg.get_local_id()[0]; \

0 commit comments

Comments
 (0)