Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions ompi/mca/op/aarch64/op_aarch64_functions.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
* reserved.
* Copyright (c) 2019 Arm Ltd. All rights reserved.
* Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
* Copyright (c) 2024 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
*
* $COPYRIGHT$
*
Expand Down Expand Up @@ -140,20 +142,18 @@ _Generic((*(out)), \
struct ompi_datatype_t **dtype, \
struct ompi_op_base_module_1_0_0_t *module) \
{ \
int types_per_step = svcnt(*((type##type_size##_t *) _in)); \
size_t idx = 0, left_over = *count; \
const int types_per_step = svcnt(*((type##type_size##_t *) _in)); \
const int cnt = *count; \
type##type_size##_t *in = (type##type_size##_t *) _in, \
*out = (type##type_size##_t *) _out; \
OP_CONCAT(OMPI_OP_TYPE_PREPEND, type##type_size##_t) vsrc, vdst; \
svbool_t pred = svwhilelt_b##type_size(idx, left_over); \
do { \
for (int idx=0; idx < cnt; idx += types_per_step) { \
svbool_t pred = svwhilelt_b##type_size(idx, cnt); \
vsrc = svld1(pred, &in[idx]); \
vdst = svld1(pred, &out[idx]); \
vdst = OP_CONCAT(OMPI_OP_OP_PREPEND, op##_x)(pred, vdst, vsrc); \
OP_CONCAT(OMPI_OP_OP_PREPEND, st1)(pred, &out[idx], vdst); \
idx += types_per_step; \
pred = svwhilelt_b##type_size(idx, left_over); \
} while (svptest_any(svptrue_b##type_size(), pred)); \
} \
}
#endif

Expand Down Expand Up @@ -308,21 +308,19 @@ static void OP_CONCAT(ompi_op_aarch64_3buff_##name##_##type##type_size##_t, APPE
struct ompi_datatype_t **dtype, \
struct ompi_op_base_module_1_0_0_t *module) \
{ \
int types_per_step = svcnt(*((type##type_size##_t *) _in1)); \
const int types_per_step = svcnt(*((type##type_size##_t *) _in1)); \
type##type_size##_t *in1 = (type##type_size##_t *) _in1, \
*in2 = (type##type_size##_t *) _in2, \
*out = (type##type_size##_t *) _out; \
size_t idx = 0, left_over = *count; \
const int cnt = *count; \
OP_CONCAT(OMPI_OP_TYPE_PREPEND, type##type_size##_t) vsrc, vdst; \
svbool_t pred = svwhilelt_b##type_size(idx, left_over); \
do { \
for (int idx=0; idx < cnt; idx += types_per_step) { \
svbool_t pred = svwhilelt_b##type_size(idx, cnt); \
vsrc = svld1(pred, &in1[idx]); \
vdst = svld1(pred, &in2[idx]); \
vdst = OP_CONCAT(OMPI_OP_OP_PREPEND, op##_x)(pred, vdst, vsrc); \
OP_CONCAT(OMPI_OP_OP_PREPEND, st1)(pred, &out[idx], vdst); \
idx += types_per_step; \
pred = svwhilelt_b##type_size(idx, left_over); \
} while (svptest_any(svptrue_b##type_size(), pred)); \
} \
}
#endif /* defined(GENERATE_SVE_CODE) */

Expand Down