Skip to content

NVSHMEM4Py Integration #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ add_subdirectory(core)
add_library(pplx_kernels SHARED
bindings/all_to_all_ops.cpp
bindings/bindings.cpp
bindings/nvshmem.cpp
)
target_link_libraries(pplx_kernels PUBLIC
all_to_all_internode_lib
Expand All @@ -68,8 +67,8 @@ target_link_libraries(pplx_kernels PUBLIC
Python::Module
CUDA::cuda_driver
CUDA::cudart
nvshmem::nvshmem
nvshmem::nvshmem_bootstrap_uid
nvshmem::nvshmem_host
nvshmem::nvshmem_device
)
set_target_properties(pplx_kernels PROPERTIES
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../src/pplx_kernels
Expand Down
8 changes: 4 additions & 4 deletions csrc/all_to_all/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ target_link_libraries(all_to_all_intranode_lib PUBLIC
CUDA::cudart
)
target_link_libraries(all_to_all_intranode_lib INTERFACE
nvshmem::nvshmem
nvshmem::nvshmem_host
)
target_include_directories(all_to_all_intranode_lib PRIVATE ${NVSHMEM_INCLUDE_DIR})
set_cuda_compile_options(all_to_all_intranode_lib)
Expand All @@ -33,7 +33,7 @@ target_link_libraries(all_to_all_internode_lib PUBLIC
CUDA::cudart
)
target_link_libraries(all_to_all_internode_lib INTERFACE
nvshmem::nvshmem
nvshmem::nvshmem_host
)
target_include_directories(all_to_all_internode_lib PRIVATE ${NVSHMEM_INCLUDE_DIR})
set_cuda_compile_options(all_to_all_internode_lib)
Expand All @@ -50,7 +50,7 @@ if(WITH_TESTS)
CUDA::cudart
CUDA::cuda_driver
MPI::MPI_CXX
nvshmem::nvshmem
nvshmem::nvshmem_host
)
set_cuda_compile_options(test_all_to_all)
add_test(NAME AllToAllTest
Expand All @@ -69,6 +69,6 @@ if (WITH_BENCHMARKS)
CUDA::cudart
CUDA::cuda_driver
MPI::MPI_CXX
nvshmem::nvshmem
nvshmem::nvshmem_host
)
endif()
4 changes: 4 additions & 0 deletions csrc/bindings/all_to_all_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ fptr_t create_internode(
hiddenDimBytes,
hiddenDimScaleBytes
);

// Needed to use host-side initialization information in device APIs.
nvshmem_init();

return (fptr_t)ptr;
}

Expand Down
8 changes: 2 additions & 6 deletions csrc/bindings/bindings.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
#include <torch/library.h>

#include "bindings/all_to_all_ops.h"
#include "bindings/nvshmem.h"
#include "core/registration.h"

using namespace pplx;

TORCH_LIBRARY(pplx_kernels, m) {
register_nvshmem_ops(m);
register_all_to_all_ops(m);
}
TORCH_LIBRARY(pplx_kernels, m) { register_all_to_all_ops(m); }

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
99 changes: 0 additions & 99 deletions csrc/bindings/nvshmem.cpp

This file was deleted.

7 changes: 0 additions & 7 deletions csrc/bindings/nvshmem.h

This file was deleted.

2 changes: 1 addition & 1 deletion csrc/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ target_link_libraries(core_lib PUBLIC
CUDA::cudart
)
target_link_libraries(core_lib INTERFACE
nvshmem::nvshmem
nvshmem::nvshmem_host
)
target_include_directories(core_lib PRIVATE ${NVSHMEM_INCLUDE_DIR})
set_cuda_compile_options(core_lib)
14 changes: 2 additions & 12 deletions src/pplx_kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
from . import ops as ops
from .all_to_all import (
AllToAll as AllToAll,
)
from .all_to_all import AllToAll as AllToAll
from .nvshmem import (
nvshmem_alloc_empty_unique_id as nvshmem_alloc_empty_unique_id,
nvshmem_alltoall as nvshmem_alltoall,
nvshmem_barrier_all as nvshmem_barrier_all,
nvshmem_barrier_all_on_current_stream as nvshmem_barrier_all_on_current_stream,
nvshmem_finalize as nvshmem_finalize,
nvshmem_get_unique_id as nvshmem_get_unique_id,
PyTorchStreamWrapper as PyTorchStreamWrapper,
nvshmem_init as nvshmem_init,
nvshmem_my_pe as nvshmem_my_pe,
nvshmem_n_pes as nvshmem_n_pes,
nvshmem_unique_id_size as nvshmem_unique_id_size,
)
93 changes: 39 additions & 54 deletions src/pplx_kernels/nvshmem.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,45 @@
# pyright: reportCallIssue=false

from collections.abc import Sequence
from typing import Any, Optional

import torch
import nvshmem.core as nvshmem # type: ignore[import]
import torch.distributed as dist

from .ops import _ops

###### NVSHMEM ######


def nvshmem_get_unique_id() -> torch.Tensor:
return _ops.nvshmem_get_unique_id()


def nvshmem_unique_id_size() -> int:
return _ops.nvshmem_unique_id_size()


def nvshmem_alloc_empty_unique_id() -> torch.Tensor:
return torch.zeros(nvshmem_unique_id_size(), dtype=torch.uint8, device="cpu")


def nvshmem_init(uid: torch.Tensor, rank: int, world_size: int) -> int:
status = _ops.nvshmem_init(uid, rank, world_size)
torch.cuda.synchronize()
return status


def nvshmem_alltoall(dest: torch.Tensor, source: torch.Tensor) -> None:
return _ops.nvshmem_alltoall(dest, source)


def nvshmem_finalize() -> None:
torch.cuda.synchronize()
_ops.nvshmem_finalize()


def nvshmem_my_pe() -> int:
return _ops.nvshmem_my_pe()


def nvshmem_n_pes() -> int:
return _ops.nvshmem_n_pes()


def nvshmem_malloc(
shape: Sequence[int],
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
return _ops.nvshmem_malloc(shape, dtype, device)


def nvshmem_barrier_all() -> None:
_ops.nvshmem_barrier_all()


def nvshmem_barrier_all_on_current_stream() -> None:
_ops.nvshmem_barrier_all_on_current_stream()
def nvshmem_init(
global_rank: int,
local_rank: int,
world_size: int,
device: Any,
uid: Optional[Any] = None,
) -> None:
uniqueid = nvshmem.get_unique_id(empty=True)
if local_rank == 0:
uniqueid = nvshmem.get_unique_id()
broadcast_objects = [uniqueid]
else:
broadcast_objects = [None]

dist.broadcast_object_list(broadcast_objects, src=0)
dist.barrier()

nvshmem.init(
device=device,
uid=broadcast_objects[0],
rank=global_rank,
nranks=world_size,
initializer_method="uid",
)


# This stream wrapper returns the format required by CUDA Python. This workaround will be removed when nvshmem4py supports Torch stream interoperability.
# For more information see: https://nvidia.github.io/cuda-python/cuda-core/latest/interoperability.html#cuda-stream-protocol
class PyTorchStreamWrapper:
def __init__(self, pt_stream: Any) -> None:
self.pt_stream = pt_stream
self.handle = pt_stream.cuda_stream

def __cuda_stream__(self) -> tuple[int, int]:
stream_id = self.pt_stream.cuda_stream
return (0, stream_id)
4 changes: 3 additions & 1 deletion src/pplx_kernels/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import torch

logger = logging.getLogger(__name__)

try:
_lib_path = os.path.join(os.path.dirname(__file__), "libpplx_kernels.so")
torch.ops.load_library(_lib_path)
Expand All @@ -13,4 +15,4 @@
from types import SimpleNamespace

_ops = SimpleNamespace()
logging.exception("Error loading pplx-kernels")
logger.exception("Error loading pplx-kernels")
Loading