Skip to content

Commit 856136e

Browse files
authored
Merge pull request open-mpi#12347 from wenduwan/v5.0.x_backport_alltoall_fixes
[v5.0.x] Backport alltoallv fixes
2 parents 2f2ddaf + 4d135f8 commit 856136e

File tree

2 files changed

+85
-43
lines changed

2 files changed

+85
-43
lines changed

ompi/mca/coll/base/coll_base_alltoallv.c

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,9 @@ ompi_coll_base_alltoallv_intra_pairwise(const void *sbuf, const int *scounts, co
199199
mca_coll_base_module_t *module)
200200
{
201201
int line = -1, err = 0, rank, size, step = 0, sendto, recvfrom;
202+
size_t sdtype_size, rdtype_size;
202203
void *psnd, *prcv;
204+
ompi_request_t *req;
203205
ptrdiff_t sext, rext;
204206

205207
if (MPI_IN_PLACE == sbuf) {
@@ -213,11 +215,15 @@ ompi_coll_base_alltoallv_intra_pairwise(const void *sbuf, const int *scounts, co
213215
OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
214216
"coll:base:alltoallv_intra_pairwise rank %d", rank));
215217

218+
ompi_datatype_type_size(sdtype, &sdtype_size);
219+
ompi_datatype_type_size(rdtype, &rdtype_size);
220+
216221
ompi_datatype_type_extent(sdtype, &sext);
217222
ompi_datatype_type_extent(rdtype, &rext);
218223

219224
/* Perform pairwise exchange starting from 1 since local exchange is done */
220225
for (step = 0; step < size; step++) {
226+
req = MPI_REQUEST_NULL;
221227

222228
/* Determine sender and receiver for this step. */
223229
sendto = (rank + step) % size;
@@ -228,12 +234,31 @@ ompi_coll_base_alltoallv_intra_pairwise(const void *sbuf, const int *scounts, co
228234
prcv = (char*)rbuf + (ptrdiff_t)rdisps[recvfrom] * rext;
229235

230236
/* send and receive */
231-
err = ompi_coll_base_sendrecv( psnd, scounts[sendto], sdtype, sendto,
232-
MCA_COLL_BASE_TAG_ALLTOALLV,
233-
prcv, rcounts[recvfrom], rdtype, recvfrom,
234-
MCA_COLL_BASE_TAG_ALLTOALLV,
235-
comm, MPI_STATUS_IGNORE, rank);
236-
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }
237+
if (0 < rcounts[recvfrom] && 0 < rdtype_size) {
238+
err = MCA_PML_CALL(irecv(prcv, rcounts[recvfrom], rdtype, recvfrom,
239+
MCA_COLL_BASE_TAG_ALLTOALLV, comm, &req));
240+
if (MPI_SUCCESS != err) {
241+
line = __LINE__;
242+
goto err_hndl;
243+
}
244+
}
245+
246+
if (0 < scounts[sendto] && 0 < sdtype_size) {
247+
err = MCA_PML_CALL(send(psnd, scounts[sendto], sdtype, sendto,
248+
MCA_COLL_BASE_TAG_ALLTOALLV, MCA_PML_BASE_SEND_STANDARD, comm));
249+
if (MPI_SUCCESS != err) {
250+
line = __LINE__;
251+
goto err_hndl;
252+
}
253+
}
254+
255+
if (MPI_REQUEST_NULL != req) {
256+
err = ompi_request_wait(&req, MPI_STATUS_IGNORE);
257+
if (MPI_SUCCESS != err) {
258+
line = __LINE__;
259+
goto err_hndl;
260+
}
261+
}
237262
}
238263

239264
return MPI_SUCCESS;
@@ -263,6 +288,7 @@ ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, const int *scounts
263288
mca_coll_base_module_t *module)
264289
{
265290
int i, size, rank, err, nreqs;
291+
size_t sdtype_size = 0, rdtype_size = 0;
266292
char *psnd, *prcv;
267293
ptrdiff_t sext, rext;
268294
ompi_request_t **preq, **reqs;
@@ -280,13 +306,16 @@ ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, const int *scounts
280306
OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
281307
"coll:base:alltoallv_intra_basic_linear rank %d", rank));
282308

309+
ompi_datatype_type_size(rdtype, &rdtype_size);
310+
ompi_datatype_type_size(sdtype, &sdtype_size);
311+
283312
ompi_datatype_type_extent(sdtype, &sext);
284313
ompi_datatype_type_extent(rdtype, &rext);
285314

286315
/* Simple optimization - handle send to self first */
287316
psnd = ((char *) sbuf) + (ptrdiff_t)sdisps[rank] * sext;
288317
prcv = ((char *) rbuf) + (ptrdiff_t)rdisps[rank] * rext;
289-
if (0 != scounts[rank]) {
318+
if (0 < scounts[rank] && 0 < sdtype_size) {
290319
err = ompi_datatype_sndrcv(psnd, scounts[rank], sdtype,
291320
prcv, rcounts[rank], rdtype);
292321
if (MPI_SUCCESS != err) {
@@ -310,7 +339,7 @@ ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, const int *scounts
310339
continue;
311340
}
312341

313-
if (rcounts[i] > 0) {
342+
if (0 < rcounts[i] && 0 < rdtype_size) {
314343
++nreqs;
315344
prcv = ((char *) rbuf) + (ptrdiff_t)rdisps[i] * rext;
316345
err = MCA_PML_CALL(irecv_init(prcv, rcounts[i], rdtype,
@@ -326,7 +355,7 @@ ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, const int *scounts
326355
continue;
327356
}
328357

329-
if (scounts[i] > 0) {
358+
if (0 < scounts[i] && 0 < sdtype_size) {
330359
++nreqs;
331360
psnd = ((char *) sbuf) + (ptrdiff_t)sdisps[i] * sext;
332361
err = MCA_PML_CALL(isend_init(psnd, scounts[i], sdtype,

ompi/mca/coll/libnbc/nbc_ialltoallv.c

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,19 @@
2323

2424
static inline int a2av_sched_linear(int rank, int p, NBC_Schedule *schedule,
2525
const void *sendbuf, const int *sendcounts,
26-
const int *sdispls, MPI_Aint sndext, MPI_Datatype sendtype,
26+
const int *sdispls, MPI_Aint sndext, MPI_Datatype sendtype, const size_t sdtype_size,
2727
void *recvbuf, const int *recvcounts,
28-
const int *rdispls, MPI_Aint rcvext, MPI_Datatype recvtype);
28+
const int *rdispls, MPI_Aint rcvext, MPI_Datatype recvtype, const size_t rdtype_size);
2929

3030
static inline int a2av_sched_pairwise(int rank, int p, NBC_Schedule *schedule,
3131
const void *sendbuf, const int *sendcounts, const int *sdispls,
32-
MPI_Aint sndext, MPI_Datatype sendtype,
32+
MPI_Aint sndext, MPI_Datatype sendtype, const size_t sdtype_size,
3333
void *recvbuf, const int *recvcounts, const int *rdispls,
34-
MPI_Aint rcvext, MPI_Datatype recvtype);
34+
MPI_Aint rcvext, MPI_Datatype recvtype, const size_t rdtype_size);
3535

3636
static inline int a2av_sched_inplace(int rank, int p, NBC_Schedule *schedule,
3737
void *buf, const int *counts, const int *displs,
38-
MPI_Aint ext, MPI_Datatype type, ptrdiff_t gap);
38+
MPI_Aint ext, MPI_Datatype type, const size_t dtype_size, ptrdiff_t gap);
3939

4040
/* an alltoallv schedule can not be cached easily because the contents
4141
* of the recvcounts array may change, so a comparison of the address
@@ -48,6 +48,7 @@ static int nbc_alltoallv_init(const void* sendbuf, const int *sendcounts, const
4848
mca_coll_base_module_t *module, bool persistent)
4949
{
5050
int rank, p, res;
51+
size_t sdtype_size, rdtype_size;
5152
MPI_Aint sndext, rcvext;
5253
NBC_Schedule *schedule;
5354
char *rbuf, *sbuf, inplace;
@@ -60,6 +61,7 @@ static int nbc_alltoallv_init(const void* sendbuf, const int *sendcounts, const
6061
rank = ompi_comm_rank (comm);
6162
p = ompi_comm_size (comm);
6263

64+
ompi_datatype_type_size(recvtype, &rdtype_size);
6365
res = ompi_datatype_type_extent (recvtype, &rcvext);
6466
if (MPI_SUCCESS != res) {
6567
NBC_Error("MPI Error in ompi_datatype_type_extent() (%i)", res);
@@ -92,7 +94,9 @@ static int nbc_alltoallv_init(const void* sendbuf, const int *sendcounts, const
9294
sendcounts = recvcounts;
9395
sdispls = rdispls;
9496
sndext = rcvext;
97+
sdtype_size = rdtype_size;
9598
} else {
99+
ompi_datatype_type_size(sendtype, &sdtype_size);
96100
res = ompi_datatype_type_extent (sendtype, &sndext);
97101
if (MPI_SUCCESS != res) {
98102
NBC_Error("MPI Error in ompi_datatype_type_extent() (%i)", res);
@@ -106,8 +110,7 @@ static int nbc_alltoallv_init(const void* sendbuf, const int *sendcounts, const
106110
return OMPI_ERR_OUT_OF_RESOURCE;
107111
}
108112

109-
110-
if (!inplace && sendcounts[rank] != 0) {
113+
if (!inplace && 0 < sendcounts[rank] && 0 < sdtype_size) {
111114
rbuf = (char *) recvbuf + rdispls[rank] * rcvext;
112115
sbuf = (char *) sendbuf + sdispls[rank] * sndext;
113116
res = NBC_Sched_copy (sbuf, false, sendcounts[rank], sendtype,
@@ -119,12 +122,12 @@ static int nbc_alltoallv_init(const void* sendbuf, const int *sendcounts, const
119122
}
120123

121124
if (inplace) {
122-
res = a2av_sched_inplace(rank, p, schedule, recvbuf, recvcounts,
123-
rdispls, rcvext, recvtype, gap);
125+
res = a2av_sched_inplace(rank, p, schedule, recvbuf, recvcounts, rdispls, rcvext, recvtype,
126+
rdtype_size, gap);
124127
} else {
125128
res = a2av_sched_linear(rank, p, schedule,
126-
sendbuf, sendcounts, sdispls, sndext, sendtype,
127-
recvbuf, recvcounts, rdispls, rcvext, recvtype);
129+
sendbuf, sendcounts, sdispls, sndext, sendtype, sdtype_size,
130+
recvbuf, recvcounts, rdispls, rcvext, recvtype, rdtype_size);
128131
}
129132
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
130133
OBJ_RELEASE(schedule);
@@ -177,10 +180,13 @@ static int nbc_alltoallv_inter_init (const void* sendbuf, const int *sendcounts,
177180
mca_coll_base_module_t *module, bool persistent)
178181
{
179182
int res, rsize;
183+
size_t sdtype_size, rdtype_size;
180184
MPI_Aint sndext, rcvext;
181185
NBC_Schedule *schedule;
182186
ompi_coll_libnbc_module_t *libnbc_module = (ompi_coll_libnbc_module_t*) module;
183187

188+
ompi_datatype_type_size(sendtype, &sdtype_size);
189+
ompi_datatype_type_size(recvtype, &rdtype_size);
184190

185191
res = ompi_datatype_type_extent(sendtype, &sndext);
186192
if (MPI_SUCCESS != res) {
@@ -203,7 +209,7 @@ static int nbc_alltoallv_inter_init (const void* sendbuf, const int *sendcounts,
203209

204210
for (int i = 0; i < rsize; i++) {
205211
/* post all sends */
206-
if (sendcounts[i] != 0) {
212+
if (0 < sendcounts[i] && 0 < sdtype_size) {
207213
char *sbuf = (char *) sendbuf + sdispls[i] * sndext;
208214
res = NBC_Sched_send (sbuf, false, sendcounts[i], sendtype, i, schedule, false);
209215
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
@@ -212,7 +218,7 @@ static int nbc_alltoallv_inter_init (const void* sendbuf, const int *sendcounts,
212218
}
213219
}
214220
/* post all receives */
215-
if (recvcounts[i] != 0) {
221+
if (0 < recvcounts[i] && 0 < rdtype_size) {
216222
char *rbuf = (char *) recvbuf + rdispls[i] * rcvext;
217223
res = NBC_Sched_recv (rbuf, false, recvcounts[i], recvtype, i, schedule, false);
218224
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
@@ -261,9 +267,9 @@ int ompi_coll_libnbc_ialltoallv_inter (const void* sendbuf, const int *sendcount
261267
__opal_attribute_unused__
262268
static inline int a2av_sched_linear(int rank, int p, NBC_Schedule *schedule,
263269
const void *sendbuf, const int *sendcounts, const int *sdispls,
264-
MPI_Aint sndext, MPI_Datatype sendtype,
270+
MPI_Aint sndext, MPI_Datatype sendtype, const size_t sdtype_size,
265271
void *recvbuf, const int *recvcounts, const int *rdispls,
266-
MPI_Aint rcvext, MPI_Datatype recvtype) {
272+
MPI_Aint rcvext, MPI_Datatype recvtype, const size_t rdtype_size) {
267273
int res;
268274

269275
for (int i = 0 ; i < p ; ++i) {
@@ -272,7 +278,7 @@ static inline int a2av_sched_linear(int rank, int p, NBC_Schedule *schedule,
272278
}
273279

274280
/* post send */
275-
if (sendcounts[i] != 0) {
281+
if (0 < sendcounts[i] && 0 < sdtype_size) {
276282
char *sbuf = ((char *) sendbuf) + (sdispls[i] * sndext);
277283
res = NBC_Sched_send(sbuf, false, sendcounts[i], sendtype, i, schedule, false);
278284
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
@@ -281,7 +287,7 @@ static inline int a2av_sched_linear(int rank, int p, NBC_Schedule *schedule,
281287
}
282288

283289
/* post receive */
284-
if (recvcounts[i] != 0) {
290+
if (0 < recvcounts[i] && 0 < rdtype_size) {
285291
char *rbuf = ((char *) recvbuf) + (rdispls[i] * rcvext);
286292
res = NBC_Sched_recv(rbuf, false, recvcounts[i], recvtype, i, schedule, false);
287293
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
@@ -296,17 +302,17 @@ static inline int a2av_sched_linear(int rank, int p, NBC_Schedule *schedule,
296302
__opal_attribute_unused__
297303
static inline int a2av_sched_pairwise(int rank, int p, NBC_Schedule *schedule,
298304
const void *sendbuf, const int *sendcounts, const int *sdispls,
299-
MPI_Aint sndext, MPI_Datatype sendtype,
305+
MPI_Aint sndext, MPI_Datatype sendtype, const size_t sdtype_size,
300306
void *recvbuf, const int *recvcounts, const int *rdispls,
301-
MPI_Aint rcvext, MPI_Datatype recvtype) {
307+
MPI_Aint rcvext, MPI_Datatype recvtype, const size_t rdtype_size) {
302308
int res;
303309

304310
for (int i = 1 ; i < p ; ++i) {
305311
int sndpeer = (rank + i) % p;
306312
int rcvpeer = (rank + p - i) %p;
307313

308314
/* post send */
309-
if (sendcounts[sndpeer] != 0) {
315+
if (0 < sendcounts[sndpeer] && 0 < sdtype_size) {
310316
char *sbuf = ((char *) sendbuf) + (sdispls[sndpeer] * sndext);
311317
res = NBC_Sched_send(sbuf, false, sendcounts[sndpeer], sendtype, sndpeer, schedule, false);
312318
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
@@ -315,7 +321,7 @@ static inline int a2av_sched_pairwise(int rank, int p, NBC_Schedule *schedule,
315321
}
316322

317323
/* post receive */
318-
if (recvcounts[rcvpeer] != 0) {
324+
if (0 < recvcounts[rcvpeer] && 0 < rdtype_size) {
319325
char *rbuf = ((char *) recvbuf) + (rdispls[rcvpeer] * rcvext);
320326
res = NBC_Sched_recv(rbuf, false, recvcounts[rcvpeer], recvtype, rcvpeer, schedule, true);
321327
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
@@ -329,7 +335,7 @@ static inline int a2av_sched_pairwise(int rank, int p, NBC_Schedule *schedule,
329335

330336
static inline int a2av_sched_inplace(int rank, int p, NBC_Schedule *schedule,
331337
void *buf, const int *counts, const int *displs,
332-
MPI_Aint ext, MPI_Datatype type, ptrdiff_t gap) {
338+
MPI_Aint ext, MPI_Datatype type, const size_t dtype_size, ptrdiff_t gap) {
333339
int res;
334340

335341
for (int i = 1; i < (p+1)/2; i++) {
@@ -338,34 +344,39 @@ static inline int a2av_sched_inplace(int rank, int p, NBC_Schedule *schedule,
338344
char *sbuf = (char *) buf + displs[speer] * ext;
339345
char *rbuf = (char *) buf + displs[rpeer] * ext;
340346

341-
if (0 != counts[rpeer]) {
347+
if (0 == dtype_size) {
348+
/* Nothing to exchange */
349+
return OMPI_SUCCESS;
350+
}
351+
352+
if (0 < counts[rpeer]) {
342353
res = NBC_Sched_copy (rbuf, false, counts[rpeer], type,
343354
(void *)(-gap), true, counts[rpeer], type,
344355
schedule, true);
345356
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
346357
return res;
347358
}
348359
}
349-
if (0 != counts[speer]) {
360+
if (0 < counts[speer]) {
350361
res = NBC_Sched_send (sbuf, false , counts[speer], type, speer, schedule, false);
351362
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
352363
return res;
353364
}
354365
}
355-
if (0 != counts[rpeer]) {
366+
if (0 < counts[rpeer]) {
356367
res = NBC_Sched_recv (rbuf, false , counts[rpeer], type, rpeer, schedule, true);
357368
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
358369
return res;
359370
}
360371
}
361372

362-
if (0 != counts[rpeer]) {
373+
if (0 < counts[rpeer]) {
363374
res = NBC_Sched_send ((void *)(-gap), true, counts[rpeer], type, rpeer, schedule, false);
364375
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
365376
return res;
366377
}
367378
}
368-
if (0 != counts[speer]) {
379+
if (0 < counts[speer]) {
369380
res = NBC_Sched_recv (sbuf, false, counts[speer], type, speer, schedule, true);
370381
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
371382
return res;
@@ -374,15 +385,17 @@ static inline int a2av_sched_inplace(int rank, int p, NBC_Schedule *schedule,
374385
}
375386
if (0 == (p%2)) {
376387
int peer = (rank + p/2) % p;
377-
378388
char *tbuf = (char *) buf + displs[peer] * ext;
379-
res = NBC_Sched_copy (tbuf, false, counts[peer], type,
380-
(void *)(-gap), true, counts[peer], type,
381-
schedule, true);
382-
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
383-
return res;
389+
390+
if (0 < counts[peer]) {
391+
res = NBC_Sched_copy(tbuf, false, counts[peer], type, (void *) (-gap), true, counts[peer],
392+
type, schedule, true);
393+
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
394+
return res;
395+
}
384396
}
385-
if (0 != counts[peer]) {
397+
398+
if (0 < counts[peer]) {
386399
res = NBC_Sched_send ((void *)(-gap), true , counts[peer], type, peer, schedule, false);
387400
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
388401
return res;

0 commit comments

Comments
 (0)