1212
1313#include <ucc/api/ucc.h>
1414
15- static inline ucc_status_t mca_scoll_ucc_alltoall_init (const void * sbuf , void * rbuf ,
16- int count ,
17- mca_scoll_ucc_module_t * ucc_module ,
18- ucc_coll_req_h * req )
15+ static inline ucc_status_t mca_scoll_ucc_alltoall_init (const void * sbuf , void * rbuf ,
16+ int count , size_t elem_size ,
17+ mca_scoll_ucc_module_t * ucc_module ,
18+ ucc_coll_req_h * req )
1919{
20+ ucc_datatype_t dt ;
21+
22+ if (elem_size == 8 ) {
23+ dt = UCC_DT_INT64 ;
24+ } else if (elem_size == 4 ) {
25+ dt = UCC_DT_INT32 ;
26+ } else {
27+ dt = UCC_DT_INT8 ;
28+ }
29+
2030 ucc_coll_args_t coll = {
2131 .mask = 0 ,
2232 .coll_type = UCC_COLL_TYPE_ALLTOALL ,
2333 .src .info = {
2434 .buffer = (void * )sbuf ,
25- .count = count ,
26- .datatype = UCC_DT_UINT8 ,
35+ .count = count * ucc_module -> group -> proc_count ,
36+ .datatype = dt ,
2737 .mem_type = UCC_MEMORY_TYPE_UNKNOWN
2838 },
2939 .dst .info = {
3040 .buffer = rbuf ,
31- .count = count ,
32- .datatype = UCC_DT_UINT8 ,
41+ .count = count * ucc_module -> group -> proc_count ,
42+ .datatype = dt ,
3343 .mem_type = UCC_MEMORY_TYPE_UNKNOWN
3444 },
3545 };
@@ -56,14 +66,15 @@ int mca_scoll_ucc_alltoall(struct oshmem_group_t *group,
5666
5767 UCC_VERBOSE (3 , "running ucc alltoall" );
5868 ucc_module = (mca_scoll_ucc_module_t * ) group -> g_scoll .scoll_alltoall_module ;
59- count = nelems * element_size ;
69+ count = nelems ;
6070
6171 /* Do nothing on zero-length request */
6272 if (OPAL_UNLIKELY (!nelems )) {
6373 return OSHMEM_SUCCESS ;
6474 }
6575
66- SCOLL_UCC_CHECK (mca_scoll_ucc_alltoall_init (source , target , count , ucc_module , & req ));
76+ SCOLL_UCC_CHECK (mca_scoll_ucc_alltoall_init (source , target , count ,
77+ element_size , ucc_module , & req ));
6778 SCOLL_UCC_CHECK (ucc_collective_post (req ));
6879 SCOLL_UCC_CHECK (scoll_ucc_req_wait (req ));
6980 return OSHMEM_SUCCESS ;
0 commit comments