Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion contrib/platform/mellanox/optimized
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
enable_mca_no_build=coll-ml,btl-uct
enable_debug_symbols=yes
enable_orterun_prefix_by_default=yes
enable_mpirun_prefix_by_default=yes
with_devel_headers=yes
enable_oshmem=yes
enable_oshmem_fortran=yes
Expand Down
5 changes: 5 additions & 0 deletions ompi/mca/coll/ucc/coll_ucc.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ BEGIN_C_DECLS
"iallgatherv,ireduce,igather,igatherv,ireduce_scatter_block,"\
"ireduce_scatter,iscatterv,iscatter"

#define mca_coll_ucc_call_previous(__api, ucc_module, ...) \
(ucc_module->previous_ ## __api == NULL ? \
UCC_ERR_NOT_SUPPORTED : \
ucc_module->previous_ ## __api (__VA_ARGS__, ucc_module->previous_ ## __api ## _module))

typedef struct mca_coll_ucc_req {
ompi_request_t super;
ucc_coll_req_h ucc_req;
Expand Down
10 changes: 6 additions & 4 deletions ompi/mca/coll/ucc/coll_ucc_allgather.c
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ int mca_coll_ucc_allgather(const void *sbuf, int scount, struct ompi_datatype_t
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback allgather");
return ucc_module->previous_allgather(sbuf, scount, sdtype, rbuf, rcount, rdtype,
comm, ucc_module->previous_allgather_module);

return mca_coll_ucc_call_previous(allgather, ucc_module,
sbuf, scount, sdtype, rbuf, rcount, rdtype, comm);
}

int mca_coll_ucc_iallgather(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
Expand All @@ -104,6 +105,7 @@ int mca_coll_ucc_iallgather(const void *sbuf, int scount, struct ompi_datatype_t
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_iallgather(sbuf, scount, sdtype, rbuf, rcount, rdtype,
comm, request, ucc_module->previous_iallgather_module);

return mca_coll_ucc_call_previous(iallgather, ucc_module,
sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, request);
}
10 changes: 4 additions & 6 deletions ompi/mca/coll/ucc/coll_ucc_allgatherv.c
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,8 @@ int mca_coll_ucc_allgatherv(const void *sbuf, int scount,
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback allgatherv");
return ucc_module->previous_allgatherv(sbuf, scount, sdtype,
rbuf, rcounts, rdisps, rdtype,
comm, ucc_module->previous_allgatherv_module);
return mca_coll_ucc_call_previous(allgatherv, ucc_module,
sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype, comm);
}

int mca_coll_ucc_iallgatherv(const void *sbuf, int scount,
Expand Down Expand Up @@ -108,7 +107,6 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, int scount,
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_iallgatherv(sbuf, scount, sdtype,
rbuf, rcounts, rdisps, rdtype,
comm, request, ucc_module->previous_iallgatherv_module);
return mca_coll_ucc_call_previous(iallgatherv, ucc_module,
sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype, comm, request);
}
8 changes: 4 additions & 4 deletions ompi/mca/coll/ucc/coll_ucc_allreduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ int mca_coll_ucc_allreduce(const void *sbuf, void *rbuf, int count,
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback allreduce");
return ucc_module->previous_allreduce(sbuf, rbuf, count, dtype, op,
comm, ucc_module->previous_allreduce_module);
return mca_coll_ucc_call_previous(allreduce, ucc_module,
sbuf, rbuf, count, dtype, op, comm);
}

int mca_coll_ucc_iallreduce(const void *sbuf, void *rbuf, int count,
Expand All @@ -100,6 +100,6 @@ int mca_coll_ucc_iallreduce(const void *sbuf, void *rbuf, int count,
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_iallreduce(sbuf, rbuf, count, dtype, op,
comm, request, ucc_module->previous_iallreduce_module);
return mca_coll_ucc_call_previous(iallreduce, ucc_module,
sbuf, rbuf, count, dtype, op, comm, request);
}
8 changes: 4 additions & 4 deletions ompi/mca/coll/ucc/coll_ucc_alltoall.c
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ int mca_coll_ucc_alltoall(const void *sbuf, int scount, struct ompi_datatype_t *
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback alltoall");
return ucc_module->previous_alltoall(sbuf, scount, sdtype, rbuf, rcount, rdtype,
comm, ucc_module->previous_alltoall_module);
return mca_coll_ucc_call_previous(alltoall, ucc_module,
sbuf, scount, sdtype, rbuf, rcount, rdtype, comm);
}

int mca_coll_ucc_ialltoall(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
Expand All @@ -104,6 +104,6 @@ int mca_coll_ucc_ialltoall(const void *sbuf, int scount, struct ompi_datatype_t
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_ialltoall(sbuf, scount, sdtype, rbuf, rcount, rdtype,
comm, request, ucc_module->previous_ialltoall_module);
return mca_coll_ucc_call_previous(ialltoall, ucc_module,
sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, request);
}
10 changes: 4 additions & 6 deletions ompi/mca/coll/ucc/coll_ucc_alltoallv.c
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ int mca_coll_ucc_alltoallv(const void *sbuf, const int *scounts,
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback alltoallv");
return ucc_module->previous_alltoallv(sbuf, scounts, sdisps, sdtype,
rbuf, rcounts, rdisps, rdtype,
comm, ucc_module->previous_alltoallv_module);
return mca_coll_ucc_call_previous(alltoallv, ucc_module,
sbuf, scounts, sdisps, sdtype, rbuf, rcounts, rdisps, rdtype, comm);
}

int mca_coll_ucc_ialltoallv(const void *sbuf, const int *scounts,
Expand Down Expand Up @@ -109,7 +108,6 @@ int mca_coll_ucc_ialltoallv(const void *sbuf, const int *scounts,
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_ialltoallv(sbuf, scounts, sdisps, sdtype,
rbuf, rcounts, rdisps, rdtype,
comm, request, ucc_module->previous_ialltoallv_module);
return mca_coll_ucc_call_previous(ialltoallv, ucc_module,
sbuf, scounts, sdisps, sdtype, rbuf, rcounts, rdisps, rdtype, comm, request);
}
5 changes: 2 additions & 3 deletions ompi/mca/coll/ucc/coll_ucc_barrier.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ int mca_coll_ucc_barrier(struct ompi_communicator_t *comm,
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback barrier");
return ucc_module->previous_barrier(comm, ucc_module->previous_barrier_module);
return mca_coll_ucc_call_previous(barrier, ucc_module, comm);
}

int mca_coll_ucc_ibarrier(struct ompi_communicator_t *comm,
Expand All @@ -58,6 +58,5 @@ int mca_coll_ucc_ibarrier(struct ompi_communicator_t *comm,
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_ibarrier(comm, request,
ucc_module->previous_ibarrier_module);
return mca_coll_ucc_call_previous(ibarrier, ucc_module, comm, request);
}
9 changes: 5 additions & 4 deletions ompi/mca/coll/ucc/coll_ucc_bcast.c
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ int mca_coll_ucc_bcast(void *buf, int count, struct ompi_datatype_t *dtype,
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback bcast");
return ucc_module->previous_bcast(buf, count, dtype, root,
comm, ucc_module->previous_bcast_module);
return mca_coll_ucc_call_previous(bcast, ucc_module,
buf, count, dtype, root, comm);

}

int mca_coll_ucc_ibcast(void *buf, int count, struct ompi_datatype_t *dtype,
Expand All @@ -76,6 +77,6 @@ int mca_coll_ucc_ibcast(void *buf, int count, struct ompi_datatype_t *dtype,
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_ibcast(buf, count, dtype, root,
comm, request, ucc_module->previous_ibcast_module);
return mca_coll_ucc_call_previous(ibcast, ucc_module,
buf, count, dtype, root, comm, request);
}
10 changes: 4 additions & 6 deletions ompi/mca/coll/ucc/coll_ucc_gather.c
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,8 @@ int mca_coll_ucc_gather(const void *sbuf, int scount, struct ompi_datatype_t *sd
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback gather");
return ucc_module->previous_gather(sbuf, scount, sdtype, rbuf, rcount,
rdtype, root, comm,
ucc_module->previous_gather_module);
return mca_coll_ucc_call_previous(gather, ucc_module,
sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm);
}

int mca_coll_ucc_igather(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
Expand All @@ -119,7 +118,6 @@ int mca_coll_ucc_igather(const void *sbuf, int scount, struct ompi_datatype_t *s
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_igather(sbuf, scount, sdtype, rbuf, rcount,
rdtype, root, comm, request,
ucc_module->previous_igather_module);
return mca_coll_ucc_call_previous(igather, ucc_module,
sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, request);
}
10 changes: 4 additions & 6 deletions ompi/mca/coll/ucc/coll_ucc_gatherv.c
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,8 @@ int mca_coll_ucc_gatherv(const void *sbuf, int scount, struct ompi_datatype_t *s
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback gatherv");
return ucc_module->previous_gatherv(sbuf, scount, sdtype, rbuf, rcounts,
disps, rdtype, root, comm,
ucc_module->previous_gatherv_module);
return mca_coll_ucc_call_previous(gatherv, ucc_module,
sbuf, scount, sdtype, rbuf, rcounts, disps, rdtype, root, comm);
}

int mca_coll_ucc_igatherv(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
Expand All @@ -115,7 +114,6 @@ int mca_coll_ucc_igatherv(const void *sbuf, int scount, struct ompi_datatype_t *
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_igatherv(sbuf, scount, sdtype, rbuf, rcounts,
disps, rdtype, root, comm, request,
ucc_module->previous_igatherv_module);
return mca_coll_ucc_call_previous(igatherv, ucc_module,
sbuf, scount, sdtype, rbuf, rcounts, disps, rdtype, root, comm, request);
}
16 changes: 6 additions & 10 deletions ompi/mca/coll/ucc/coll_ucc_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,12 @@ static void mca_coll_ucc_module_destruct(mca_coll_ucc_module_t *ucc_module)
#define SAVE_PREV_COLL_API(__api) do { \
ucc_module->previous_ ## __api = comm->c_coll->coll_ ## __api; \
ucc_module->previous_ ## __api ## _module = comm->c_coll->coll_ ## __api ## _module; \
if (!comm->c_coll->coll_ ## __api || !comm->c_coll->coll_ ## __api ## _module) { \
return OMPI_ERROR; \
if (comm->c_coll->coll_ ## __api && comm->c_coll->coll_ ## __api ## _module) { \
OBJ_RETAIN(ucc_module->previous_ ## __api ## _module); \
} \
OBJ_RETAIN(ucc_module->previous_ ## __api ## _module); \
} while(0)

static int mca_coll_ucc_save_coll_handlers(mca_coll_ucc_module_t *ucc_module)
static void mca_coll_ucc_save_coll_handlers(mca_coll_ucc_module_t *ucc_module)
{
ompi_communicator_t *comm = ucc_module->comm;
SAVE_PREV_COLL_API(allreduce);
Expand Down Expand Up @@ -178,7 +177,6 @@ static int mca_coll_ucc_save_coll_handlers(mca_coll_ucc_module_t *ucc_module)
SAVE_PREV_COLL_API(iscatterv);
SAVE_PREV_COLL_API(scatter);
SAVE_PREV_COLL_API(iscatter);
return OMPI_SUCCESS;
}

/*
Expand Down Expand Up @@ -470,11 +468,6 @@ static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module,
(void*)comm, (long long unsigned)team_params.id,
ompi_comm_size(comm));

if (OMPI_SUCCESS != mca_coll_ucc_save_coll_handlers(ucc_module)){
UCC_ERROR("mca_coll_ucc_save_coll_handlers failed");
goto err;
}

if (UCC_OK != ucc_team_create_post(&cm->ucc_context, 1,
&team_params, &ucc_module->ucc_team)) {
UCC_ERROR("ucc_team_create_post failed");
Expand All @@ -495,6 +488,9 @@ static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module,
UCC_ERROR("ucc ompi_attr_set_c failed");
goto err;
}

mca_coll_ucc_save_coll_handlers(ucc_module);

return OMPI_SUCCESS;

err:
Expand Down
8 changes: 4 additions & 4 deletions ompi/mca/coll/ucc/coll_ucc_reduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ int mca_coll_ucc_reduce(const void *sbuf, void* rbuf, int count,
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback reduce");
return ucc_module->previous_reduce(sbuf, rbuf, count, dtype, op, root,
comm, ucc_module->previous_reduce_module);
return mca_coll_ucc_call_previous(reduce, ucc_module,
sbuf, rbuf, count, dtype, op, root, comm);
}

int mca_coll_ucc_ireduce(const void *sbuf, void* rbuf, int count,
Expand All @@ -103,6 +103,6 @@ int mca_coll_ucc_ireduce(const void *sbuf, void* rbuf, int count,
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_ireduce(sbuf, rbuf, count, dtype, op, root,
comm, request, ucc_module->previous_ireduce_module);
return mca_coll_ucc_call_previous(ireduce, ucc_module,
sbuf, rbuf, count, dtype, op, root, comm, request);
}
10 changes: 4 additions & 6 deletions ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,8 @@ int mca_coll_ucc_reduce_scatter(const void *sbuf, void *rbuf, const int *rcounts
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback reduce_scatter");
return ucc_module->previous_reduce_scatter(sbuf, rbuf, rcounts, dtype, op,
comm,
ucc_module->previous_reduce_scatter_module);
return mca_coll_ucc_call_previous(reduce_scatter, ucc_module,
sbuf, rbuf, rcounts, dtype, op, comm);
}

int mca_coll_ucc_ireduce_scatter(const void *sbuf, void *rbuf, const int *rcounts,
Expand All @@ -115,7 +114,6 @@ int mca_coll_ucc_ireduce_scatter(const void *sbuf, void *rbuf, const int *rcount
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_ireduce_scatter(sbuf, rbuf, rcounts, dtype, op,
comm, request,
ucc_module->previous_ireduce_scatter_module);
return mca_coll_ucc_call_previous(ireduce_scatter, ucc_module,
sbuf, rbuf, rcounts, dtype, op, comm, request);
}
10 changes: 4 additions & 6 deletions ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,8 @@ int mca_coll_ucc_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback reduce_scatter_block");
return ucc_module->previous_reduce_scatter_block(sbuf, rbuf, rcount, dtype,
op, comm,
ucc_module->previous_reduce_scatter_block_module);
return mca_coll_ucc_call_previous(reduce_scatter_block, ucc_module,
sbuf, rbuf, rcount, dtype, op, comm);
}

int mca_coll_ucc_ireduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
Expand All @@ -111,7 +110,6 @@ int mca_coll_ucc_ireduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_ireduce_scatter_block(sbuf, rbuf, rcount, dtype,
op, comm, request,
ucc_module->previous_ireduce_scatter_block_module);
return mca_coll_ucc_call_previous(ireduce_scatter_block, ucc_module,
sbuf, rbuf, rcount, dtype, op, comm, request);
}
11 changes: 4 additions & 7 deletions ompi/mca/coll/ucc/coll_ucc_scatter.c
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,8 @@ int mca_coll_ucc_scatter(const void *sbuf, int scount,
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback scatter");
return ucc_module->previous_scatter(sbuf, scount, sdtype, rbuf, rcount,
rdtype, root, comm,
ucc_module->previous_scatter_module);

return mca_coll_ucc_call_previous(scatter, ucc_module,
sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm);
}

int mca_coll_ucc_iscatter(const void *sbuf, int scount,
Expand All @@ -117,7 +115,6 @@ int mca_coll_ucc_iscatter(const void *sbuf, int scount,
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_iscatter(sbuf, scount, sdtype, rbuf, rcount,
rdtype, root, comm, request,
ucc_module->previous_iscatter_module);
return mca_coll_ucc_call_previous(iscatter, ucc_module,
sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, request);
}
10 changes: 4 additions & 6 deletions ompi/mca/coll/ucc/coll_ucc_scatterv.c
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,8 @@ int mca_coll_ucc_scatterv(const void *sbuf, const int *scounts,
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback scatterv");
return ucc_module->previous_scatterv(sbuf, scounts, disps, sdtype, rbuf,
rcount, rdtype, root, comm,
ucc_module->previous_scatterv_module);
return mca_coll_ucc_call_previous(scatterv, ucc_module,
sbuf, scounts, disps, sdtype, rbuf, rcount, rdtype, root, comm);
}

int mca_coll_ucc_iscatterv(const void *sbuf, const int *scounts,
Expand Down Expand Up @@ -120,7 +119,6 @@ int mca_coll_ucc_iscatterv(const void *sbuf, const int *scounts,
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_iscatterv(sbuf, scounts, disps, sdtype, rbuf,
rcount, rdtype, root, comm, request,
ucc_module->previous_iscatterv_module);
return mca_coll_ucc_call_previous(iscatterv, ucc_module,
sbuf, scounts, disps, sdtype, rbuf, rcount, rdtype, root, comm, request);
}