|
14 | 14 | nccl_message = deps.nccl_import("the DistributedArray module") |
15 | 15 |
|
16 | 16 | if nccl_message is None and cupy_message is None: |
17 | | - from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split |
| 17 | + from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv |
18 | 18 | from cupy.cuda.nccl import NcclCommunicator |
19 | 19 | else: |
20 | 20 | NcclCommunicator = Any |
@@ -503,6 +503,35 @@ def _allgather(self, send_buf, recv_buf=None): |
503 | 503 | self.base_comm.Allgather(send_buf, recv_buf) |
504 | 504 | return recv_buf |
505 | 505 |
|
| 506 | + def _send(self, send_buf, dest, count=None, tag=None): |
| 507 | + """ Send operation |
| 508 | + """ |
| 509 | + if deps.nccl_enabled and getattr(self, "base_comm_nccl"): |
| 510 | + if count is None: |
| 511 | + # assuming sending the whole array |
| 512 | + count = send_buf.size |
| 513 | + nccl_send(self.base_comm_nccl, send_buf, dest, count) |
| 514 | + else: |
| 515 | + self.base_comm.Send(send_buf, dest, tag) |
| 516 | + |
| 517 | + def _recv(self, recv_buf=None, source=0, count=None, tag=None): |
| 518 | + """ Receive operation |
| 519 | + """ |
| 520 | + # NCCL must be called with recv_buf. Size cannot be inferred from |
| 521 | + # other arguments and thus cannot dynamically allocated |
| 522 | + if deps.nccl_enabled and getattr(self, "base_comm_nccl") and recv_buf is not None: |
| 523 | + if count is None: |
| 524 | + # assuming data will take a space of the whole buffer |
| 525 | + count = recv_buf.size |
| 526 | + nccl_recv(self.base_comm_nccl, recv_buf, source, count) |
| 527 | + return recv_buf |
| 528 | + else: |
| 529 | + # MPI allows a receiver buffer to be optional |
| 530 | + if recv_buf is None: |
| 531 | + return self.base_comm.recv(source=source, tag=tag) |
| 532 | + self.base_comm.Recv(buf=recv_buf, source=source, tag=tag) |
| 533 | + return recv_buf |
| 534 | + |
506 | 535 | def __neg__(self): |
507 | 536 | arr = DistributedArray(global_shape=self.global_shape, |
508 | 537 | base_comm=self.base_comm, |
@@ -747,50 +776,55 @@ def add_ghost_cells(self, cells_front: Optional[int] = None, |
747 | 776 |
|
748 | 777 | """ |
749 | 778 | ghosted_array = self.local_array.copy() |
| 779 | + ncp = get_module(getattr(self, "engine")) |
750 | 780 | if cells_front is not None: |
751 | | - # TODO: these are metadata (small size). Under current API, it will |
752 | | - # call nccl allgather, should we force it to always use MPI? |
753 | | - cells_fronts = self._allgather(cells_front) |
754 | | - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): |
755 | | - total_cells_front = cells_fronts.tolist() + [0] |
756 | | - else: |
757 | | - total_cells_front = cells_fronts + [0] |
| 781 | + # cells_front is small array of int. Explicitly use MPI |
| 782 | + total_cells_front = self.base_comm.allgather(cells_front) + [0] |
758 | 783 | # Read cells_front which needs to be sent to rank + 1(cells_front for rank + 1) |
759 | 784 | cells_front = total_cells_front[self.rank + 1] |
| 785 | + send_buf = ncp.take(self.local_array, ncp.arange(-cells_front, 0), axis=self.axis) |
| 786 | + recv_shapes = self.local_shapes |
760 | 787 | if self.rank != 0: |
761 | | - ghosted_array = np.concatenate([self.base_comm.recv(source=self.rank - 1, tag=1), ghosted_array], |
762 | | - axis=self.axis) |
763 | | - if self.rank != self.size - 1: |
| 788 | + # from receiver's perspective (rank), the recv buffer have the same shape as the sender's array (rank-1) |
| 789 | + # in every dimension except the shape at axis=self.axis |
| 790 | + recv_shape = list(recv_shapes[self.rank - 1]) |
| 791 | + recv_shape[self.axis] = total_cells_front[self.rank] |
| 792 | + recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype) |
| 793 | + # Some communication can skip if len(recv_buf) = 0 |
| 794 | + # Additionally, NCCL will hang if the buffer size is 0 so this optimization is somewhat mandatory |
| 795 | + if len(recv_buf) != 0: |
| 796 | + ghosted_array = ncp.concatenate([self._recv(recv_buf, source=self.rank - 1, tag=1), ghosted_array], axis=self.axis) |
| 797 | + # The skip in sender is to match with what described in receiver |
| 798 | + if self.rank != self.size - 1 and len(send_buf) != 0: |
764 | 799 | if cells_front > self.local_shape[self.axis]: |
765 | 800 | raise ValueError(f"Local Shape at rank={self.rank} along axis={self.axis} " |
766 | 801 | f"should be > {cells_front}: dim({self.axis}) " |
767 | 802 | f"{self.local_shape[self.axis]} < {cells_front}; " |
768 | 803 | f"to achieve this use NUM_PROCESSES <= " |
769 | 804 | f"{max(1, self.global_shape[self.axis] // cells_front)}") |
770 | | - # TODO: this array maybe large. Currently it will always use MPI. |
771 | | - # Should we enable NCCL point-point here ? |
772 | | - self.base_comm.send(np.take(self.local_array, np.arange(-cells_front, 0), axis=self.axis), |
773 | | - dest=self.rank + 1, tag=1) |
| 805 | + self._send(send_buf, dest=self.rank + 1, tag=1) |
774 | 806 | if cells_back is not None: |
775 | | - cells_backs = self._allgather(cells_back) |
776 | | - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): |
777 | | - total_cells_back = cells_backs.tolist() + [0] |
778 | | - else: |
779 | | - total_cells_back = cells_backs + [0] |
| 807 | + total_cells_back = self.base_comm.allgather(cells_back) + [0] |
780 | 808 | # Read cells_back which needs to be sent to rank - 1(cells_back for rank - 1) |
781 | 809 | cells_back = total_cells_back[self.rank - 1] |
782 | | - if self.rank != 0: |
| 810 | + send_buf = ncp.take(self.local_array, ncp.arange(cells_back), axis=self.axis) |
| 811 | + # Same reasoning as sending cell front applied |
| 812 | + recv_shapes = self.local_shapes |
| 813 | + if self.rank != 0 and len(send_buf) != 0: |
783 | 814 | if cells_back > self.local_shape[self.axis]: |
784 | 815 | raise ValueError(f"Local Shape at rank={self.rank} along axis={self.axis} " |
785 | 816 | f"should be > {cells_back}: dim({self.axis}) " |
786 | 817 | f"{self.local_shape[self.axis]} < {cells_back}; " |
787 | 818 | f"to achieve this use NUM_PROCESSES <= " |
788 | 819 | f"{max(1, self.global_shape[self.axis] // cells_back)}") |
789 | | - self.base_comm.send(np.take(self.local_array, np.arange(cells_back), axis=self.axis), |
790 | | - dest=self.rank - 1, tag=0) |
| 820 | + self._send(send_buf, dest=self.rank - 1, tag=0) |
791 | 821 | if self.rank != self.size - 1: |
792 | | - ghosted_array = np.append(ghosted_array, self.base_comm.recv(source=self.rank + 1, tag=0), |
793 | | - axis=self.axis) |
| 822 | + recv_shape = list(recv_shapes[self.rank + 1]) |
| 823 | + recv_shape[self.axis] = total_cells_back[self.rank] |
| 824 | + recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype) |
| 825 | + if len(recv_buf) != 0: |
| 826 | + ghosted_array = ncp.append(ghosted_array, self._recv(recv_buf, source=self.rank + 1, tag=0), |
| 827 | + axis=self.axis) |
794 | 828 | return ghosted_array |
795 | 829 |
|
796 | 830 | def __repr__(self): |
|
0 commit comments