From e823e51b3a62ad3c6f12ed0e8c823bee6c76a1e0 Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Mon, 10 Jun 2024 11:33:59 -0400 Subject: [PATCH] Add accelerator-awareness to most allreduce implementations Adding accelerator-awareness requires allocation of temporary memory on the a device selected based on the input buffer (possibly cached by the allocator) and reintroduces the use of 3buff reductions to combine copy and operator application. This change also improves performance on CPU for larger operations. Where possible Signed-off-by: Joseph Schuchart --- ompi/mca/coll/base/coll_base_allreduce.c | 364 ++++++++++++++++++----- ompi/mca/coll/base/coll_base_frame.c | 15 + ompi/mca/coll/base/coll_base_functions.h | 6 + ompi/mca/coll/base/coll_base_util.c | 56 ++++ ompi/mca/coll/base/coll_base_util.h | 44 +++ 5 files changed, 407 insertions(+), 78 deletions(-) diff --git a/ompi/mca/coll/base/coll_base_allreduce.c b/ompi/mca/coll/base/coll_base_allreduce.c index 05a2ca0d561..c6380b23866 100644 --- a/ompi/mca/coll/base/coll_base_allreduce.c +++ b/ompi/mca/coll/base/coll_base_allreduce.c @@ -141,6 +141,7 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, int ret, line, rank, size, adjsize, remote, distance; int newrank, newremote, extra_ranks; char *tmpsend = NULL, *tmprecv = NULL, *tmpswap = NULL, *inplacebuf_free = NULL, *inplacebuf; + char *recvbuf = NULL; ptrdiff_t span, gap = 0; size = ompi_comm_size(comm); @@ -158,22 +159,64 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, return MPI_SUCCESS; } - /* Allocate and initialize temporary send buffer */ + /* get the device for sbuf and rbuf and where the op would like to execute */ + int sendbuf_dev, recvbuf_dev, op_dev; + uint64_t sendbuf_flags, recvbuf_flags; + ompi_coll_base_select_device(op, sbuf, rbuf, count, dtype, &sendbuf_dev, &recvbuf_dev, + &sendbuf_flags, &recvbuf_flags, &op_dev); span = opal_datatype_span(&dtype->super, count, &gap); - inplacebuf_free = (char*) malloc(span); + inplacebuf_free = ompi_coll_base_allocate_on_device(op_dev, span, module); if (NULL == inplacebuf_free) { ret = -1; line = __LINE__; goto error_hndl; } inplacebuf = inplacebuf_free - gap; + //printf("allreduce ring count %d sbuf_dev %d rbuf_dev %d op_dev %d\n", count, sendbuf_dev, recvbuf_dev, op_dev); - if (MPI_IN_PLACE == sbuf) { - ret = ompi_datatype_copy_content_same_ddt(dtype, count, inplacebuf, (char*)rbuf); - if (ret < 0) { line = __LINE__; goto error_hndl; } - } else { - ret = ompi_datatype_copy_content_same_ddt(dtype, count, inplacebuf, (char*)sbuf); + opal_accelerator_stream_t *stream = NULL; + if (op_dev >= 0) { + stream = MCA_ACCELERATOR_STREAM_DEFAULT; + } + + tmpsend = (char*) sbuf; + if (op_dev != recvbuf_dev) { + /* copy data to where the op wants it to be */ + if (MPI_IN_PLACE == sbuf) { + ret = ompi_datatype_copy_content_same_ddt_stream(dtype, count, inplacebuf, (char*)rbuf, stream); + if (ret < 0) { line = __LINE__; goto error_hndl; } + } + /* only copy if op is on the device or we cannot access the sendbuf on the host */ + else if (op_dev != MCA_ACCELERATOR_NO_DEVICE_ID || + 0 == (sendbuf_flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY)) { + ret = ompi_datatype_copy_content_same_ddt_stream(dtype, count, inplacebuf, (char*)sbuf, stream); + if (ret < 0) { line = __LINE__; goto error_hndl; } + } + tmpsend = (char*) inplacebuf; + } else if (MPI_IN_PLACE == sbuf) { + ret = ompi_datatype_copy_content_same_ddt_stream(dtype, count, inplacebuf, (char*)rbuf, stream); if (ret < 0) { line = __LINE__; goto error_hndl; } + tmpsend = (char*) inplacebuf; } - tmpsend = (char*) inplacebuf; - tmprecv = (char*) rbuf; + /* Handle MPI_IN_PLACE */ + bool use_sbuf = (MPI_IN_PLACE != sbuf); + /* allocate temporary recv buffer if the tmpbuf above is on a different device than the rbuf + * and the op is on the device or we cannot access the recv buffer on the host */ + recvbuf = rbuf; + bool free_recvbuf = false; + if (op_dev != recvbuf_dev && + (op_dev != MCA_ACCELERATOR_NO_DEVICE_ID || + 0 == (recvbuf_flags & MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY))) { + recvbuf = ompi_coll_base_allocate_on_device(op_dev, span, module); + free_recvbuf = true; + if (use_sbuf) { + /* copy from rbuf */ + ompi_datatype_copy_content_same_ddt_stream(dtype, count, (char*)recvbuf, (char*)sbuf, stream); + } else { + /* copy from sbuf */ + ompi_datatype_copy_content_same_ddt_stream(dtype, count, (char*)recvbuf, (char*)rbuf, stream); + } + use_sbuf = false; + } + + tmprecv = (char*) recvbuf; /* Determine nearest power of two less than or equal to size */ adjsize = opal_next_poweroftwo (size); @@ -189,6 +232,11 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, extra_ranks = size - adjsize; if (rank < (2 * extra_ranks)) { if (0 == (rank % 2)) { + /* wait for above copies to complete */ + if (NULL != stream) { + opal_accelerator.wait_stream(stream); + } + /* wait for tmpsend to be copied */ ret = MCA_PML_CALL(send(tmpsend, count, dtype, (rank + 1), MCA_COLL_BASE_TAG_ALLREDUCE, MCA_PML_BASE_SEND_STANDARD, comm)); @@ -199,8 +247,14 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, MCA_COLL_BASE_TAG_ALLREDUCE, comm, MPI_STATUS_IGNORE)); if (MPI_SUCCESS != ret) { line = __LINE__; goto error_hndl; } - /* tmpsend = tmprecv (op) tmpsend */ - ompi_op_reduce(op, tmprecv, tmpsend, count, dtype); + if (tmpsend == sbuf) { + tmpsend = inplacebuf; + /* tmpsend = tmprecv (op) sbuf */ + ompi_3buff_op_reduce_stream(op, sbuf, tmprecv, tmpsend, count, dtype, op_dev, stream); + } else { + /* tmpsend = tmprecv (op) tmpsend */ + ompi_op_reduce_stream(op, tmprecv, tmpsend, count, dtype, op_dev, stream); + } newrank = rank >> 1; } } else { @@ -219,6 +273,12 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, remote = (newremote < extra_ranks)? (newremote * 2 + 1):(newremote + extra_ranks); + bool have_next_iter = ((distance << 1) < adjsize); + + /* wait for previous ops to complete to complete */ + if (NULL != stream) { + opal_accelerator.wait_stream(stream); + } /* Exchange the data */ ret = ompi_coll_base_sendrecv_actual(tmpsend, count, dtype, remote, MCA_COLL_BASE_TAG_ALLREDUCE, @@ -229,14 +289,47 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, /* Apply operation */ if (rank < remote) { - /* tmprecv = tmpsend (op) tmprecv */ - ompi_op_reduce(op, tmpsend, tmprecv, count, dtype); - tmpswap = tmprecv; - tmprecv = tmpsend; - tmpsend = tmpswap; + if (tmpsend == sbuf) { + /* special case: 1st iteration takes one input from the sbuf */ + /* tmprecv = sbuf (op) tmprecv */ + ompi_op_reduce_stream(op, sbuf, tmprecv, count, dtype, op_dev, stream); + /* send the current recv buffer, and use the tmp buffer to receive */ + tmpsend = tmprecv; + tmprecv = inplacebuf; + } else if (have_next_iter || tmprecv == recvbuf) { + /* All iterations, and the last if tmprecv is the recv buffer */ + /* tmprecv = tmpsend (op) tmprecv */ + ompi_op_reduce_stream(op, tmpsend, tmprecv, count, dtype, op_dev, stream); + /* swap send and receive buffers */ + tmpswap = tmprecv; + tmprecv = tmpsend; + tmpsend = tmpswap; + } else { + /* Last iteration if tmprecv is not the recv buffer, then tmpsend is */ + /* Make sure we reduce into the receive buffer + * tmpsend = tmprecv (op) tmpsend */ + ompi_op_reduce_stream(op, tmprecv, tmpsend, count, dtype, op_dev, stream); + } } else { - /* tmpsend = tmprecv (op) tmpsend */ - ompi_op_reduce(op, tmprecv, tmpsend, count, dtype); + if (tmpsend == sbuf) { + /* First iteration: use input from sbuf */ + /* tmpsend = tmprecv (op) sbuf */ + tmpsend = inplacebuf; + if (have_next_iter || tmpsend == recvbuf) { + ompi_3buff_op_reduce_stream(op, tmprecv, sbuf, tmpsend, count, dtype, op_dev, stream); + } else { + ompi_op_reduce_stream(op, sbuf, tmprecv, count, dtype, op_dev, stream); + tmpsend = tmprecv; + } + } else if (have_next_iter || tmpsend == recvbuf) { + /* All other iterations: reduce into tmpsend for next iteration */ + /* tmpsend = tmprecv (op) tmpsend */ + ompi_op_reduce_stream(op, tmprecv, tmpsend, count, dtype, op_dev, stream); + } else { + /* Last iteration: reduce into rbuf and set tmpsend to rbuf (needed at the end) */ + ompi_op_reduce_stream(op, tmpsend, tmprecv, count, dtype, op_dev, stream); + tmpsend = tmprecv; + } } } @@ -253,6 +346,10 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, if (MPI_SUCCESS != ret) { line = __LINE__; goto error_hndl; } tmpsend = (char*)rbuf; } else { + /* wait for previous ops to complete */ + if (NULL != stream) { + opal_accelerator.wait_stream(stream); + } ret = MCA_PML_CALL(send(tmpsend, count, dtype, (rank - 1), MCA_COLL_BASE_TAG_ALLREDUCE, MCA_PML_BASE_SEND_STANDARD, comm)); @@ -262,18 +359,31 @@ ompi_coll_base_allreduce_intra_recursivedoubling(const void *sbuf, void *rbuf, /* Ensure that the final result is in rbuf */ if (tmpsend != rbuf) { - ret = ompi_datatype_copy_content_same_ddt(dtype, count, (char*)rbuf, tmpsend); + /* TODO: catch this case in the 3buf selection above. Maybe already caught? */ + ret = ompi_datatype_copy_content_same_ddt_stream(dtype, count, (char*)rbuf, tmpsend, stream); if (ret < 0) { line = __LINE__; goto error_hndl; } } - if (NULL != inplacebuf_free) free(inplacebuf_free); + /* wait for previous ops to complete */ + if (NULL != stream) { + opal_accelerator.wait_stream(stream); + } + ompi_coll_base_free_tmpbuf(inplacebuf_free, op_dev, module); + + if (free_recvbuf) { + ompi_coll_base_free_tmpbuf(recvbuf, op_dev, module); + } return MPI_SUCCESS; error_hndl: OPAL_OUTPUT((ompi_coll_base_framework.framework_output, "%s:%4d\tRank %d Error occurred %d\n", __FILE__, line, rank, ret)); (void)line; // silence compiler warning - if (NULL != inplacebuf_free) free(inplacebuf_free); + ompi_coll_base_free_tmpbuf(inplacebuf_free, op_dev, module); + + if (op_dev != recvbuf_dev) { + ompi_coll_base_free_tmpbuf(recvbuf, op_dev, module); + } return ret; } @@ -352,6 +462,7 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, size_t count, int early_segcount, late_segcount, split_rank, max_segcount; size_t typelng; char *tmpsend = NULL, *tmprecv = NULL, *inbuf[2] = {NULL, NULL}; + void *recvbuf = NULL; ptrdiff_t true_lb, true_extent, lb, extent; ptrdiff_t block_offset, max_real_segsize; ompi_request_t *reqs[2] = {MPI_REQUEST_NULL, MPI_REQUEST_NULL}; @@ -400,18 +511,36 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, size_t count, max_segcount = early_segcount; max_real_segsize = true_extent + (max_segcount - 1) * extent; - - inbuf[0] = (char*)malloc(max_real_segsize); - if (NULL == inbuf[0]) { ret = -1; line = __LINE__; goto error_hndl; } + /* get the device for sbuf and rbuf and where the op would like to execute */ + int sendbuf_dev, recvbuf_dev, op_dev; + uint64_t sendbuf_flags, recvbuf_flags; + ompi_coll_base_select_device(op, sbuf, rbuf, count, dtype, &sendbuf_dev, &recvbuf_dev, + &sendbuf_flags, &recvbuf_flags, &op_dev); if (size > 2) { - inbuf[1] = (char*)malloc(max_real_segsize); - if (NULL == inbuf[1]) { ret = -1; line = __LINE__; goto error_hndl; } + inbuf[0] = ompi_coll_base_allocate_on_device(op_dev, 2*max_real_segsize, module); + if (NULL == inbuf[0]) { ret = -1; line = __LINE__; goto error_hndl; } + inbuf[1] = inbuf[0] + max_real_segsize; + } else { + inbuf[0] = ompi_coll_base_allocate_on_device(op_dev, max_real_segsize, module); + if (NULL == inbuf[0]) { ret = -1; line = __LINE__; goto error_hndl; } } /* Handle MPI_IN_PLACE */ - if (MPI_IN_PLACE != sbuf) { - ret = ompi_datatype_copy_content_same_ddt(dtype, count, (char*)rbuf, (char*)sbuf); - if (ret < 0) { line = __LINE__; goto error_hndl; } + bool use_sbuf = (MPI_IN_PLACE != sbuf); + /* allocate temporary recv buffer if the tmpbuf above is on a different device than the rbuf */ + recvbuf = rbuf; + if (op_dev != recvbuf_dev && + /* only copy if op is on the device or the recvbuffer cannot be accessed on the host */ + (op_dev != MCA_ACCELERATOR_NO_DEVICE_ID || 0 == (MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY & recvbuf_flags))) { + recvbuf = ompi_coll_base_allocate_on_device(op_dev, typelng*count, module); + if (use_sbuf) { + /* copy from rbuf */ + ompi_datatype_copy_content_same_ddt(dtype, count, (char*)recvbuf, (char*)sbuf); + } else { + /* copy from sbuf */ + ompi_datatype_copy_content_same_ddt(dtype, count, (char*)recvbuf, (char*)rbuf); + } + use_sbuf = false; } /* Computation loop */ @@ -444,7 +573,7 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, size_t count, ((ptrdiff_t)rank * (ptrdiff_t)early_segcount) : ((ptrdiff_t)rank * (ptrdiff_t)late_segcount + split_rank)); block_count = ((rank < split_rank)? early_segcount : late_segcount); - tmpsend = ((char*)rbuf) + block_offset * extent; + tmpsend = ((use_sbuf) ? ((char*)sbuf) : ((char*)recvbuf)) + block_offset * extent; ret = MCA_PML_CALL(send(tmpsend, block_count, dtype, send_to, MCA_COLL_BASE_TAG_ALLREDUCE, MCA_PML_BASE_SEND_STANDARD, comm)); @@ -471,8 +600,17 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, size_t count, ((ptrdiff_t)prevblock * early_segcount) : ((ptrdiff_t)prevblock * late_segcount + split_rank)); block_count = ((prevblock < split_rank)? early_segcount : late_segcount); - tmprecv = ((char*)rbuf) + (ptrdiff_t)block_offset * extent; - ompi_op_reduce(op, inbuf[inbi ^ 0x1], tmprecv, block_count, dtype); + tmprecv = ((char*)recvbuf) + (ptrdiff_t)block_offset * extent; + if (use_sbuf) { + void *tmpsbuf = ((char*)sbuf) + (ptrdiff_t)block_offset * extent; + /* tmprecv = inbuf[inbi ^ 0x1] (op) sbuf */ + ompi_3buff_op_reduce_stream(op, inbuf[inbi ^ 0x1], tmpsbuf, tmprecv, block_count, + dtype, op_dev, NULL); + } else { + /* tmprecv = inbuf[inbi ^ 0x1] (op) tmprecv */ + ompi_op_reduce_stream(op, inbuf[inbi ^ 0x1], tmprecv, block_count, + dtype, op_dev, NULL); + } /* send previous block to send_to */ ret = MCA_PML_CALL(send(tmprecv, block_count, dtype, send_to, @@ -492,8 +630,8 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, size_t count, ((ptrdiff_t)recv_from * early_segcount) : ((ptrdiff_t)recv_from * late_segcount + split_rank)); block_count = ((recv_from < split_rank)? early_segcount : late_segcount); - tmprecv = ((char*)rbuf) + (ptrdiff_t)block_offset * extent; - ompi_op_reduce(op, inbuf[inbi], tmprecv, block_count, dtype); + tmprecv = ((char*)recvbuf) + (ptrdiff_t)block_offset * extent; + ompi_op_reduce_stream(op, inbuf[inbi], tmprecv, block_count, dtype, op_dev, NULL); /* Distribution loop - variation of ring allgather */ send_to = (rank + 1) % size; @@ -512,8 +650,8 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, size_t count, block_count = ((send_data_from < split_rank)? early_segcount : late_segcount); - tmprecv = (char*)rbuf + (ptrdiff_t)recv_block_offset * extent; - tmpsend = (char*)rbuf + (ptrdiff_t)send_block_offset * extent; + tmprecv = (char*)recvbuf + (ptrdiff_t)recv_block_offset * extent; + tmpsend = (char*)recvbuf + (ptrdiff_t)send_block_offset * extent; ret = ompi_coll_base_sendrecv(tmpsend, block_count, dtype, send_to, MCA_COLL_BASE_TAG_ALLREDUCE, @@ -521,11 +659,14 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, size_t count, MCA_COLL_BASE_TAG_ALLREDUCE, comm, MPI_STATUS_IGNORE, rank); if (MPI_SUCCESS != ret) { line = __LINE__; goto error_hndl;} - } - if (NULL != inbuf[0]) free(inbuf[0]); - if (NULL != inbuf[1]) free(inbuf[1]); + ompi_coll_base_free_tmpbuf(inbuf[0], op_dev, module); + if (recvbuf != rbuf) { + /* copy to final rbuf and release temporary recvbuf */ + ompi_datatype_copy_content_same_ddt(dtype, count, (char*)rbuf, (char*)recvbuf); + ompi_coll_base_free_tmpbuf(recvbuf, op_dev, module); + } return MPI_SUCCESS; @@ -534,8 +675,12 @@ ompi_coll_base_allreduce_intra_ring(const void *sbuf, void *rbuf, size_t count, __FILE__, line, rank, ret)); ompi_coll_base_free_reqs(reqs, 2); (void)line; // silence compiler warning - if (NULL != inbuf[0]) free(inbuf[0]); - if (NULL != inbuf[1]) free(inbuf[1]); + ompi_coll_base_free_tmpbuf(inbuf[0], op_dev, module); + if (NULL != recvbuf && recvbuf != rbuf) { + /* copy to final rbuf and release temporary recvbuf */ + ompi_datatype_copy_content_same_ddt(dtype, count, (char*)rbuf, (char*)recvbuf); + ompi_coll_base_free_tmpbuf(recvbuf, op_dev, module); + } return ret; } @@ -688,16 +833,21 @@ ompi_coll_base_allreduce_intra_ring_segmented(const void *sbuf, void *rbuf, size if (MPI_SUCCESS != ret) { line = __LINE__; goto error_hndl; } max_real_segsize = opal_datatype_span(&dtype->super, max_segcount, &gap); + int sendbuf_dev, recvbuf_dev, op_dev; + uint64_t sendbuf_flags, recvbuf_flags; + ompi_coll_base_select_device(op, sbuf, rbuf, count, dtype, &sendbuf_dev, &recvbuf_dev, + &sendbuf_flags, &recvbuf_flags, &op_dev); /* Allocate and initialize temporary buffers */ - inbuf[0] = (char*)malloc(max_real_segsize); + inbuf[0] = ompi_coll_base_allocate_on_device(op_dev, max_real_segsize, module); if (NULL == inbuf[0]) { ret = -1; line = __LINE__; goto error_hndl; } if (size > 2) { - inbuf[1] = (char*)malloc(max_real_segsize); + inbuf[1] = ompi_coll_base_allocate_on_device(op_dev, max_real_segsize, module); if (NULL == inbuf[1]) { ret = -1; line = __LINE__; goto error_hndl; } } /* Handle MPI_IN_PLACE */ if (MPI_IN_PLACE != sbuf) { + /* TODO: can we avoid this copy? */ ret = ompi_datatype_copy_content_same_ddt(dtype, count, (char*)rbuf, (char*)sbuf); if (ret < 0) { line = __LINE__; goto error_hndl; } } @@ -783,7 +933,8 @@ ompi_coll_base_allreduce_intra_ring_segmented(const void *sbuf, void *rbuf, size ((ptrdiff_t)phase * (ptrdiff_t)early_phase_segcount) : ((ptrdiff_t)phase * (ptrdiff_t)late_phase_segcount + split_phase)); tmprecv = ((char*)rbuf) + (ptrdiff_t)(block_offset + phase_offset) * extent; - ompi_op_reduce(op, inbuf[inbi ^ 0x1], tmprecv, phase_count, dtype); + ompi_op_reduce_stream(op, inbuf[inbi ^ 0x1], tmprecv, phase_count, + dtype, op_dev, NULL); /* send previous block to send_to */ ret = MCA_PML_CALL(send(tmprecv, phase_count, dtype, send_to, @@ -812,7 +963,8 @@ ompi_coll_base_allreduce_intra_ring_segmented(const void *sbuf, void *rbuf, size ((ptrdiff_t)phase * (ptrdiff_t)early_phase_segcount) : ((ptrdiff_t)phase * (ptrdiff_t)late_phase_segcount + split_phase)); tmprecv = ((char*)rbuf) + (ptrdiff_t)(block_offset + phase_offset) * extent; - ompi_op_reduce(op, inbuf[inbi], tmprecv, phase_count, dtype); + ompi_op_reduce_stream(op, inbuf[inbi], tmprecv, phase_count, + dtype, op_dev, NULL); } /* Distribution loop - variation of ring allgather */ @@ -844,8 +996,8 @@ ompi_coll_base_allreduce_intra_ring_segmented(const void *sbuf, void *rbuf, size } - if (NULL != inbuf[0]) free(inbuf[0]); - if (NULL != inbuf[1]) free(inbuf[1]); + ompi_coll_base_free_tmpbuf(inbuf[0], op_dev, module); + ompi_coll_base_free_tmpbuf(inbuf[1], op_dev, module); return MPI_SUCCESS; @@ -854,8 +1006,8 @@ ompi_coll_base_allreduce_intra_ring_segmented(const void *sbuf, void *rbuf, size __FILE__, line, rank, ret)); ompi_coll_base_free_reqs(reqs, 2); (void)line; // silence compiler warning - if (NULL != inbuf[0]) free(inbuf[0]); - if (NULL != inbuf[1]) free(inbuf[1]); + ompi_coll_base_free_tmpbuf(inbuf[0], op_dev, module); + ompi_coll_base_free_tmpbuf(inbuf[1], op_dev, module); return ret; } @@ -984,7 +1136,14 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( "coll:base:allreduce_intra_redscat_allgather: rank %d/%d", rank, comm_size)); - if (!ompi_op_is_commute(op)) { + /* Find nearest power-of-two less than or equal to comm_size */ + int nsteps = opal_hibit(comm_size, comm->c_cube_dim + 1); /* ilog2(comm_size) */ + if (-1 == nsteps) { + return MPI_ERR_ARG; + } + int nprocs_pof2 = 1 << nsteps; /* flp2(comm_size) */ + + if (count < nprocs_pof2 || !ompi_op_is_commute(op)) { OPAL_OUTPUT((ompi_coll_base_framework.framework_output, "coll:base:allreduce_intra_redscat_allgather: rank %d/%d " "count %zu switching to basic linear allreduce", @@ -993,28 +1152,32 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( op, comm, module); } - /* Find nearest power-of-two less than or equal to comm_size */ - int nsteps = opal_hibit(comm_size, comm->c_cube_dim + 1); /* ilog2(comm_size) */ - if (-1 == nsteps) { - return MPI_ERR_ARG; - } - int nprocs_pof2 = 1 << nsteps; /* flp2(comm_size) */ int err = MPI_SUCCESS; ptrdiff_t lb, extent, dsize, gap = 0; ompi_datatype_get_extent(dtype, &lb, &extent); dsize = opal_datatype_span(&dtype->super, count, &gap); + /* get the device for sbuf and rbuf and where the op would like to execute */ + int sendbuf_dev, recvbuf_dev, op_dev; + uint64_t sendbuf_flags, recvbuf_flags; + ompi_coll_base_select_device(op, sbuf, rbuf, count, dtype, &sendbuf_dev, &recvbuf_dev, + &sendbuf_flags, &recvbuf_flags, &op_dev); + /* Temporary buffer for receiving messages */ char *tmp_buf = NULL; - char *tmp_buf_raw = (char *)malloc(dsize); + char *tmp_buf_raw = ompi_coll_base_allocate_on_device(op_dev, dsize, module); if (NULL == tmp_buf_raw) return OMPI_ERR_OUT_OF_RESOURCE; tmp_buf = tmp_buf_raw - gap; - if (sbuf != MPI_IN_PLACE) { - err = ompi_datatype_copy_content_same_ddt(dtype, count, (char *)rbuf, - (char *)sbuf); - if (MPI_SUCCESS != err) { goto cleanup_and_return; } + char *recvbuf = rbuf; + if (op_dev != recvbuf_dev && 0 == (MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY & recvbuf_flags)) { + recvbuf = ompi_coll_base_allocate_on_device(op_dev, dsize, module); + } + if (op_dev != sendbuf_dev && 0 == (MCA_ACCELERATOR_FLAGS_UNIFIED_MEMORY & sendbuf_flags) && sbuf != MPI_IN_PLACE) { + /* move the data into the recvbuf and set sbuf to MPI_IN_PLACE */ + ompi_datatype_copy_content_same_ddt(dtype, count, (char*)recvbuf, (char*)sbuf); + sbuf = MPI_IN_PLACE; } /* @@ -1037,9 +1200,18 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( int vrank, step, wsize; int nprocs_rem = comm_size - nprocs_pof2; + opal_accelerator_stream_t *stream = NULL; + if (op_dev >= 0) { + stream = MCA_ACCELERATOR_STREAM_DEFAULT; + } + if (rank < 2 * nprocs_rem) { int count_lhalf = count / 2; int count_rhalf = count - count_lhalf; + const void *send_buf = sbuf; + if (MPI_IN_PLACE == sbuf) { + send_buf = recvbuf; + } if (rank % 2 != 0) { /* @@ -1047,7 +1219,7 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( * Send the left half of the input vector to the left neighbor, * Recv the right half of the input vector from the left neighbor */ - err = ompi_coll_base_sendrecv(rbuf, count_lhalf, dtype, rank - 1, + err = ompi_coll_base_sendrecv((void*)send_buf, count_lhalf, dtype, rank - 1, MCA_COLL_BASE_TAG_ALLREDUCE, (char *)tmp_buf + (ptrdiff_t)count_lhalf * extent, count_rhalf, dtype, rank - 1, @@ -1055,12 +1227,24 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( MPI_STATUS_IGNORE, rank); if (MPI_SUCCESS != err) { goto cleanup_and_return; } - /* Reduce on the right half of the buffers (result in rbuf) */ - ompi_op_reduce(op, (char *)tmp_buf + (ptrdiff_t)count_lhalf * extent, - (char *)rbuf + count_lhalf * extent, count_rhalf, dtype); + /* Reduce on the right half of the buffers (result in rbuf) + * We're not using a stream here, the reduction will make sure that the result is available upon return */ + if (MPI_IN_PLACE != sbuf) { + /* rbuf = sbuf (op) tmp_buf */ + ompi_3buff_op_reduce_stream(op, + (char *)tmp_buf + (ptrdiff_t)count_lhalf * extent, + (char *)sbuf + (ptrdiff_t)count_lhalf * extent, + (char *)recvbuf + count_lhalf * extent, + count_rhalf, dtype, op_dev, NULL); + } else { + /* rbuf = rbuf (op) tmp_buf */ + ompi_op_reduce_stream(op, (char *)tmp_buf + (ptrdiff_t)count_lhalf * extent, + (char *)recvbuf + count_lhalf * extent, count_rhalf, + dtype, op_dev, NULL); + } /* Send the right half to the left neighbor */ - err = MCA_PML_CALL(send((char *)rbuf + (ptrdiff_t)count_lhalf * extent, + err = MCA_PML_CALL(send((char *)recvbuf + (ptrdiff_t)count_lhalf * extent, count_rhalf, dtype, rank - 1, MCA_COLL_BASE_TAG_ALLREDUCE, MCA_PML_BASE_SEND_STANDARD, comm)); @@ -1075,7 +1259,7 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( * Send the right half of the input vector to the right neighbor, * Recv the left half of the input vector from the right neighbor */ - err = ompi_coll_base_sendrecv((char *)rbuf + (ptrdiff_t)count_lhalf * extent, + err = ompi_coll_base_sendrecv((char *)send_buf + (ptrdiff_t)count_lhalf * extent, count_rhalf, dtype, rank + 1, MCA_COLL_BASE_TAG_ALLREDUCE, tmp_buf, count_lhalf, dtype, rank + 1, @@ -1084,21 +1268,35 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( if (MPI_SUCCESS != err) { goto cleanup_and_return; } /* Reduce on the right half of the buffers (result in rbuf) */ - ompi_op_reduce(op, tmp_buf, rbuf, count_lhalf, dtype); + if (MPI_IN_PLACE != sbuf) { + /* rbuf = sbuf (op) tmp_buf */ + ompi_3buff_op_reduce_stream(op, sbuf, tmp_buf, recvbuf, count_lhalf, dtype, op_dev, stream); + } else { + /* rbuf = rbuf (op) tmp_buf */ + ompi_op_reduce_stream(op, tmp_buf, recvbuf, count_lhalf, dtype, op_dev, stream); + } + /* Recv the right half from the right neighbor */ - err = MCA_PML_CALL(recv((char *)rbuf + (ptrdiff_t)count_lhalf * extent, + err = MCA_PML_CALL(recv((char *)recvbuf + (ptrdiff_t)count_lhalf * extent, count_rhalf, dtype, rank + 1, MCA_COLL_BASE_TAG_ALLREDUCE, comm, MPI_STATUS_IGNORE)); if (MPI_SUCCESS != err) { goto cleanup_and_return; } + /* wait for the op to complete */ + if (NULL != stream) { + opal_accelerator.wait_stream(stream); + } + vrank = rank / 2; } } else { /* rank >= 2 * nprocs_rem */ vrank = rank - nprocs_rem; } + /* At this point the input data has been accumulated into the rbuf */ + /* * Step 2. Reduce-scatter implemented with recursive vector halving and * recursive distance doubling. We have p' = 2^{\floor{\log_2 p}} @@ -1155,7 +1353,7 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( } /* Send part of data from the rbuf, recv into the tmp_buf */ - err = ompi_coll_base_sendrecv((char *)rbuf + (ptrdiff_t)sindex[step] * extent, + err = ompi_coll_base_sendrecv((char *)recvbuf + (ptrdiff_t)sindex[step] * extent, scount[step], dtype, dest, MCA_COLL_BASE_TAG_ALLREDUCE, (char *)tmp_buf + (ptrdiff_t)rindex[step] * extent, @@ -1165,9 +1363,9 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( if (MPI_SUCCESS != err) { goto cleanup_and_return; } /* Local reduce: rbuf[] = tmp_buf[] rbuf[] */ - ompi_op_reduce(op, (char *)tmp_buf + (ptrdiff_t)rindex[step] * extent, - (char *)rbuf + (ptrdiff_t)rindex[step] * extent, - rcount[step], dtype); + ompi_op_reduce_stream(op, (char *)tmp_buf + (ptrdiff_t)rindex[step] * extent, + (char *)recvbuf + (ptrdiff_t)rindex[step] * extent, + rcount[step], dtype, op_dev, NULL); /* Move the current window to the received message */ if (step + 1 < nsteps) { @@ -1201,10 +1399,10 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( * Send rcount[step] elements from rbuf[rindex[step]...] * Recv scount[step] elements to rbuf[sindex[step]...] */ - err = ompi_coll_base_sendrecv((char *)rbuf + (ptrdiff_t)rindex[step] * extent, + err = ompi_coll_base_sendrecv((char *)recvbuf + (ptrdiff_t)rindex[step] * extent, rcount[step], dtype, dest, MCA_COLL_BASE_TAG_ALLREDUCE, - (char *)rbuf + (ptrdiff_t)sindex[step] * extent, + (char *)recvbuf + (ptrdiff_t)sindex[step] * extent, scount[step], dtype, dest, MCA_COLL_BASE_TAG_ALLREDUCE, comm, MPI_STATUS_IGNORE, rank); @@ -1216,6 +1414,7 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( /* * Step 4. Send total result to excluded odd ranks. */ + bool recvbuf_need_copy = true; if (rank < 2 * nprocs_rem) { if (rank % 2 != 0) { /* Odd process -- recv result from rank - 1 */ @@ -1223,19 +1422,28 @@ int ompi_coll_base_allreduce_intra_redscat_allgather( MCA_COLL_BASE_TAG_ALLREDUCE, comm, MPI_STATUS_IGNORE)); if (OMPI_SUCCESS != err) { goto cleanup_and_return; } + recvbuf_need_copy = false; } else { /* Even process -- send result to rank + 1 */ - err = MCA_PML_CALL(send(rbuf, count, dtype, rank + 1, + err = MCA_PML_CALL(send(recvbuf, count, dtype, rank + 1, MCA_COLL_BASE_TAG_ALLREDUCE, MCA_PML_BASE_SEND_STANDARD, comm)); if (MPI_SUCCESS != err) { goto cleanup_and_return; } } } + if (recvbuf != rbuf) { + /* copy into final rbuf */ + if (recvbuf_need_copy) { + ompi_datatype_copy_content_same_ddt(dtype, count, (char*)rbuf, (char*)recvbuf); + } + ompi_coll_base_free_tmpbuf(recvbuf, op_dev, module); + } + cleanup_and_return: - if (NULL != tmp_buf_raw) - free(tmp_buf_raw); + + ompi_coll_base_free_tmpbuf(tmp_buf_raw, op_dev, module); if (NULL != rindex) free(rindex); if (NULL != sindex) diff --git a/ompi/mca/coll/base/coll_base_frame.c b/ompi/mca/coll/base/coll_base_frame.c index 07b7f85cf92..36e1c428ab3 100644 --- a/ompi/mca/coll/base/coll_base_frame.c +++ b/ompi/mca/coll/base/coll_base_frame.c @@ -32,6 +32,9 @@ #include "opal/util/output.h" #include "opal/mca/base/base.h" #include "opal/mca/base/mca_base_alias.h" +#include "opal/mca/accelerator/accelerator.h" + + #include "ompi/mca/coll/coll.h" #include "ompi/mca/coll/base/base.h" #include "ompi/mca/coll/base/coll_base_functions.h" @@ -70,6 +73,8 @@ static void coll_base_comm_construct(mca_coll_base_comm_t *data) { memset ((char *) data + sizeof (data->super), 0, sizeof (*data) - sizeof (data->super)); + data->device_allocators = NULL; + data->num_device_allocators = 0; } static void @@ -108,6 +113,16 @@ coll_base_comm_destruct(mca_coll_base_comm_t *data) if (data->cached_in_order_bintree) { /* destroy in order bintree if defined */ ompi_coll_base_topo_destroy_tree (&data->cached_in_order_bintree); } + + if (NULL != data->device_allocators) { + for (int i = 0; i < data->num_device_allocators; ++i) { + if (NULL != data->device_allocators[i]) { + data->device_allocators[i]->alc_finalize(data->device_allocators[i]); + } + } + free(data->device_allocators); + data->device_allocators = NULL; + } } OBJ_CLASS_INSTANCE(mca_coll_base_comm_t, opal_object_t, diff --git a/ompi/mca/coll/base/coll_base_functions.h b/ompi/mca/coll/base/coll_base_functions.h index ae924de5d31..2303aa8665d 100644 --- a/ompi/mca/coll/base/coll_base_functions.h +++ b/ompi/mca/coll/base/coll_base_functions.h @@ -40,6 +40,8 @@ /* need to include our own topo prototypes so we can malloc data on the comm correctly */ #include "coll_base_topo.h" +#include "opal/mca/allocator/allocator.h" + /* some fixed value index vars to simplify certain operations */ typedef enum COLLTYPE { ALLGATHER = 0, /* 0 */ @@ -516,6 +518,10 @@ struct mca_coll_base_comm_t { /* in-order binary tree (root of the in-order binary tree is rank 0) */ ompi_coll_tree_t *cached_in_order_bintree; + + /* pointer to per-device memory cache */ + mca_allocator_base_module_t **device_allocators; + int num_device_allocators; }; typedef struct mca_coll_base_comm_t mca_coll_base_comm_t; OMPI_DECLSPEC OBJ_CLASS_DECLARATION(mca_coll_base_comm_t); diff --git a/ompi/mca/coll/base/coll_base_util.c b/ompi/mca/coll/base/coll_base_util.c index ae9010497d7..9dbad4e9f0f 100644 --- a/ompi/mca/coll/base/coll_base_util.c +++ b/ompi/mca/coll/base/coll_base_util.c @@ -31,6 +31,7 @@ #include "ompi/mca/pml/pml.h" #include "coll_base_util.h" #include "coll_base_functions.h" +#include "opal/mca/allocator/base/base.h" #include int ompi_coll_base_sendrecv_actual( const void* sendbuf, size_t scount, @@ -603,3 +604,58 @@ const char* mca_coll_base_colltype_to_str(int collid) } return colltype_translation_table[collid]; } + +static void* ompi_coll_base_device_allocate_cb(void *ctx, size_t *size) { + int dev_id = (intptr_t)ctx; + void *ptr = NULL; + opal_accelerator.mem_alloc(dev_id, &ptr, *size); + return ptr; +} + +static void ompi_coll_base_device_release_cb(void *ctx, void* ptr) { + int dev_id = (intptr_t)ctx; + opal_accelerator.mem_release(dev_id, ptr); +} + +void *ompi_coll_base_allocate_on_device(int device, size_t size, + mca_coll_base_module_t *module) +{ + mca_allocator_base_module_t *allocator_module; + if (device < 0) { + return malloc(size); + } + + if (module->base_data->num_device_allocators <= device) { + int num_dev; + opal_accelerator.num_devices(&num_dev); + if (num_dev < device+1) num_dev = device+1; + module->base_data->device_allocators = realloc(module->base_data->device_allocators, num_dev * sizeof(mca_allocator_base_module_t *)); + for (int i = module->base_data->num_device_allocators; i < num_dev; ++i) { + module->base_data->device_allocators[i] = NULL; + } + module->base_data->num_device_allocators = num_dev; + } + if (NULL == (allocator_module = module->base_data->device_allocators[device])) { + mca_allocator_base_component_t *allocator_component; + allocator_component = mca_allocator_component_lookup("devicebucket"); + assert(allocator_component != NULL); + allocator_module = allocator_component->allocator_init(false, ompi_coll_base_device_allocate_cb, + ompi_coll_base_device_release_cb, + (void*)(intptr_t)device); + assert(allocator_module != NULL); + module->base_data->device_allocators[device] = allocator_module; + } + return allocator_module->alc_alloc(allocator_module, size, 0); +} + +void ompi_coll_base_free_on_device(int device, void *ptr, mca_coll_base_module_t *module) +{ + mca_allocator_base_module_t *allocator_module; + if (device < 0) { + free(ptr); + } else { + assert(NULL != module->base_data->device_allocators); + allocator_module = module->base_data->device_allocators[device]; + allocator_module->alc_free(allocator_module, ptr); + } +} diff --git a/ompi/mca/coll/base/coll_base_util.h b/ompi/mca/coll/base/coll_base_util.h index 852abcedefa..dd2ecdee1c7 100644 --- a/ompi/mca/coll/base/coll_base_util.h +++ b/ompi/mca/coll/base/coll_base_util.h @@ -31,6 +31,7 @@ #include "ompi/mca/coll/base/coll_tags.h" #include "ompi/op/op.h" #include "ompi/mca/pml/pml.h" +#include "opal/mca/accelerator/accelerator.h" BEGIN_C_DECLS @@ -200,5 +201,48 @@ int ompi_coll_base_file_peek_next_char_is(FILE *fptr, int *fileline, int expecte const char* mca_coll_base_colltype_to_str(int collid); int mca_coll_base_name_to_colltype(const char* name); +/* device/host memory allocation functions */ + + +void *ompi_coll_base_allocate_on_device(int device, size_t size, + mca_coll_base_module_t *module); + +void ompi_coll_base_free_on_device(int device, void *ptr, mca_coll_base_module_t *module); + + +static inline +void ompi_coll_base_select_device( + struct ompi_op_t *op, + const void *sendbuf, + const void *recvbuf, + size_t count, + struct ompi_datatype_t *dtype, + int *sendbuf_device, + int *recvbuf_device, + uint64_t *sendbuf_flags, + uint64_t *recvbuf_flags, + int *op_device) +{ + *recvbuf_device = MCA_ACCELERATOR_NO_DEVICE_ID; + *sendbuf_device = MCA_ACCELERATOR_NO_DEVICE_ID; + if (sendbuf != NULL && sendbuf != MPI_IN_PLACE) opal_accelerator.check_addr(sendbuf, sendbuf_device, sendbuf_flags); + if (recvbuf != NULL) opal_accelerator.check_addr(recvbuf, recvbuf_device, recvbuf_flags); + ompi_op_preferred_device(op, *recvbuf_device, *sendbuf_device, count, dtype, op_device); +} + +/** + * Frees memory allocated through ompi_coll_base_allocate_op_tmpbuf + * or ompi_coll_base_allocate_tmpbuf. + */ +static inline +void ompi_coll_base_free_tmpbuf(void *tmpbuf, int device, mca_coll_base_module_t *module) { + if (-1 == device) { + free(tmpbuf); + } else if (NULL != tmpbuf) { + ompi_coll_base_free_on_device(device, tmpbuf, module); + } +} + + END_C_DECLS #endif /* MCA_COLL_BASE_UTIL_EXPORT_H */