Skip to content

Commit 82299a9

Browse files
committed
coll: reduce_scatter_block: add recursive halving algorithm
Signed-off-by: Mikhail Kurnosov <[email protected]>
1 parent 93930a2 commit 82299a9

File tree

3 files changed

+215
-3
lines changed

3 files changed

+215
-3
lines changed

ompi/mca/coll/base/coll_base_functions.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ int ompi_coll_base_reduce_scatter_intra_ring(REDUCESCATTER_ARGS);
252252
/* Reduce_scatter_block */
253253
int ompi_coll_base_reduce_scatter_block_basic(REDUCESCATTERBLOCK_ARGS);
254254
int ompi_coll_base_reduce_scatter_block_intra_recursivedoubling(REDUCESCATTERBLOCK_ARGS);
255+
int ompi_coll_base_reduce_scatter_block_intra_recursivehalving(REDUCESCATTERBLOCK_ARGS);
255256

256257
/* Scan */
257258
int ompi_coll_base_scan_intra_recursivedoubling(SCAN_ARGS);

ompi/mca/coll/base/coll_base_reduce_scatter_block.c

Lines changed: 211 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,15 @@
3232
#include "ompi/datatype/ompi_datatype.h"
3333
#include "ompi/communicator/communicator.h"
3434
#include "ompi/mca/coll/coll.h"
35-
#include "ompi/mca/coll/base/coll_tags.h"
36-
#include "ompi/mca/coll/base/coll_base_functions.h"
35+
#include "ompi/mca/coll/basic/coll_basic.h"
3736
#include "ompi/mca/pml/pml.h"
3837
#include "ompi/op/op.h"
39-
#include "ompi/mca/coll/base/coll_base_functions.h"
38+
#include "coll_tags.h"
39+
#include "coll_base_functions.h"
4040
#include "coll_base_topo.h"
4141
#include "coll_base_util.h"
4242

43+
4344
/*
4445
* ompi_reduce_scatter_block_basic
4546
*
@@ -303,3 +304,210 @@ ompi_coll_base_reduce_scatter_block_intra_recursivedoubling(
303304
free(tmprecv_raw);
304305
return err;
305306
}
307+
308+
/*
309+
* ompi_range_sum: Returns sum of elems in intersection of [a, b] and [0, r]
310+
* index: 0 1 2 3 4 ... r r+1 r+2 ... nproc_pof2
311+
* value: 2 2 2 2 2 ... 2 1 1 ... 1
312+
*/
313+
static int ompi_range_sum(int a, int b, int r)
314+
{
315+
if (r < a)
316+
return b - a + 1;
317+
else if (r > b)
318+
return 2 * (b - a + 1);
319+
return 2 * (r - a + 1) + b - r;
320+
}
321+
322+
/*
323+
* ompi_coll_base_reduce_scatter_block_intra_recursivehalving
324+
*
325+
* Function: Recursive halving algorithm for reduce_scatter_block
326+
* Accepts: Same as MPI_Reduce_scatter_block
327+
* Returns: MPI_SUCCESS or error code
328+
*
329+
* Description: Implements recursive halving algorithm for MPI_Reduce_scatter_block.
330+
* The algorithm can be used by commutative operations only.
331+
*
332+
* Limitations: commutative operations only
333+
* Memory requirements (per process): 2 * rcount * comm_size * typesize
334+
*/
335+
int
336+
ompi_coll_base_reduce_scatter_block_intra_recursivehalving(
337+
const void *sbuf, void *rbuf, int rcount, struct ompi_datatype_t *dtype,
338+
struct ompi_op_t *op, struct ompi_communicator_t *comm,
339+
mca_coll_base_module_t *module)
340+
{
341+
char *tmprecv_raw = NULL, *tmpbuf_raw = NULL, *tmprecv, *tmpbuf;
342+
ptrdiff_t span, gap, totalcount, extent;
343+
int err = MPI_SUCCESS;
344+
int comm_size = ompi_comm_size(comm);
345+
int rank = ompi_comm_rank(comm);
346+
347+
OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
348+
"coll:base:reduce_scatter_block_intra_recursivehalving: rank %d/%d",
349+
rank, comm_size));
350+
if (rcount == 0 || comm_size < 2)
351+
return MPI_SUCCESS;
352+
353+
if (!ompi_op_is_commute(op)) {
354+
OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
355+
"coll:base:reduce_scatter_block_intra_recursivehalving: rank %d/%d "
356+
"switching to basic reduce_scatter_block", rank, comm_size));
357+
return ompi_coll_base_reduce_scatter_block_basic(sbuf, rbuf, rcount, dtype,
358+
op, comm, module);
359+
}
360+
totalcount = comm_size * rcount;
361+
ompi_datatype_type_extent(dtype, &extent);
362+
span = opal_datatype_span(&dtype->super, totalcount, &gap);
363+
tmpbuf_raw = malloc(span);
364+
tmprecv_raw = malloc(span);
365+
if (NULL == tmpbuf_raw || NULL == tmprecv_raw) {
366+
err = OMPI_ERR_OUT_OF_RESOURCE;
367+
goto cleanup_and_return;
368+
}
369+
tmpbuf = tmpbuf_raw - gap;
370+
tmprecv = tmprecv_raw - gap;
371+
372+
if (sbuf != MPI_IN_PLACE) {
373+
err = ompi_datatype_copy_content_same_ddt(dtype, totalcount, tmpbuf, (char *)sbuf);
374+
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
375+
} else {
376+
err = ompi_datatype_copy_content_same_ddt(dtype, totalcount, tmpbuf, rbuf);
377+
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
378+
}
379+
380+
/*
381+
* Step 1. Reduce the number of processes to the nearest lower power of two
382+
* p' = 2^{\floor{\log_2 p}} by removing r = p - p' processes.
383+
* In the first 2r processes (ranks 0 to 2r - 1), all the even ranks send
384+
* the input vector to their neighbor (rank + 1) and all the odd ranks recv
385+
* the input vector and perform local reduction.
386+
* The odd ranks (0 to 2r - 1) contain the reduction with the input
387+
* vector on their neighbors (the even ranks). The first r odd
388+
* processes and the p - 2r last processes are renumbered from
389+
* 0 to 2^{\floor{\log_2 p}} - 1. Even ranks do not participate in the
390+
* rest of the algorithm.
391+
*/
392+
393+
/* Find nearest power-of-two less than or equal to comm_size */
394+
int nprocs_pof2 = opal_next_poweroftwo(comm_size);
395+
nprocs_pof2 >>= 1;
396+
int nprocs_rem = comm_size - nprocs_pof2;
397+
398+
int vrank = -1;
399+
if (rank < 2 * nprocs_rem) {
400+
if ((rank % 2) == 0) {
401+
/* Even process */
402+
err = MCA_PML_CALL(send(tmpbuf, totalcount, dtype, rank + 1,
403+
MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK,
404+
MCA_PML_BASE_SEND_STANDARD, comm));
405+
if (OMPI_SUCCESS != err) { goto cleanup_and_return; }
406+
/* This process does not pariticipate in the rest of the algorithm */
407+
vrank = -1;
408+
} else {
409+
/* Odd process */
410+
err = MCA_PML_CALL(recv(tmprecv, totalcount, dtype, rank - 1,
411+
MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK,
412+
comm, MPI_STATUS_IGNORE));
413+
if (OMPI_SUCCESS != err) { goto cleanup_and_return; }
414+
ompi_op_reduce(op, tmprecv, tmpbuf, totalcount, dtype);
415+
/* Adjust rank to be the bottom "remain" ranks */
416+
vrank = rank / 2;
417+
}
418+
} else {
419+
/* Adjust rank to show that the bottom "even remain" ranks dropped out */
420+
vrank = rank - nprocs_rem;
421+
}
422+
423+
if (vrank != -1) {
424+
/*
425+
* Step 2. Recursive vector halving. We have p' = 2^{\floor{\log_2 p}}
426+
* power-of-two number of processes with new ranks (vrank) and partial
427+
* result in tmpbuf.
428+
* All processes then compute the reduction between the local
429+
* buffer and the received buffer. In the next \log_2(p') - 1 steps, the
430+
* buffers are recursively halved. At the end, each of the p' processes
431+
* has 1 / p' of the total reduction result.
432+
*/
433+
int send_index = 0, recv_index = 0, last_index = nprocs_pof2;
434+
for (int mask = nprocs_pof2 >> 1; mask > 0; mask >>= 1) {
435+
int vpeer = vrank ^ mask;
436+
int peer = (vpeer < nprocs_rem) ? vpeer * 2 + 1 : vpeer + nprocs_rem;
437+
438+
/*
439+
* Calculate the recv_count and send_count because the
440+
* even-numbered processes who no longer participate will
441+
* have their result calculated by the process to their
442+
* right (rank + 1).
443+
*/
444+
int send_count = 0, recv_count = 0;
445+
if (vrank < vpeer) {
446+
/* Send the right half of the buffer, recv the left half */
447+
send_index = recv_index + mask;
448+
send_count = rcount * ompi_range_sum(send_index, last_index - 1, nprocs_rem - 1);
449+
recv_count = rcount * ompi_range_sum(recv_index, send_index - 1, nprocs_rem - 1);
450+
} else {
451+
/* Send the left half of the buffer, recv the right half */
452+
recv_index = send_index + mask;
453+
send_count = rcount * ompi_range_sum(send_index, recv_index - 1, nprocs_rem - 1);
454+
recv_count = rcount * ompi_range_sum(recv_index, last_index - 1, nprocs_rem - 1);
455+
}
456+
ptrdiff_t rdispl = rcount * ((recv_index <= nprocs_rem - 1) ?
457+
2 * recv_index : nprocs_rem + recv_index);
458+
ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ?
459+
2 * send_index : nprocs_rem + send_index);
460+
struct ompi_request_t *request = NULL;
461+
462+
if (recv_count > 0) {
463+
err = MCA_PML_CALL(irecv(tmprecv + rdispl * extent, recv_count,
464+
dtype, peer, MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK,
465+
comm, &request));
466+
if (OMPI_SUCCESS != err) { goto cleanup_and_return; }
467+
}
468+
if (send_count > 0) {
469+
err = MCA_PML_CALL(send(tmpbuf + sdispl * extent, send_count,
470+
dtype, peer, MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK,
471+
MCA_PML_BASE_SEND_STANDARD,
472+
comm));
473+
if (OMPI_SUCCESS != err) { goto cleanup_and_return; }
474+
}
475+
if (recv_count > 0) {
476+
err = ompi_request_wait(&request, MPI_STATUS_IGNORE);
477+
if (OMPI_SUCCESS != err) { goto cleanup_and_return; }
478+
ompi_op_reduce(op, tmprecv + rdispl * extent,
479+
tmpbuf + rdispl * extent, recv_count, dtype);
480+
}
481+
send_index = recv_index;
482+
last_index = recv_index + mask;
483+
}
484+
err = ompi_datatype_copy_content_same_ddt(dtype, rcount, rbuf,
485+
tmpbuf + (ptrdiff_t)rank * rcount * extent);
486+
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
487+
}
488+
489+
/* Step 3. Send the result to excluded even ranks */
490+
if (rank < 2 * nprocs_rem) {
491+
if ((rank % 2) == 0) {
492+
/* Even process */
493+
err = MCA_PML_CALL(recv(rbuf, rcount, dtype, rank + 1,
494+
MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK, comm,
495+
MPI_STATUS_IGNORE));
496+
if (OMPI_SUCCESS != err) { goto cleanup_and_return; }
497+
} else {
498+
/* Odd process */
499+
err = MCA_PML_CALL(send(tmpbuf + (ptrdiff_t)(rank - 1) * rcount * extent,
500+
rcount, dtype, rank - 1,
501+
MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK,
502+
MCA_PML_BASE_SEND_STANDARD, comm));
503+
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
504+
}
505+
}
506+
507+
cleanup_and_return:
508+
if (tmpbuf_raw)
509+
free(tmpbuf_raw);
510+
if (tmprecv_raw)
511+
free(tmprecv_raw);
512+
return err;
513+
}

ompi/mca/coll/tuned/coll_tuned_reduce_scatter_block_decision.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ static mca_base_var_enum_value_t reduce_scatter_block_algorithms[] = {
3535
{0, "ignore"},
3636
{1, "basic"},
3737
{2, "recursive_doubling"},
38+
{3, "recursive_halving"},
3839
{0, NULL}
3940
};
4041

@@ -125,6 +126,8 @@ int ompi_coll_tuned_reduce_scatter_block_intra_do_this(const void *sbuf, void *r
125126
dtype, op, comm, module);
126127
case (2): return ompi_coll_base_reduce_scatter_block_intra_recursivedoubling(sbuf, rbuf, rcount,
127128
dtype, op, comm, module);
129+
case (3): return ompi_coll_base_reduce_scatter_block_intra_recursivehalving(sbuf, rbuf, rcount,
130+
dtype, op, comm, module);
128131
} /* switch */
129132
OPAL_OUTPUT((ompi_coll_tuned_stream, "coll:tuned:reduce_scatter_block_intra_do_this attempt to select algorithm %d when only 0-%d is valid?",
130133
algorithm, ompi_coll_tuned_forced_max_algorithms[REDUCESCATTERBLOCK]));

0 commit comments

Comments
 (0)