|  | 
| 3 | 3 | from typing import Any, List, Optional, Tuple, Union, NewType | 
| 4 | 4 | 
 | 
| 5 | 5 | import numpy as np | 
|  | 6 | +import os | 
| 6 | 7 | from mpi4py import MPI | 
| 7 | 8 | from pylops.utils import DTypeLike, NDArray | 
| 8 | 9 | from pylops.utils import deps as pylops_deps  # avoid namespace crashes with pylops_mpi.utils | 
|  | 
| 21 | 22 | 
 | 
| 22 | 23 | NcclCommunicatorType = NewType("NcclCommunicator", NcclCommunicator) | 
| 23 | 24 | 
 | 
|  | 25 | +if int(os.environ.get("PYLOPS_MPI_CUDA_AWARE", 0)): | 
|  | 26 | +    is_cuda_aware_mpi = True | 
|  | 27 | +else: | 
|  | 28 | +    is_cuda_aware_mpi = False | 
| 24 | 29 | 
 | 
| 25 | 30 | class Partition(Enum): | 
| 26 | 31 |     r"""Enum class | 
| @@ -529,34 +534,52 @@ def _allgather_subcomm(self, send_buf, recv_buf=None): | 
| 529 | 534 |                 return self.sub_comm.allgather(send_buf) | 
| 530 | 535 |             self.sub_comm.Allgather(send_buf, recv_buf) | 
| 531 | 536 | 
 | 
| 532 |  | -    def _send(self, send_buf, dest, count=None, tag=None): | 
| 533 |  | -        """ Send operation | 
|  | 537 | +    def _send(self, send_buf, dest, count=None, tag=0): | 
|  | 538 | +        """Send operation | 
| 534 | 539 |         """ | 
| 535 | 540 |         if deps.nccl_enabled and self.base_comm_nccl: | 
| 536 | 541 |             if count is None: | 
| 537 |  | -                # assuming sending the whole array | 
| 538 | 542 |                 count = send_buf.size | 
| 539 | 543 |             nccl_send(self.base_comm_nccl, send_buf, dest, count) | 
| 540 | 544 |         else: | 
| 541 |  | -            self.base_comm.send(send_buf, dest, tag) | 
| 542 |  | - | 
| 543 |  | -    def _recv(self, recv_buf=None, source=0, count=None, tag=None): | 
| 544 |  | -        """ Receive operation | 
| 545 |  | -        """ | 
| 546 |  | -        # NCCL must be called with recv_buf. Size cannot be inferred from | 
| 547 |  | -        # other arguments and thus cannot be dynamically allocated | 
| 548 |  | -        if deps.nccl_enabled and self.base_comm_nccl and recv_buf is not None: | 
| 549 |  | -            if recv_buf is not None: | 
|  | 545 | +            if is_cuda_aware_mpi or self.engine == "numpy": | 
|  | 546 | +                # Determine MPI type based on array dtype | 
|  | 547 | +                mpi_type = MPI._typedict[send_buf.dtype.char] | 
| 550 | 548 |                 if count is None: | 
| 551 |  | -                    # assuming data will take a space of the whole buffer | 
| 552 |  | -                    count = recv_buf.size | 
| 553 |  | -                nccl_recv(self.base_comm_nccl, recv_buf, source, count) | 
| 554 |  | -                return recv_buf | 
|  | 549 | +                    count = send_buf.size | 
|  | 550 | +                self.base_comm.Send([send_buf, count, mpi_type], dest=dest, tag=tag) | 
| 555 | 551 |             else: | 
| 556 |  | -                raise ValueError("Using recv with NCCL must also supply receiver buffer ") | 
|  | 552 | +                # Uses CuPy without CUDA-aware MPI | 
|  | 553 | +                self.base_comm.send(send_buf, dest, tag) | 
|  | 554 | + | 
|  | 555 | + | 
|  | 556 | +    def _recv(self, recv_buf=None, source=0, count=None, tag=0): | 
|  | 557 | +        """Receive operation | 
|  | 558 | +        """ | 
|  | 559 | +        if deps.nccl_enabled and self.base_comm_nccl: | 
|  | 560 | +            if recv_buf is None: | 
|  | 561 | +                raise ValueError("recv_buf must be supplied when using NCCL") | 
|  | 562 | +            if count is None: | 
|  | 563 | +                count = recv_buf.size | 
|  | 564 | +            nccl_recv(self.base_comm_nccl, recv_buf, source, count) | 
|  | 565 | +            return recv_buf | 
| 557 | 566 |         else: | 
| 558 |  | -            # MPI allows a receiver buffer to be optional and receives as a Python Object | 
| 559 |  | -            return self.base_comm.recv(source=source, tag=tag) | 
|  | 567 | +            # NumPy + MPI will benefit from buffered communication regardless of MPI installation | 
|  | 568 | +            if is_cuda_aware_mpi or self.engine == "numpy": | 
|  | 569 | +                ncp = get_module(self.engine) | 
|  | 570 | +                if recv_buf is None: | 
|  | 571 | +                    if count is None: | 
|  | 572 | +                        raise ValueError("Must provide either recv_buf or count for MPI receive") | 
|  | 573 | +                    # Default to int32 works currently because add_ghost_cells() is called | 
|  | 574 | +                    # with recv_buf and is not affected by this branch. The int32 is for when | 
|  | 575 | +                    # dimension or shape-related integers are send/recv | 
|  | 576 | +                    recv_buf = ncp.zeros(count, dtype=ncp.int32) | 
|  | 577 | +                mpi_type = MPI._typedict[recv_buf.dtype.char] | 
|  | 578 | +                self.base_comm.Recv([recv_buf, recv_buf.size, mpi_type], source=source, tag=tag) | 
|  | 579 | +            else: | 
|  | 580 | +                # Uses CuPy without CUDA-aware MPI | 
|  | 581 | +                recv_buf = self.base_comm.recv(source=source, tag=tag) | 
|  | 582 | +            return recv_buf | 
| 560 | 583 | 
 | 
| 561 | 584 |     def _nccl_local_shapes(self, masked: bool): | 
| 562 | 585 |         """Get the the list of shapes of every GPU in the communicator | 
|  | 
0 commit comments