Skip to content

Commit d258021

Browse files
committed
Swap out existing NVSHMEM python bindings for official NVIDIA variant.
Changes include dynamic linking with host-side initialization, deletion of existing bindings, addition of nvshmem4py, and addition of helper functions.
1 parent c336faf commit d258021

File tree

13 files changed

+174
-249
lines changed

13 files changed

+174
-249
lines changed

csrc/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ add_subdirectory(core)
5858
add_library(pplx_kernels SHARED
5959
bindings/all_to_all_ops.cpp
6060
bindings/bindings.cpp
61-
bindings/nvshmem.cpp
6261
)
6362
target_link_libraries(pplx_kernels PUBLIC
6463
all_to_all_internode_lib
@@ -68,8 +67,9 @@ target_link_libraries(pplx_kernels PUBLIC
6867
Python::Module
6968
CUDA::cuda_driver
7069
CUDA::cudart
71-
nvshmem::nvshmem
70+
nvshmem::nvshmem_host
7271
nvshmem::nvshmem_bootstrap_uid
72+
nvshmem::nvshmem_device
7373
)
7474
set_target_properties(pplx_kernels PROPERTIES
7575
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../src/pplx_kernels

csrc/all_to_all/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ target_link_libraries(all_to_all_intranode_lib PUBLIC
1818
CUDA::cudart
1919
)
2020
target_link_libraries(all_to_all_intranode_lib INTERFACE
21-
nvshmem::nvshmem
21+
nvshmem::nvshmem_host
2222
)
2323
target_include_directories(all_to_all_intranode_lib PRIVATE ${NVSHMEM_INCLUDE_DIR})
2424
set_cuda_compile_options(all_to_all_intranode_lib)
@@ -33,7 +33,7 @@ target_link_libraries(all_to_all_internode_lib PUBLIC
3333
CUDA::cudart
3434
)
3535
target_link_libraries(all_to_all_internode_lib INTERFACE
36-
nvshmem::nvshmem
36+
nvshmem::nvshmem_host
3737
)
3838
target_include_directories(all_to_all_internode_lib PRIVATE ${NVSHMEM_INCLUDE_DIR})
3939
set_cuda_compile_options(all_to_all_internode_lib)
@@ -50,7 +50,7 @@ if(WITH_TESTS)
5050
CUDA::cudart
5151
CUDA::cuda_driver
5252
MPI::MPI_CXX
53-
nvshmem::nvshmem
53+
nvshmem::nvshmem_host
5454
)
5555
set_cuda_compile_options(test_all_to_all)
5656
add_test(NAME AllToAllTest
@@ -69,6 +69,6 @@ if (WITH_BENCHMARKS)
6969
CUDA::cudart
7070
CUDA::cuda_driver
7171
MPI::MPI_CXX
72-
nvshmem::nvshmem
72+
nvshmem::nvshmem_host
7373
)
7474
endif()

csrc/bindings/all_to_all_ops.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ fptr_t create_internode(
7373
hiddenDimBytes,
7474
hiddenDimScaleBytes
7575
);
76+
77+
// Needed to use host-side initialization information in device APIs.
78+
nvshmem_init();
79+
7680
return (fptr_t)ptr;
7781
}
7882

csrc/bindings/bindings.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
#include <torch/library.h>
22

33
#include "bindings/all_to_all_ops.h"
4-
#include "bindings/nvshmem.h"
54
#include "core/registration.h"
65

76
using namespace pplx;
87

9-
TORCH_LIBRARY(pplx_kernels, m) {
10-
register_nvshmem_ops(m);
11-
register_all_to_all_ops(m);
8+
TORCH_LIBRARY(pplx_kernels, m) {
9+
register_all_to_all_ops(m);
1210
}
1311

1412
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

csrc/bindings/nvshmem.cpp

Lines changed: 0 additions & 99 deletions
This file was deleted.

csrc/bindings/nvshmem.h

Lines changed: 0 additions & 7 deletions
This file was deleted.

csrc/core/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ target_link_libraries(core_lib PUBLIC
88
CUDA::cudart
99
)
1010
target_link_libraries(core_lib INTERFACE
11-
nvshmem::nvshmem
11+
nvshmem::nvshmem_host
1212
)
1313
target_include_directories(core_lib PRIVATE ${NVSHMEM_INCLUDE_DIR})
1414
set_cuda_compile_options(core_lib)

src/pplx_kernels/__init__.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,6 @@
11
from . import ops as ops
2-
from .all_to_all import (
3-
AllToAll as AllToAll,
4-
)
2+
from .all_to_all import AllToAll as AllToAll
53
from .nvshmem import (
6-
nvshmem_alloc_empty_unique_id as nvshmem_alloc_empty_unique_id,
7-
nvshmem_alltoall as nvshmem_alltoall,
8-
nvshmem_barrier_all as nvshmem_barrier_all,
9-
nvshmem_barrier_all_on_current_stream as nvshmem_barrier_all_on_current_stream,
10-
nvshmem_finalize as nvshmem_finalize,
11-
nvshmem_get_unique_id as nvshmem_get_unique_id,
4+
PyTorchStreamWrapper as PyTorchStreamWrapper,
125
nvshmem_init as nvshmem_init,
13-
nvshmem_my_pe as nvshmem_my_pe,
14-
nvshmem_n_pes as nvshmem_n_pes,
15-
nvshmem_unique_id_size as nvshmem_unique_id_size,
166
)

src/pplx_kernels/nvshmem.py

Lines changed: 22 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,35 @@
11
# pyright: reportCallIssue=false
22

3-
from collections.abc import Sequence
3+
from typing import Any, Optional
44

5-
import torch
5+
import nvshmem.core as nvshmem # type: ignore[import]
6+
import torch.distributed as dist
67

7-
from .ops import _ops
88

99
###### NVSHMEM ######
10+
def nvshmem_init(global_rank: int, local_rank: int, world_size: int, device: Any, uid: Optional[Any] = None) -> None:
11+
uniqueid = nvshmem.get_unique_id(empty=True)
12+
if local_rank == 0:
13+
uniqueid = nvshmem.get_unique_id()
14+
broadcast_objects = [uniqueid]
15+
else:
16+
broadcast_objects = [None]
1017

18+
dist.broadcast_object_list(broadcast_objects, src=0)
19+
dist.barrier()
1120

12-
def nvshmem_get_unique_id() -> torch.Tensor:
13-
return _ops.nvshmem_get_unique_id()
21+
nvshmem.init(device=device, uid=broadcast_objects[0], rank=global_rank, nranks=world_size, initializer_method="uid")
1422

1523

16-
def nvshmem_unique_id_size() -> int:
17-
return _ops.nvshmem_unique_id_size()
24+
# This stream wrapper returns the format required by CUDA Python. This workaround will be removed when nvshmem4py supports Torch stream interoperability.
25+
# For more information see: https://nvidia.github.io/cuda-python/cuda-core/latest/interoperability.html#cuda-stream-protocol
26+
class PyTorchStreamWrapper:
27+
def __init__(self, pt_stream: Any) -> None:
28+
self.pt_stream = pt_stream
29+
self.handle = pt_stream.cuda_stream
1830

31+
def __cuda_stream__(self) -> tuple[int, int]:
32+
stream_id = self.pt_stream.cuda_stream
33+
return (0, stream_id)
1934

20-
def nvshmem_alloc_empty_unique_id() -> torch.Tensor:
21-
return torch.zeros(nvshmem_unique_id_size(), dtype=torch.uint8, device="cpu")
2235

23-
24-
def nvshmem_init(uid: torch.Tensor, rank: int, world_size: int) -> int:
25-
status = _ops.nvshmem_init(uid, rank, world_size)
26-
torch.cuda.synchronize()
27-
return status
28-
29-
30-
def nvshmem_alltoall(dest: torch.Tensor, source: torch.Tensor) -> None:
31-
return _ops.nvshmem_alltoall(dest, source)
32-
33-
34-
def nvshmem_finalize() -> None:
35-
torch.cuda.synchronize()
36-
_ops.nvshmem_finalize()
37-
38-
39-
def nvshmem_my_pe() -> int:
40-
return _ops.nvshmem_my_pe()
41-
42-
43-
def nvshmem_n_pes() -> int:
44-
return _ops.nvshmem_n_pes()
45-
46-
47-
def nvshmem_malloc(
48-
shape: Sequence[int],
49-
dtype: torch.dtype,
50-
device: torch.device,
51-
) -> torch.Tensor:
52-
return _ops.nvshmem_malloc(shape, dtype, device)
53-
54-
55-
def nvshmem_barrier_all() -> None:
56-
_ops.nvshmem_barrier_all()
57-
58-
59-
def nvshmem_barrier_all_on_current_stream() -> None:
60-
_ops.nvshmem_barrier_all_on_current_stream()

src/pplx_kernels/ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import torch
77

8+
logger = logging.getLogger(__name__)
9+
810
try:
911
_lib_path = os.path.join(os.path.dirname(__file__), "libpplx_kernels.so")
1012
torch.ops.load_library(_lib_path)
@@ -13,4 +15,4 @@
1315
from types import SimpleNamespace
1416

1517
_ops = SimpleNamespace()
16-
logging.exception("Error loading pplx-kernels")
18+
logger.exception("Error loading pplx-kernels")

0 commit comments

Comments
 (0)