88#include " ec_rocm.h"
99#include " utils/ucc_math_op.h"
1010#include < inttypes.h>
11+ #include < hip/hip_complex.h>
1112
1213#define ROCM_REDUCE_WITH_OP_DEFAULT (NAME, _OP ) \
1314 template <typename _Type, typename _AlphaType> \
5455 } \
5556 }
5657
58+ #define ROCM_REDUCE_WITH_COMPLEX_PRODUCT_DEFAULT (NAME, _OP ) \
59+ template <typename _Type, typename _AlphaType> \
60+ __global__ void UCC_REDUCE_ROCM_DEFAULT_COMPLEX_##NAME(ucc_eee_task_reduce_t task, \
61+ uint16_t flags) \
62+ { \
63+ size_t start = blockIdx .x * blockDim .x + threadIdx .x ; \
64+ size_t step = blockDim .x * gridDim .x ; \
65+ size_t count = task.count ; \
66+ int n_srcs = task.n_srcs ; \
67+ const _Type **s = (const _Type **)task.srcs ; \
68+ _Type * d = (_Type *)task.dst ; \
69+ size_t i; \
70+ \
71+ switch (n_srcs) { \
72+ case 2 : \
73+ for (i = start; i < count; i += step) { \
74+ d[i] = _OP (s[0 ][i], s[1 ][i]); \
75+ } \
76+ break ; \
77+ default : \
78+ for (i = start; i < count; i += step) { \
79+ d[i] = _OP (s[0 ][i], s[1 ][i]); \
80+ for (size_t j = 2 ; j < n_srcs; j++) { \
81+ d[i] = _OP (d[i], s[j][i]); \
82+ } \
83+ } \
84+ break ; \
85+ } \
86+ if (flags & UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA) { \
87+ for (i = start; i < count; i += step) { \
88+ d[i] = d[i] * (_AlphaType)task.alpha ; \
89+ } \
90+ } \
91+ }
92+
5793#define ROCM_REDUCE_WITH_OP_STRIDED (NAME, _OP ) \
5894 template <typename _Type, typename _AlphaType> \
5995 __global__ void UCC_REDUCE_ROCM_STRIDED_##NAME( \
99135 } \
100136 }
101137
138+ #define ROCM_REDUCE_WITH_COMPLEX_PRODUCT_STRIDED (NAME, _OP ) \
139+ template <typename _Type, typename _AlphaType> \
140+ __global__ void UCC_REDUCE_ROCM_STRIDED_COMPLEX_##NAME( \
141+ const _Type *s1, const _Type *s2, _Type *d, size_t count, \
142+ size_t stride, uint16_t n_src2, const bool with_alpha, \
143+ const double alpha) \
144+ { \
145+ size_t start = blockIdx .x * blockDim .x + threadIdx .x ; \
146+ size_t step = blockDim .x * gridDim .x ; \
147+ size_t ld = stride / sizeof (_Type); \
148+ size_t i; \
149+ \
150+ ucc_assert_system (stride % sizeof (_Type) == 0 ); \
151+ switch (n_src2) { \
152+ case 1 : \
153+ for (i = start; i < count; i += step) { \
154+ d[i] = _OP (s1[i], s2[i]); \
155+ } \
156+ break ; \
157+ default : \
158+ for (i = start; i < count; i += step) { \
159+ d[i] = _OP (s1[i], s2[i]); \
160+ for (size_t j = 1 ; j < n_src2; j++) { \
161+ d[i] = _OP (d[i], s2[i + j * ld]); \
162+ } \
163+ } \
164+ break ; \
165+ } \
166+ if (with_alpha) { \
167+ for (i = start; i < count; i += step) { \
168+ d[i] = d[i] * (_AlphaType)alpha; \
169+ } \
170+ } \
171+ }
172+
102173ROCM_REDUCE_WITH_OP_DEFAULT (SUM, DO_OP_SUM);
103174ROCM_REDUCE_WITH_OP_DEFAULT (PROD, DO_OP_PROD);
175+ ROCM_REDUCE_WITH_COMPLEX_PRODUCT_DEFAULT (PROD_DOUBLE, hipCmul);
176+ ROCM_REDUCE_WITH_COMPLEX_PRODUCT_DEFAULT (PROD_FLOAT, hipCmulf);
104177ROCM_REDUCE_WITH_OP_DEFAULT (MIN, DO_OP_MIN);
105178ROCM_REDUCE_WITH_OP_DEFAULT (MAX, DO_OP_MAX);
106179ROCM_REDUCE_WITH_OP_DEFAULT (LAND, DO_OP_LAND);
@@ -112,6 +185,8 @@ ROCM_REDUCE_WITH_OP_DEFAULT(BXOR, DO_OP_BXOR);
112185
113186ROCM_REDUCE_WITH_OP_STRIDED (SUM, DO_OP_SUM);
114187ROCM_REDUCE_WITH_OP_STRIDED (PROD, DO_OP_PROD);
188+ ROCM_REDUCE_WITH_COMPLEX_PRODUCT_STRIDED (PROD_DOUBLE, hipCmul);
189+ ROCM_REDUCE_WITH_COMPLEX_PRODUCT_STRIDED (PROD_FLOAT, hipCmulf);
115190ROCM_REDUCE_WITH_OP_STRIDED (MIN, DO_OP_MIN);
116191ROCM_REDUCE_WITH_OP_STRIDED (MAX, DO_OP_MAX);
117192ROCM_REDUCE_WITH_OP_STRIDED (LAND, DO_OP_LAND);
@@ -136,6 +211,21 @@ ROCM_REDUCE_WITH_OP_STRIDED(BXOR, DO_OP_BXOR);
136211 } \
137212 } while (0 )
138213
214+ #define LAUNCH_KERNEL_B (NAME, type, _AlphaType, _task, s, b, t ) \
215+ do { \
216+ if (_task->task_type == UCC_EE_EXECUTOR_TASK_REDUCE) { \
217+ UCC_REDUCE_ROCM_DEFAULT_COMPLEX_##NAME<type, _AlphaType> \
218+ <<<b, t, 0 , s>>> (_task->reduce , _task->flags ); \
219+ } else { \
220+ ucc_eee_task_reduce_strided_t *trs = &_task->reduce_strided ; \
221+ UCC_REDUCE_ROCM_STRIDED_COMPLEX_##NAME<type, _AlphaType><<<b, t, 0 , s>>> ( \
222+ (type *)trs->src1 , (type *)trs->src2 , (type *)trs->dst , \
223+ trs->count , trs->stride , trs->n_src2 , \
224+ (bool )(_task->flags & UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA), \
225+ trs->alpha ); \
226+ } \
227+ } while (0 )
228+
139229#define LAUNCH_KERNEL (NAME, type, _task, s, b, t ) \
140230 LAUNCH_KERNEL_A (NAME, type, type, _task, s, b, t)
141231
@@ -207,15 +297,15 @@ ROCM_REDUCE_WITH_OP_STRIDED(BXOR, DO_OP_BXOR);
207297 } \
208298 } while (0 )
209299
210- #define DT_REDUCE_FLOAT_COMPLEX (type, _alphaType, _task, _op, s, b, t ) \
300+ #define DT_REDUCE_FLOAT_COMPLEX (NAME, type, _alphaType, _task, _op, s, b, t ) \
211301 do { \
212302 switch (_op) { \
213303 case UCC_OP_AVG: \
214304 case UCC_OP_SUM: \
215- LAUNCH_KERNEL_A (SUM, type , _alphaType, _task, s, b, t); \
305+ LAUNCH_KERNEL_A (SUM, type , _alphaType, _task, s, b, t); \
216306 break ; \
217307 case UCC_OP_PROD: \
218- LAUNCH_KERNEL_A (PROD , type, _alphaType, _task, s, b, t); \
308+ LAUNCH_KERNEL_B (NAME , type, _alphaType, _task, s, b, t); \
219309 break ; \
220310 default : \
221311 ec_error (&ucc_ec_rocm.super , \
@@ -299,10 +389,10 @@ ucc_status_t ucc_ec_rocm_reduce(ucc_ee_executor_task_args_t *task,
299389 return UCC_ERR_NOT_SUPPORTED;
300390#endif
301391 case UCC_DT_FLOAT32_COMPLEX:
302- DT_REDUCE_FLOAT_COMPLEX (hipFloatComplex, float , task, op, stream, bk, th);
392+ DT_REDUCE_FLOAT_COMPLEX (PROD_FLOAT, hipFloatComplex, float , task, op, stream, bk, th);
303393 break ;
304394 case UCC_DT_FLOAT64_COMPLEX:
305- DT_REDUCE_FLOAT_COMPLEX (hipDoubleComplex, double , task, op, stream, bk, th);
395+ DT_REDUCE_FLOAT_COMPLEX (PROD_DOUBLE, hipDoubleComplex, double , task, op, stream, bk, th);
306396 break ;
307397 case UCC_DT_BFLOAT16:
308398 ucc_assert (2 == sizeof (hip_bfloat16));
0 commit comments