|
3 | 3 | from typing import Any, List, Optional, Tuple, Union, NewType |
4 | 4 |
|
5 | 5 | import numpy as np |
| 6 | +import os |
6 | 7 | from mpi4py import MPI |
7 | 8 | from pylops.utils import DTypeLike, NDArray |
8 | 9 | from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils |
|
21 | 22 |
|
22 | 23 | NcclCommunicatorType = NewType("NcclCommunicator", NcclCommunicator) |
23 | 24 |
|
| 25 | +if int(os.environ.get("PYLOPS_MPI_CUDA_AWARE", 0)): |
| 26 | + is_cuda_aware_mpi = True |
| 27 | +else: |
| 28 | + is_cuda_aware_mpi = False |
24 | 29 |
|
25 | 30 | class Partition(Enum): |
26 | 31 | r"""Enum class |
@@ -529,34 +534,52 @@ def _allgather_subcomm(self, send_buf, recv_buf=None): |
529 | 534 | return self.sub_comm.allgather(send_buf) |
530 | 535 | self.sub_comm.Allgather(send_buf, recv_buf) |
531 | 536 |
|
532 | | - def _send(self, send_buf, dest, count=None, tag=None): |
533 | | - """ Send operation |
| 537 | + def _send(self, send_buf, dest, count=None, tag=0): |
| 538 | + """Send operation |
534 | 539 | """ |
535 | 540 | if deps.nccl_enabled and self.base_comm_nccl: |
536 | 541 | if count is None: |
537 | | - # assuming sending the whole array |
538 | 542 | count = send_buf.size |
539 | 543 | nccl_send(self.base_comm_nccl, send_buf, dest, count) |
540 | 544 | else: |
541 | | - self.base_comm.send(send_buf, dest, tag) |
542 | | - |
543 | | - def _recv(self, recv_buf=None, source=0, count=None, tag=None): |
544 | | - """ Receive operation |
545 | | - """ |
546 | | - # NCCL must be called with recv_buf. Size cannot be inferred from |
547 | | - # other arguments and thus cannot be dynamically allocated |
548 | | - if deps.nccl_enabled and self.base_comm_nccl and recv_buf is not None: |
549 | | - if recv_buf is not None: |
| 545 | + if is_cuda_aware_mpi or self.engine == "numpy": |
| 546 | + # Determine MPI type based on array dtype |
| 547 | + mpi_type = MPI._typedict[send_buf.dtype.char] |
550 | 548 | if count is None: |
551 | | - # assuming data will take a space of the whole buffer |
552 | | - count = recv_buf.size |
553 | | - nccl_recv(self.base_comm_nccl, recv_buf, source, count) |
554 | | - return recv_buf |
| 549 | + count = send_buf.size |
| 550 | + self.base_comm.Send([send_buf, count, mpi_type], dest=dest, tag=tag) |
555 | 551 | else: |
556 | | - raise ValueError("Using recv with NCCL must also supply receiver buffer ") |
| 552 | + # Uses CuPy without CUDA-aware MPI |
| 553 | + self.base_comm.send(send_buf, dest, tag) |
| 554 | + |
| 555 | + |
| 556 | + def _recv(self, recv_buf=None, source=0, count=None, tag=0): |
| 557 | + """Receive operation |
| 558 | + """ |
| 559 | + if deps.nccl_enabled and self.base_comm_nccl: |
| 560 | + if recv_buf is None: |
| 561 | + raise ValueError("recv_buf must be supplied when using NCCL") |
| 562 | + if count is None: |
| 563 | + count = recv_buf.size |
| 564 | + nccl_recv(self.base_comm_nccl, recv_buf, source, count) |
| 565 | + return recv_buf |
557 | 566 | else: |
558 | | - # MPI allows a receiver buffer to be optional and receives as a Python Object |
559 | | - return self.base_comm.recv(source=source, tag=tag) |
| 567 | + # NumPy + MPI will benefit from buffered communication regardless of MPI installation |
| 568 | + if is_cuda_aware_mpi or self.engine == "numpy": |
| 569 | + ncp = get_module(self.engine) |
| 570 | + if recv_buf is None: |
| 571 | + if count is None: |
| 572 | + raise ValueError("Must provide either recv_buf or count for MPI receive") |
| 573 | + # Default to int32 works currently because add_ghost_cells() is called |
| 574 | + # with recv_buf and is not affected by this branch. The int32 is for when |
| 575 | + # dimension or shape-related integers are send/recv |
| 576 | + recv_buf = ncp.zeros(count, dtype=ncp.int32) |
| 577 | + mpi_type = MPI._typedict[recv_buf.dtype.char] |
| 578 | + self.base_comm.Recv([recv_buf, recv_buf.size, mpi_type], source=source, tag=tag) |
| 579 | + else: |
| 580 | + # Uses CuPy without CUDA-aware MPI |
| 581 | + recv_buf = self.base_comm.recv(source=source, tag=tag) |
| 582 | + return recv_buf |
560 | 583 |
|
561 | 584 | def _nccl_local_shapes(self, masked: bool): |
562 | 585 | """Get the the list of shapes of every GPU in the communicator |
|
0 commit comments