Skip to content

Commit 6b7e5d9

Browse files
authored
Merge pull request #12897 from Sergei-Lebedev/topiv/v4.1.x_ucc_rs_block
v4.1.x: coll/ucc: add reduce scatter block
2 parents cbf1a6d + 5e05bae commit 6b7e5d9

File tree

5 files changed

+226
-79
lines changed

5 files changed

+226
-79
lines changed

ompi/mca/coll/ucc/Makefile.am

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,22 @@
1212

1313
AM_CPPFLAGS = $(coll_ucc_CPPFLAGS)
1414

15-
coll_ucc_sources = \
16-
coll_ucc.h \
17-
coll_ucc_debug.h \
18-
coll_ucc_dtypes.h \
19-
coll_ucc_common.h \
20-
coll_ucc_module.c \
21-
coll_ucc_component.c \
22-
coll_ucc_barrier.c \
23-
coll_ucc_bcast.c \
24-
coll_ucc_allreduce.c \
25-
coll_ucc_reduce.c \
26-
coll_ucc_alltoall.c \
27-
coll_ucc_alltoallv.c \
28-
coll_ucc_allgather.c \
29-
coll_ucc_allgatherv.c
15+
coll_ucc_sources = \
16+
coll_ucc.h \
17+
coll_ucc_debug.h \
18+
coll_ucc_dtypes.h \
19+
coll_ucc_common.h \
20+
coll_ucc_module.c \
21+
coll_ucc_component.c \
22+
coll_ucc_barrier.c \
23+
coll_ucc_bcast.c \
24+
coll_ucc_allreduce.c \
25+
coll_ucc_reduce.c \
26+
coll_ucc_alltoall.c \
27+
coll_ucc_alltoallv.c \
28+
coll_ucc_allgather.c \
29+
coll_ucc_allgatherv.c \
30+
coll_ucc_reduce_scatter_block.c
3031

3132
# Make the output library in this directory, and name it either
3233
# mca_<type>_<name>.la (for DSO builds) or libmca_<type>_<name>.la

ompi/mca/coll/ucc/coll_ucc.h

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@ BEGIN_C_DECLS
2727
#define COLL_UCC_CTS (UCC_COLL_TYPE_BARRIER | UCC_COLL_TYPE_BCAST | \
2828
UCC_COLL_TYPE_ALLREDUCE | UCC_COLL_TYPE_ALLTOALL | \
2929
UCC_COLL_TYPE_ALLTOALLV | UCC_COLL_TYPE_ALLGATHER | \
30-
UCC_COLL_TYPE_REDUCE | UCC_COLL_TYPE_ALLGATHERV)
30+
UCC_COLL_TYPE_REDUCE | UCC_COLL_TYPE_ALLGATHERV | \
31+
UCC_COLL_TYPE_REDUCE_SCATTER)
3132

32-
#define COLL_UCC_CTS_STR "barrier,bcast,allreduce,alltoall,alltoallv,allgather,allgatherv,reduce," \
33-
"ibarrier,ibcast,iallreduce,ialltoall,ialltoallv,iallgather,iallgatherv,ireduce"
33+
#define COLL_UCC_CTS_STR "barrier,bcast,allreduce,alltoall,alltoallv,allgather,allgatherv,reduce,reduce_scatter_block," \
34+
"ibarrier,ibcast,iallreduce,ialltoall,ialltoallv,iallgather,iallgatherv,ireduce,ireduce_scatter_block"
3435

3536
typedef struct mca_coll_ucc_req {
3637
ompi_request_t super;
@@ -64,42 +65,46 @@ OMPI_MODULE_DECLSPEC extern mca_coll_ucc_component_t mca_coll_ucc_component;
6465
* UCC enabled communicator
6566
*/
6667
struct mca_coll_ucc_module_t {
67-
mca_coll_base_module_t super;
68-
ompi_communicator_t* comm;
69-
int rank;
70-
ucc_team_h ucc_team;
71-
mca_coll_base_module_allreduce_fn_t previous_allreduce;
72-
mca_coll_base_module_t* previous_allreduce_module;
73-
mca_coll_base_module_iallreduce_fn_t previous_iallreduce;
74-
mca_coll_base_module_t* previous_iallreduce_module;
75-
mca_coll_base_module_reduce_fn_t previous_reduce;
76-
mca_coll_base_module_t* previous_reduce_module;
77-
mca_coll_base_module_ireduce_fn_t previous_ireduce;
78-
mca_coll_base_module_t* previous_ireduce_module;
79-
mca_coll_base_module_barrier_fn_t previous_barrier;
80-
mca_coll_base_module_t* previous_barrier_module;
81-
mca_coll_base_module_ibarrier_fn_t previous_ibarrier;
82-
mca_coll_base_module_t* previous_ibarrier_module;
83-
mca_coll_base_module_bcast_fn_t previous_bcast;
84-
mca_coll_base_module_t* previous_bcast_module;
85-
mca_coll_base_module_ibcast_fn_t previous_ibcast;
86-
mca_coll_base_module_t* previous_ibcast_module;
87-
mca_coll_base_module_alltoall_fn_t previous_alltoall;
88-
mca_coll_base_module_t* previous_alltoall_module;
89-
mca_coll_base_module_ialltoall_fn_t previous_ialltoall;
90-
mca_coll_base_module_t* previous_ialltoall_module;
91-
mca_coll_base_module_alltoallv_fn_t previous_alltoallv;
92-
mca_coll_base_module_t* previous_alltoallv_module;
93-
mca_coll_base_module_ialltoallv_fn_t previous_ialltoallv;
94-
mca_coll_base_module_t* previous_ialltoallv_module;
95-
mca_coll_base_module_allgather_fn_t previous_allgather;
96-
mca_coll_base_module_t* previous_allgather_module;
97-
mca_coll_base_module_iallgather_fn_t previous_iallgather;
98-
mca_coll_base_module_t* previous_iallgather_module;
99-
mca_coll_base_module_allgatherv_fn_t previous_allgatherv;
100-
mca_coll_base_module_t* previous_allgatherv_module;
101-
mca_coll_base_module_iallgatherv_fn_t previous_iallgatherv;
102-
mca_coll_base_module_t* previous_iallgatherv_module;
68+
mca_coll_base_module_t super;
69+
ompi_communicator_t* comm;
70+
int rank;
71+
ucc_team_h ucc_team;
72+
mca_coll_base_module_allreduce_fn_t previous_allreduce;
73+
mca_coll_base_module_t* previous_allreduce_module;
74+
mca_coll_base_module_iallreduce_fn_t previous_iallreduce;
75+
mca_coll_base_module_t* previous_iallreduce_module;
76+
mca_coll_base_module_reduce_fn_t previous_reduce;
77+
mca_coll_base_module_t* previous_reduce_module;
78+
mca_coll_base_module_ireduce_fn_t previous_ireduce;
79+
mca_coll_base_module_t* previous_ireduce_module;
80+
mca_coll_base_module_barrier_fn_t previous_barrier;
81+
mca_coll_base_module_t* previous_barrier_module;
82+
mca_coll_base_module_ibarrier_fn_t previous_ibarrier;
83+
mca_coll_base_module_t* previous_ibarrier_module;
84+
mca_coll_base_module_bcast_fn_t previous_bcast;
85+
mca_coll_base_module_t* previous_bcast_module;
86+
mca_coll_base_module_ibcast_fn_t previous_ibcast;
87+
mca_coll_base_module_t* previous_ibcast_module;
88+
mca_coll_base_module_alltoall_fn_t previous_alltoall;
89+
mca_coll_base_module_t* previous_alltoall_module;
90+
mca_coll_base_module_ialltoall_fn_t previous_ialltoall;
91+
mca_coll_base_module_t* previous_ialltoall_module;
92+
mca_coll_base_module_alltoallv_fn_t previous_alltoallv;
93+
mca_coll_base_module_t* previous_alltoallv_module;
94+
mca_coll_base_module_ialltoallv_fn_t previous_ialltoallv;
95+
mca_coll_base_module_t* previous_ialltoallv_module;
96+
mca_coll_base_module_allgather_fn_t previous_allgather;
97+
mca_coll_base_module_t* previous_allgather_module;
98+
mca_coll_base_module_iallgather_fn_t previous_iallgather;
99+
mca_coll_base_module_t* previous_iallgather_module;
100+
mca_coll_base_module_allgatherv_fn_t previous_allgatherv;
101+
mca_coll_base_module_t* previous_allgatherv_module;
102+
mca_coll_base_module_iallgatherv_fn_t previous_iallgatherv;
103+
mca_coll_base_module_t* previous_iallgatherv_module;
104+
mca_coll_base_module_reduce_scatter_block_fn_t previous_reduce_scatter_block;
105+
mca_coll_base_module_t* previous_reduce_scatter_block_module;
106+
mca_coll_base_module_ireduce_scatter_block_fn_t previous_ireduce_scatter_block;
107+
mca_coll_base_module_t* previous_ireduce_scatter_block_module;
103108
};
104109
typedef struct mca_coll_ucc_module_t mca_coll_ucc_module_t;
105110
OBJ_CLASS_DECLARATION(mca_coll_ucc_module_t);
@@ -195,5 +200,18 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, int scount, struct ompi_datatype_
195200
ompi_request_t** request,
196201
mca_coll_base_module_t *module);
197202

203+
int mca_coll_ucc_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
204+
struct ompi_datatype_t *dtype,
205+
struct ompi_op_t *op,
206+
struct ompi_communicator_t *comm,
207+
mca_coll_base_module_t *module);
208+
209+
int mca_coll_ucc_ireduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
210+
struct ompi_datatype_t *dtype,
211+
struct ompi_op_t *op,
212+
struct ompi_communicator_t *comm,
213+
ompi_request_t** request,
214+
mca_coll_base_module_t *module);
215+
198216
END_C_DECLS
199217
#endif

ompi/mca/coll/ucc/coll_ucc_component.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ static ucc_coll_type_t mca_coll_ucc_str_to_type(const char *str)
120120
return UCC_COLL_TYPE_ALLGATHERV;
121121
} else if (0 == strcasecmp(str, "reduce")) {
122122
return UCC_COLL_TYPE_REDUCE;
123+
} else if (0 == strcasecmp(str, "reduce_scatter_block")) {
124+
return UCC_COLL_TYPE_REDUCE_SCATTER;
123125
}
124126
UCC_ERROR("incorrect value for cts: %s, allowed: %s",
125127
str, COLL_UCC_CTS_STR);

ompi/mca/coll/ucc/coll_ucc_module.c

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,27 @@ int mca_coll_ucc_init_query(bool enable_progress_threads, bool enable_mpi_thread
2929

3030
static void mca_coll_ucc_module_clear(mca_coll_ucc_module_t *ucc_module)
3131
{
32-
ucc_module->ucc_team = NULL;
33-
ucc_module->previous_allreduce = NULL;
34-
ucc_module->previous_iallreduce = NULL;
35-
ucc_module->previous_barrier = NULL;
36-
ucc_module->previous_ibarrier = NULL;
37-
ucc_module->previous_bcast = NULL;
38-
ucc_module->previous_ibcast = NULL;
39-
ucc_module->previous_alltoall = NULL;
40-
ucc_module->previous_ialltoall = NULL;
41-
ucc_module->previous_alltoallv = NULL;
42-
ucc_module->previous_ialltoallv = NULL;
43-
ucc_module->previous_allgather = NULL;
44-
ucc_module->previous_iallgather = NULL;
45-
ucc_module->previous_allgatherv = NULL;
46-
ucc_module->previous_iallgatherv = NULL;
47-
ucc_module->previous_reduce = NULL;
48-
ucc_module->previous_ireduce = NULL;
32+
ucc_module->ucc_team = NULL;
33+
ucc_module->previous_allreduce = NULL;
34+
ucc_module->previous_iallreduce = NULL;
35+
ucc_module->previous_barrier = NULL;
36+
ucc_module->previous_ibarrier = NULL;
37+
ucc_module->previous_bcast = NULL;
38+
ucc_module->previous_ibcast = NULL;
39+
ucc_module->previous_alltoall = NULL;
40+
ucc_module->previous_ialltoall = NULL;
41+
ucc_module->previous_alltoallv = NULL;
42+
ucc_module->previous_ialltoallv = NULL;
43+
ucc_module->previous_allgather = NULL;
44+
ucc_module->previous_iallgather = NULL;
45+
ucc_module->previous_allgatherv = NULL;
46+
ucc_module->previous_iallgatherv = NULL;
47+
ucc_module->previous_reduce = NULL;
48+
ucc_module->previous_ireduce = NULL;
49+
ucc_module->previous_reduce_scatter_block = NULL;
50+
ucc_module->previous_reduce_scatter_block_module = NULL;
51+
ucc_module->previous_ireduce_scatter_block = NULL;
52+
ucc_module->previous_ireduce_scatter_block_module = NULL;
4953
}
5054

5155
static void mca_coll_ucc_module_construct(mca_coll_ucc_module_t *ucc_module)
@@ -82,6 +86,8 @@ static void mca_coll_ucc_module_destruct(mca_coll_ucc_module_t *ucc_module)
8286
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_iallgatherv_module);
8387
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_reduce_module);
8488
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ireduce_module);
89+
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_reduce_scatter_block_module);
90+
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ireduce_scatter_block_module);
8591
mca_coll_ucc_module_clear(ucc_module);
8692
}
8793

@@ -113,6 +119,8 @@ static int mca_coll_ucc_save_coll_handlers(mca_coll_ucc_module_t *ucc_module)
113119
SAVE_PREV_COLL_API(iallgatherv);
114120
SAVE_PREV_COLL_API(reduce);
115121
SAVE_PREV_COLL_API(ireduce);
122+
SAVE_PREV_COLL_API(reduce_scatter_block);
123+
SAVE_PREV_COLL_API(ireduce_scatter_block);
116124
return OMPI_SUCCESS;
117125
}
118126

@@ -491,14 +499,15 @@ mca_coll_ucc_comm_query(struct ompi_communicator_t *comm, int *priority)
491499
ucc_module->comm = comm;
492500
ucc_module->super.coll_module_enable = mca_coll_ucc_module_enable;
493501
*priority = cm->ucc_priority;
494-
SET_COLL_PTR(ucc_module, BARRIER, barrier);
495-
SET_COLL_PTR(ucc_module, BCAST, bcast);
496-
SET_COLL_PTR(ucc_module, ALLREDUCE, allreduce);
497-
SET_COLL_PTR(ucc_module, ALLTOALL, alltoall);
498-
SET_COLL_PTR(ucc_module, ALLTOALLV, alltoallv);
499-
SET_COLL_PTR(ucc_module, REDUCE, reduce);
500-
SET_COLL_PTR(ucc_module, ALLGATHER, allgather);
501-
SET_COLL_PTR(ucc_module, ALLGATHERV, allgatherv);
502+
SET_COLL_PTR(ucc_module, BARRIER, barrier);
503+
SET_COLL_PTR(ucc_module, BCAST, bcast);
504+
SET_COLL_PTR(ucc_module, ALLREDUCE, allreduce);
505+
SET_COLL_PTR(ucc_module, ALLTOALL, alltoall);
506+
SET_COLL_PTR(ucc_module, ALLTOALLV, alltoallv);
507+
SET_COLL_PTR(ucc_module, REDUCE, reduce);
508+
SET_COLL_PTR(ucc_module, ALLGATHER, allgather);
509+
SET_COLL_PTR(ucc_module, ALLGATHERV, allgatherv);
510+
SET_COLL_PTR(ucc_module, REDUCE_SCATTER, reduce_scatter_block);
502511
return &ucc_module->super;
503512
}
504513

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/**
2+
* Copyright (c) 2021 Mellanox Technologies. All rights reserved.
3+
* Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
4+
* $COPYRIGHT$
5+
*
6+
* Additional copyrights may follow
7+
*
8+
*/
9+
10+
#include "coll_ucc_common.h"
11+
12+
static inline
13+
ucc_status_t mca_coll_ucc_reduce_scatter_block_init(const void *sbuf, void *rbuf,
14+
size_t rcount,
15+
struct ompi_datatype_t *dtype,
16+
struct ompi_op_t *op,
17+
mca_coll_ucc_module_t *ucc_module,
18+
ucc_coll_req_h *req,
19+
mca_coll_ucc_req_t *coll_req)
20+
{
21+
ucc_datatype_t ucc_dt;
22+
ucc_reduction_op_t ucc_op;
23+
int comm_size = ompi_comm_size(ucc_module->comm);
24+
25+
if (MPI_IN_PLACE == sbuf) {
26+
/* TODO: UCC defines inplace differently:
27+
data in rbuf of rank R is shifted by R * rcount */
28+
UCC_VERBOSE(5, "inplace reduce_scatter_block is not supported");
29+
return UCC_ERR_NOT_SUPPORTED;
30+
}
31+
ucc_dt = ompi_dtype_to_ucc_dtype(dtype);
32+
ucc_op = ompi_op_to_ucc_op(op);
33+
if (OPAL_UNLIKELY(COLL_UCC_DT_UNSUPPORTED == ucc_dt)) {
34+
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
35+
dtype->super.name);
36+
goto fallback;
37+
}
38+
if (OPAL_UNLIKELY(COLL_UCC_OP_UNSUPPORTED == ucc_op)) {
39+
UCC_VERBOSE(5, "ompi_op is not supported: op = %s",
40+
op->o_name);
41+
goto fallback;
42+
}
43+
ucc_coll_args_t coll = {
44+
.mask = 0,
45+
.flags = 0,
46+
.coll_type = UCC_COLL_TYPE_REDUCE_SCATTER,
47+
.src.info = {
48+
.buffer = (void*)sbuf,
49+
.count = ((size_t)rcount) * comm_size,
50+
.datatype = ucc_dt,
51+
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
52+
},
53+
.dst.info = {
54+
.buffer = rbuf,
55+
.count = rcount,
56+
.datatype = ucc_dt,
57+
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
58+
},
59+
.op = ucc_op,
60+
};
61+
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
62+
return UCC_OK;
63+
fallback:
64+
return UCC_ERR_NOT_SUPPORTED;
65+
}
66+
67+
int mca_coll_ucc_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
68+
struct ompi_datatype_t *dtype,
69+
struct ompi_op_t *op,
70+
struct ompi_communicator_t *comm,
71+
mca_coll_base_module_t *module)
72+
{
73+
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
74+
ucc_coll_req_h req;
75+
76+
UCC_VERBOSE(3, "running ucc reduce scatter block");
77+
COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_init(sbuf, rbuf, rcount,
78+
dtype, op, ucc_module,
79+
&req, NULL));
80+
COLL_UCC_POST_AND_CHECK(req);
81+
COLL_UCC_CHECK(coll_ucc_req_wait(req));
82+
return OMPI_SUCCESS;
83+
fallback:
84+
UCC_VERBOSE(3, "running fallback reduce_scatter_block");
85+
return ucc_module->previous_reduce_scatter_block(sbuf, rbuf, rcount, dtype,
86+
op, comm,
87+
ucc_module->previous_reduce_scatter_block_module);
88+
}
89+
90+
int mca_coll_ucc_ireduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
91+
struct ompi_datatype_t *dtype,
92+
struct ompi_op_t *op,
93+
struct ompi_communicator_t *comm,
94+
ompi_request_t** request,
95+
mca_coll_base_module_t *module)
96+
{
97+
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
98+
ucc_coll_req_h req;
99+
mca_coll_ucc_req_t *coll_req = NULL;
100+
101+
UCC_VERBOSE(3, "running ucc ireduce_scatter_block");
102+
COLL_UCC_GET_REQ(coll_req);
103+
COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_init(sbuf, rbuf, rcount,
104+
dtype, op, ucc_module,
105+
&req, coll_req));
106+
COLL_UCC_POST_AND_CHECK(req);
107+
*request = &coll_req->super;
108+
return OMPI_SUCCESS;
109+
fallback:
110+
UCC_VERBOSE(3, "running fallback ireduce_scatter_block");
111+
if (coll_req) {
112+
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
113+
}
114+
return ucc_module->previous_ireduce_scatter_block(sbuf, rbuf, rcount, dtype,
115+
op, comm, request,
116+
ucc_module->previous_ireduce_scatter_block_module);
117+
}

0 commit comments

Comments
 (0)