|
32 | 32 | #include "ompi/datatype/ompi_datatype.h" |
33 | 33 | #include "ompi/communicator/communicator.h" |
34 | 34 | #include "ompi/mca/coll/coll.h" |
35 | | -#include "ompi/mca/coll/base/coll_tags.h" |
36 | | -#include "ompi/mca/coll/base/coll_base_functions.h" |
| 35 | +#include "ompi/mca/coll/basic/coll_basic.h" |
37 | 36 | #include "ompi/mca/pml/pml.h" |
38 | 37 | #include "ompi/op/op.h" |
39 | | -#include "ompi/mca/coll/base/coll_base_functions.h" |
| 38 | +#include "coll_tags.h" |
| 39 | +#include "coll_base_functions.h" |
40 | 40 | #include "coll_base_topo.h" |
41 | 41 | #include "coll_base_util.h" |
42 | 42 |
|
| 43 | + |
43 | 44 | /* |
44 | 45 | * ompi_reduce_scatter_block_basic |
45 | 46 | * |
@@ -303,3 +304,210 @@ ompi_coll_base_reduce_scatter_block_intra_recursivedoubling( |
303 | 304 | free(tmprecv_raw); |
304 | 305 | return err; |
305 | 306 | } |
| 307 | + |
| 308 | +/* |
| 309 | + * ompi_range_sum: Returns sum of elems in intersection of [a, b] and [0, r] |
| 310 | + * index: 0 1 2 3 4 ... r r+1 r+2 ... nproc_pof2 |
| 311 | + * value: 2 2 2 2 2 ... 2 1 1 ... 1 |
| 312 | + */ |
| 313 | +static int ompi_range_sum(int a, int b, int r) |
| 314 | +{ |
| 315 | + if (r < a) |
| 316 | + return b - a + 1; |
| 317 | + else if (r > b) |
| 318 | + return 2 * (b - a + 1); |
| 319 | + return 2 * (r - a + 1) + b - r; |
| 320 | +} |
| 321 | + |
| 322 | +/* |
| 323 | + * ompi_coll_base_reduce_scatter_block_intra_recursivehalving |
| 324 | + * |
| 325 | + * Function: Recursive halving algorithm for reduce_scatter_block |
| 326 | + * Accepts: Same as MPI_Reduce_scatter_block |
| 327 | + * Returns: MPI_SUCCESS or error code |
| 328 | + * |
| 329 | + * Description: Implements recursive halving algorithm for MPI_Reduce_scatter_block. |
| 330 | + * The algorithm can be used by commutative operations only. |
| 331 | + * |
| 332 | + * Limitations: commutative operations only |
| 333 | + * Memory requirements (per process): 2 * rcount * comm_size * typesize |
| 334 | + */ |
| 335 | +int |
| 336 | +ompi_coll_base_reduce_scatter_block_intra_recursivehalving( |
| 337 | + const void *sbuf, void *rbuf, int rcount, struct ompi_datatype_t *dtype, |
| 338 | + struct ompi_op_t *op, struct ompi_communicator_t *comm, |
| 339 | + mca_coll_base_module_t *module) |
| 340 | +{ |
| 341 | + char *tmprecv_raw = NULL, *tmpbuf_raw = NULL, *tmprecv, *tmpbuf; |
| 342 | + ptrdiff_t span, gap, totalcount, extent; |
| 343 | + int err = MPI_SUCCESS; |
| 344 | + int comm_size = ompi_comm_size(comm); |
| 345 | + int rank = ompi_comm_rank(comm); |
| 346 | + |
| 347 | + OPAL_OUTPUT((ompi_coll_base_framework.framework_output, |
| 348 | + "coll:base:reduce_scatter_block_intra_recursivehalving: rank %d/%d", |
| 349 | + rank, comm_size)); |
| 350 | + if (rcount == 0 || comm_size < 2) |
| 351 | + return MPI_SUCCESS; |
| 352 | + |
| 353 | + if (!ompi_op_is_commute(op)) { |
| 354 | + OPAL_OUTPUT((ompi_coll_base_framework.framework_output, |
| 355 | + "coll:base:reduce_scatter_block_intra_recursivehalving: rank %d/%d " |
| 356 | + "switching to basic reduce_scatter_block", rank, comm_size)); |
| 357 | + return ompi_coll_base_reduce_scatter_block_basic(sbuf, rbuf, rcount, dtype, |
| 358 | + op, comm, module); |
| 359 | + } |
| 360 | + totalcount = comm_size * rcount; |
| 361 | + ompi_datatype_type_extent(dtype, &extent); |
| 362 | + span = opal_datatype_span(&dtype->super, totalcount, &gap); |
| 363 | + tmpbuf_raw = malloc(span); |
| 364 | + tmprecv_raw = malloc(span); |
| 365 | + if (NULL == tmpbuf_raw || NULL == tmprecv_raw) { |
| 366 | + err = OMPI_ERR_OUT_OF_RESOURCE; |
| 367 | + goto cleanup_and_return; |
| 368 | + } |
| 369 | + tmpbuf = tmpbuf_raw - gap; |
| 370 | + tmprecv = tmprecv_raw - gap; |
| 371 | + |
| 372 | + if (sbuf != MPI_IN_PLACE) { |
| 373 | + err = ompi_datatype_copy_content_same_ddt(dtype, totalcount, tmpbuf, (char *)sbuf); |
| 374 | + if (MPI_SUCCESS != err) { goto cleanup_and_return; } |
| 375 | + } else { |
| 376 | + err = ompi_datatype_copy_content_same_ddt(dtype, totalcount, tmpbuf, rbuf); |
| 377 | + if (MPI_SUCCESS != err) { goto cleanup_and_return; } |
| 378 | + } |
| 379 | + |
| 380 | + /* |
| 381 | + * Step 1. Reduce the number of processes to the nearest lower power of two |
| 382 | + * p' = 2^{\floor{\log_2 p}} by removing r = p - p' processes. |
| 383 | + * In the first 2r processes (ranks 0 to 2r - 1), all the even ranks send |
| 384 | + * the input vector to their neighbor (rank + 1) and all the odd ranks recv |
| 385 | + * the input vector and perform local reduction. |
| 386 | + * The odd ranks (0 to 2r - 1) contain the reduction with the input |
| 387 | + * vector on their neighbors (the even ranks). The first r odd |
| 388 | + * processes and the p - 2r last processes are renumbered from |
| 389 | + * 0 to 2^{\floor{\log_2 p}} - 1. Even ranks do not participate in the |
| 390 | + * rest of the algorithm. |
| 391 | + */ |
| 392 | + |
| 393 | + /* Find nearest power-of-two less than or equal to comm_size */ |
| 394 | + int nprocs_pof2 = opal_next_poweroftwo(comm_size); |
| 395 | + nprocs_pof2 >>= 1; |
| 396 | + int nprocs_rem = comm_size - nprocs_pof2; |
| 397 | + |
| 398 | + int vrank = -1; |
| 399 | + if (rank < 2 * nprocs_rem) { |
| 400 | + if ((rank % 2) == 0) { |
| 401 | + /* Even process */ |
| 402 | + err = MCA_PML_CALL(send(tmpbuf, totalcount, dtype, rank + 1, |
| 403 | + MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK, |
| 404 | + MCA_PML_BASE_SEND_STANDARD, comm)); |
| 405 | + if (OMPI_SUCCESS != err) { goto cleanup_and_return; } |
| 406 | + /* This process does not pariticipate in the rest of the algorithm */ |
| 407 | + vrank = -1; |
| 408 | + } else { |
| 409 | + /* Odd process */ |
| 410 | + err = MCA_PML_CALL(recv(tmprecv, totalcount, dtype, rank - 1, |
| 411 | + MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK, |
| 412 | + comm, MPI_STATUS_IGNORE)); |
| 413 | + if (OMPI_SUCCESS != err) { goto cleanup_and_return; } |
| 414 | + ompi_op_reduce(op, tmprecv, tmpbuf, totalcount, dtype); |
| 415 | + /* Adjust rank to be the bottom "remain" ranks */ |
| 416 | + vrank = rank / 2; |
| 417 | + } |
| 418 | + } else { |
| 419 | + /* Adjust rank to show that the bottom "even remain" ranks dropped out */ |
| 420 | + vrank = rank - nprocs_rem; |
| 421 | + } |
| 422 | + |
| 423 | + if (vrank != -1) { |
| 424 | + /* |
| 425 | + * Step 2. Recursive vector halving. We have p' = 2^{\floor{\log_2 p}} |
| 426 | + * power-of-two number of processes with new ranks (vrank) and partial |
| 427 | + * result in tmpbuf. |
| 428 | + * All processes then compute the reduction between the local |
| 429 | + * buffer and the received buffer. In the next \log_2(p') - 1 steps, the |
| 430 | + * buffers are recursively halved. At the end, each of the p' processes |
| 431 | + * has 1 / p' of the total reduction result. |
| 432 | + */ |
| 433 | + int send_index = 0, recv_index = 0, last_index = nprocs_pof2; |
| 434 | + for (int mask = nprocs_pof2 >> 1; mask > 0; mask >>= 1) { |
| 435 | + int vpeer = vrank ^ mask; |
| 436 | + int peer = (vpeer < nprocs_rem) ? vpeer * 2 + 1 : vpeer + nprocs_rem; |
| 437 | + |
| 438 | + /* |
| 439 | + * Calculate the recv_count and send_count because the |
| 440 | + * even-numbered processes who no longer participate will |
| 441 | + * have their result calculated by the process to their |
| 442 | + * right (rank + 1). |
| 443 | + */ |
| 444 | + int send_count = 0, recv_count = 0; |
| 445 | + if (vrank < vpeer) { |
| 446 | + /* Send the right half of the buffer, recv the left half */ |
| 447 | + send_index = recv_index + mask; |
| 448 | + send_count = rcount * ompi_range_sum(send_index, last_index - 1, nprocs_rem - 1); |
| 449 | + recv_count = rcount * ompi_range_sum(recv_index, send_index - 1, nprocs_rem - 1); |
| 450 | + } else { |
| 451 | + /* Send the left half of the buffer, recv the right half */ |
| 452 | + recv_index = send_index + mask; |
| 453 | + send_count = rcount * ompi_range_sum(send_index, recv_index - 1, nprocs_rem - 1); |
| 454 | + recv_count = rcount * ompi_range_sum(recv_index, last_index - 1, nprocs_rem - 1); |
| 455 | + } |
| 456 | + ptrdiff_t rdispl = rcount * ((recv_index <= nprocs_rem - 1) ? |
| 457 | + 2 * recv_index : nprocs_rem + recv_index); |
| 458 | + ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ? |
| 459 | + 2 * send_index : nprocs_rem + send_index); |
| 460 | + struct ompi_request_t *request = NULL; |
| 461 | + |
| 462 | + if (recv_count > 0) { |
| 463 | + err = MCA_PML_CALL(irecv(tmprecv + rdispl * extent, recv_count, |
| 464 | + dtype, peer, MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK, |
| 465 | + comm, &request)); |
| 466 | + if (OMPI_SUCCESS != err) { goto cleanup_and_return; } |
| 467 | + } |
| 468 | + if (send_count > 0) { |
| 469 | + err = MCA_PML_CALL(send(tmpbuf + sdispl * extent, send_count, |
| 470 | + dtype, peer, MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK, |
| 471 | + MCA_PML_BASE_SEND_STANDARD, |
| 472 | + comm)); |
| 473 | + if (OMPI_SUCCESS != err) { goto cleanup_and_return; } |
| 474 | + } |
| 475 | + if (recv_count > 0) { |
| 476 | + err = ompi_request_wait(&request, MPI_STATUS_IGNORE); |
| 477 | + if (OMPI_SUCCESS != err) { goto cleanup_and_return; } |
| 478 | + ompi_op_reduce(op, tmprecv + rdispl * extent, |
| 479 | + tmpbuf + rdispl * extent, recv_count, dtype); |
| 480 | + } |
| 481 | + send_index = recv_index; |
| 482 | + last_index = recv_index + mask; |
| 483 | + } |
| 484 | + err = ompi_datatype_copy_content_same_ddt(dtype, rcount, rbuf, |
| 485 | + tmpbuf + (ptrdiff_t)rank * rcount * extent); |
| 486 | + if (MPI_SUCCESS != err) { goto cleanup_and_return; } |
| 487 | + } |
| 488 | + |
| 489 | + /* Step 3. Send the result to excluded even ranks */ |
| 490 | + if (rank < 2 * nprocs_rem) { |
| 491 | + if ((rank % 2) == 0) { |
| 492 | + /* Even process */ |
| 493 | + err = MCA_PML_CALL(recv(rbuf, rcount, dtype, rank + 1, |
| 494 | + MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK, comm, |
| 495 | + MPI_STATUS_IGNORE)); |
| 496 | + if (OMPI_SUCCESS != err) { goto cleanup_and_return; } |
| 497 | + } else { |
| 498 | + /* Odd process */ |
| 499 | + err = MCA_PML_CALL(send(tmpbuf + (ptrdiff_t)(rank - 1) * rcount * extent, |
| 500 | + rcount, dtype, rank - 1, |
| 501 | + MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK, |
| 502 | + MCA_PML_BASE_SEND_STANDARD, comm)); |
| 503 | + if (MPI_SUCCESS != err) { goto cleanup_and_return; } |
| 504 | + } |
| 505 | + } |
| 506 | + |
| 507 | +cleanup_and_return: |
| 508 | + if (tmpbuf_raw) |
| 509 | + free(tmpbuf_raw); |
| 510 | + if (tmprecv_raw) |
| 511 | + free(tmprecv_raw); |
| 512 | + return err; |
| 513 | +} |
0 commit comments