-
Notifications
You must be signed in to change notification settings - Fork 6
add NCCL support to add_ghost_cells and operators in /basicoperators #137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
b567f86
7ae78a9
f80417d
3848408
58f1305
ae7190c
fb14c86
d7d07ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,9 @@ | |
| "nccl_allgather", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it may be good to add all There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alright, I can do that. Maybe in the other PR ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. If it is very small like this we can go for the same PR, if it is something a bit more consistent like the changes you made previously, it is good practice to have a separate documentation-only PR😄 |
||
| "nccl_allreduce", | ||
| "nccl_bcast", | ||
| "nccl_asarray" | ||
| "nccl_asarray", | ||
| "nccl_send", | ||
| "nccl_recv" | ||
| ] | ||
|
|
||
| from enum import IntEnum | ||
|
|
@@ -286,3 +288,57 @@ def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray: | |
| chunks[i] = chunks[i].reshape(send_shape)[slicing] | ||
| # combine back to single global array | ||
| return cp.concatenate(chunks, axis=axis) | ||
|
|
||
|
|
||
| def nccl_send(nccl_comm, send_buf, dest, count): | ||
| """NCCL equivalent of MPI_Send. Sends a specified number of elements | ||
| from the buffer to a destination GPU device. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` | ||
| The NCCL communicator used for point-to-point communication. | ||
| send_buf : :obj:`cupy.ndarray` | ||
| The array containing data to send. | ||
| dest: :obj:`int` | ||
| The rank of the destination GPU device. | ||
| count : :obj:`int` | ||
| Number of elements to send from `send_buf`. | ||
|
|
||
| Returns | ||
tharittk marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ------- | ||
| None | ||
| """ | ||
| nccl_comm.send(send_buf.data.ptr, | ||
| count, | ||
| cupy_to_nccl_dtype[str(send_buf.dtype)], | ||
| dest, | ||
| cp.cuda.Stream.null.ptr | ||
| ) | ||
|
|
||
|
|
||
| def nccl_recv(nccl_comm, recv_buf, source, count=None): | ||
| """NCCL equivalent of MPI_Recv. Receives data from a source GPU device | ||
| into the given buffer. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` | ||
| The NCCL communicator used for point-to-point communication. | ||
| recv_buf : :obj:`cupy.ndarray` | ||
| The array to store the received data. | ||
| source : :obj:`int` | ||
| The rank of the source GPU device. | ||
| count : :obj:`int`, optional | ||
| Number of elements to receive. | ||
|
|
||
| Returns | ||
| ------- | ||
| None | ||
| """ | ||
| nccl_comm.recv(recv_buf.data.ptr, | ||
| count, | ||
| cupy_to_nccl_dtype[str(recv_buf.dtype)], | ||
| source, | ||
| cp.cuda.Stream.null.ptr | ||
| ) | ||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -54,6 +54,7 @@ def wrapper(self, x: DistributedArray): | |||||||
| local_shapes = None | ||||||||
| global_shape = getattr(self, "dims") | ||||||||
| arr = DistributedArray(global_shape=global_shape, | ||||||||
| base_comm_nccl=x.base_comm_nccl, | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we are changing this, I think it would be safe to also pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| local_shapes=local_shapes, axis=0, | ||||||||
| engine=x.engine, dtype=x.dtype) | ||||||||
| arr_local_shapes = np.asarray(arr.base_comm.allgather(np.prod(arr.local_shape))) | ||||||||
|
|
||||||||
Uh oh!
There was an error while loading. Please reload this page.