diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 565ef80..600c781 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -58,7 +58,6 @@ add_subdirectory(core) add_library(pplx_kernels SHARED bindings/all_to_all_ops.cpp bindings/bindings.cpp - bindings/nvshmem.cpp ) target_link_libraries(pplx_kernels PUBLIC all_to_all_internode_lib @@ -68,8 +67,8 @@ target_link_libraries(pplx_kernels PUBLIC Python::Module CUDA::cuda_driver CUDA::cudart - nvshmem::nvshmem - nvshmem::nvshmem_bootstrap_uid + nvshmem::nvshmem_host + nvshmem::nvshmem_device ) set_target_properties(pplx_kernels PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../src/pplx_kernels diff --git a/csrc/all_to_all/CMakeLists.txt b/csrc/all_to_all/CMakeLists.txt index fddc1a9..9cd333f 100644 --- a/csrc/all_to_all/CMakeLists.txt +++ b/csrc/all_to_all/CMakeLists.txt @@ -18,7 +18,7 @@ target_link_libraries(all_to_all_intranode_lib PUBLIC CUDA::cudart ) target_link_libraries(all_to_all_intranode_lib INTERFACE - nvshmem::nvshmem + nvshmem::nvshmem_host ) target_include_directories(all_to_all_intranode_lib PRIVATE ${NVSHMEM_INCLUDE_DIR}) set_cuda_compile_options(all_to_all_intranode_lib) @@ -33,7 +33,7 @@ target_link_libraries(all_to_all_internode_lib PUBLIC CUDA::cudart ) target_link_libraries(all_to_all_internode_lib INTERFACE - nvshmem::nvshmem + nvshmem::nvshmem_host ) target_include_directories(all_to_all_internode_lib PRIVATE ${NVSHMEM_INCLUDE_DIR}) set_cuda_compile_options(all_to_all_internode_lib) @@ -50,7 +50,7 @@ if(WITH_TESTS) CUDA::cudart CUDA::cuda_driver MPI::MPI_CXX - nvshmem::nvshmem + nvshmem::nvshmem_host ) set_cuda_compile_options(test_all_to_all) add_test(NAME AllToAllTest @@ -69,6 +69,6 @@ if (WITH_BENCHMARKS) CUDA::cudart CUDA::cuda_driver MPI::MPI_CXX - nvshmem::nvshmem + nvshmem::nvshmem_host ) endif() diff --git a/csrc/bindings/all_to_all_ops.cpp b/csrc/bindings/all_to_all_ops.cpp index a96ee97..34f6dd5 100644 --- a/csrc/bindings/all_to_all_ops.cpp +++ b/csrc/bindings/all_to_all_ops.cpp @@ -73,6 +73,10 @@ fptr_t create_internode( hiddenDimBytes, hiddenDimScaleBytes ); + + // Needed to use host-side initialization information in device APIs. + nvshmem_init(); + return (fptr_t)ptr; } diff --git a/csrc/bindings/bindings.cpp b/csrc/bindings/bindings.cpp index 311b55c..6d17534 100644 --- a/csrc/bindings/bindings.cpp +++ b/csrc/bindings/bindings.cpp @@ -1,14 +1,10 @@ #include #include "bindings/all_to_all_ops.h" -#include "bindings/nvshmem.h" #include "core/registration.h" using namespace pplx; -TORCH_LIBRARY(pplx_kernels, m) { - register_nvshmem_ops(m); - register_all_to_all_ops(m); -} +TORCH_LIBRARY(pplx_kernels, m) { register_all_to_all_ops(m); } -REGISTER_EXTENSION(TORCH_EXTENSION_NAME) +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) \ No newline at end of file diff --git a/csrc/bindings/nvshmem.cpp b/csrc/bindings/nvshmem.cpp deleted file mode 100644 index 6b72694..0000000 --- a/csrc/bindings/nvshmem.cpp +++ /dev/null @@ -1,99 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "bindings/nvshmem.h" -#include "core/nvshmem_utils.h" - -namespace { - -at::Tensor get_unique_id() { - nvshmemx_uniqueid_t uid = NVSHMEMX_UNIQUEID_INITIALIZER; - nvshmemx_get_uniqueid(&uid); - return at::from_blob(&uid, sizeof(uid), at::kByte).clone(); -} - -int64_t unique_id_size() { return sizeof(nvshmemx_uniqueid_t); } - -int64_t init(at::Tensor uid, int64_t rank, int64_t world_size) { - TORCH_CHECK(uid.device().is_cpu(), "uid must be a CPU tensor"); - TORCH_CHECK(uid.scalar_type() == at::kByte, "uid must be a byte tensor"); - TORCH_CHECK( - uid.numel() == sizeof(nvshmemx_uniqueid_t), - "Invalid unique id size (expected ", - sizeof(nvshmemx_uniqueid_t), - ", got ", - uid.numel(), - ")" - ); - nvshmemx_uniqueid_t id; - std::memcpy(&id, uid.data_ptr(), sizeof(id)); - nvshmemx_init_attr_t attr = NVSHMEMX_INIT_ATTR_INITIALIZER; - nvshmemx_set_attr_uniqueid_args(rank, world_size, &id, &attr); - return nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr); -} - -void finalize() { nvshmem_finalize(); } - -int64_t my_pe() { return nvshmem_my_pe(); } - -int64_t n_pes() { return nvshmem_n_pes(); } - -at::Tensor -malloc_tensor(const std::vector &shape, c10::ScalarType dtype, const c10::Device &device) { - size_t size = c10::elementSize(dtype) * c10::multiply_integers(shape); - void *ptr = nvshmem_malloc(size); - if (ptr == nullptr) { - AT_ERROR("nvshmem_malloc failed. size: ", size); - } - return at::from_blob( - ptr, - shape, - [](void *ptr) { nvshmem_free(ptr); }, - at::TensorOptions().dtype(dtype).device(device) - ); -} - -void barrier_all() { nvshmem_barrier_all(); } - -void barrier_all_on_current_stream() { - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - nvshmemx_barrier_all_on_stream(stream); -} - -void alltoall(at::Tensor dest, at::Tensor source) { - TORCH_CHECK(dest.is_contiguous(), "dest must be contiguous"); - TORCH_CHECK(source.is_contiguous(), "source must be contiguous"); - - size_t nbytes = dest.numel() * dest.itemsize() / dest.size(0); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - NVSHMEMCHECK(nvshmemx_alltoallmem_on_stream( - NVSHMEM_TEAM_WORLD, (uint8_t *)dest.data_ptr(), (uint8_t *)source.data_ptr(), nbytes, stream - )); -} - -void fake_alltoall(at::Tensor dest, at::Tensor source) {} - -} // namespace - -void pplx::register_nvshmem_ops(torch::Library &m) { - m.def("nvshmem_get_unique_id", &get_unique_id); - m.def("nvshmem_unique_id_size", &unique_id_size); - m.def("nvshmem_init", &init); - m.def("nvshmem_finalize", &finalize); - m.def("nvshmem_my_pe", &my_pe); - m.def("nvshmem_n_pes", &n_pes); - m.def("nvshmem_malloc", &malloc_tensor); - m.def("nvshmem_barrier_all", &barrier_all); - m.def("nvshmem_barrier_all_on_current_stream", &barrier_all_on_current_stream); - m.def("nvshmem_alltoall(Tensor! dest, Tensor src) -> ()"); - m.impl("nvshmem_alltoall", c10::kCUDA, &alltoall); - m.impl("nvshmem_alltoall", c10::kMeta, &fake_alltoall); -} diff --git a/csrc/bindings/nvshmem.h b/csrc/bindings/nvshmem.h deleted file mode 100644 index d14a988..0000000 --- a/csrc/bindings/nvshmem.h +++ /dev/null @@ -1,7 +0,0 @@ -#pragma once - -#include - -namespace pplx { -void register_nvshmem_ops(torch::Library &m); -} // namespace pplx diff --git a/csrc/core/CMakeLists.txt b/csrc/core/CMakeLists.txt index 07a762a..821035d 100644 --- a/csrc/core/CMakeLists.txt +++ b/csrc/core/CMakeLists.txt @@ -8,7 +8,7 @@ target_link_libraries(core_lib PUBLIC CUDA::cudart ) target_link_libraries(core_lib INTERFACE - nvshmem::nvshmem + nvshmem::nvshmem_host ) target_include_directories(core_lib PRIVATE ${NVSHMEM_INCLUDE_DIR}) set_cuda_compile_options(core_lib) diff --git a/src/pplx_kernels/__init__.py b/src/pplx_kernels/__init__.py index 8636240..e230c26 100644 --- a/src/pplx_kernels/__init__.py +++ b/src/pplx_kernels/__init__.py @@ -1,16 +1,6 @@ from . import ops as ops -from .all_to_all import ( - AllToAll as AllToAll, -) +from .all_to_all import AllToAll as AllToAll from .nvshmem import ( - nvshmem_alloc_empty_unique_id as nvshmem_alloc_empty_unique_id, - nvshmem_alltoall as nvshmem_alltoall, - nvshmem_barrier_all as nvshmem_barrier_all, - nvshmem_barrier_all_on_current_stream as nvshmem_barrier_all_on_current_stream, - nvshmem_finalize as nvshmem_finalize, - nvshmem_get_unique_id as nvshmem_get_unique_id, + PyTorchStreamWrapper as PyTorchStreamWrapper, nvshmem_init as nvshmem_init, - nvshmem_my_pe as nvshmem_my_pe, - nvshmem_n_pes as nvshmem_n_pes, - nvshmem_unique_id_size as nvshmem_unique_id_size, ) diff --git a/src/pplx_kernels/nvshmem.py b/src/pplx_kernels/nvshmem.py index c1143a0..5077607 100644 --- a/src/pplx_kernels/nvshmem.py +++ b/src/pplx_kernels/nvshmem.py @@ -1,60 +1,45 @@ # pyright: reportCallIssue=false -from collections.abc import Sequence +from typing import Any, Optional -import torch +import nvshmem.core as nvshmem # type: ignore[import] +import torch.distributed as dist -from .ops import _ops ###### NVSHMEM ###### - - -def nvshmem_get_unique_id() -> torch.Tensor: - return _ops.nvshmem_get_unique_id() - - -def nvshmem_unique_id_size() -> int: - return _ops.nvshmem_unique_id_size() - - -def nvshmem_alloc_empty_unique_id() -> torch.Tensor: - return torch.zeros(nvshmem_unique_id_size(), dtype=torch.uint8, device="cpu") - - -def nvshmem_init(uid: torch.Tensor, rank: int, world_size: int) -> int: - status = _ops.nvshmem_init(uid, rank, world_size) - torch.cuda.synchronize() - return status - - -def nvshmem_alltoall(dest: torch.Tensor, source: torch.Tensor) -> None: - return _ops.nvshmem_alltoall(dest, source) - - -def nvshmem_finalize() -> None: - torch.cuda.synchronize() - _ops.nvshmem_finalize() - - -def nvshmem_my_pe() -> int: - return _ops.nvshmem_my_pe() - - -def nvshmem_n_pes() -> int: - return _ops.nvshmem_n_pes() - - -def nvshmem_malloc( - shape: Sequence[int], - dtype: torch.dtype, - device: torch.device, -) -> torch.Tensor: - return _ops.nvshmem_malloc(shape, dtype, device) - - -def nvshmem_barrier_all() -> None: - _ops.nvshmem_barrier_all() - - -def nvshmem_barrier_all_on_current_stream() -> None: - _ops.nvshmem_barrier_all_on_current_stream() +def nvshmem_init( + global_rank: int, + local_rank: int, + world_size: int, + device: Any, + uid: Optional[Any] = None, +) -> None: + uniqueid = nvshmem.get_unique_id(empty=True) + if local_rank == 0: + uniqueid = nvshmem.get_unique_id() + broadcast_objects = [uniqueid] + else: + broadcast_objects = [None] + + dist.broadcast_object_list(broadcast_objects, src=0) + dist.barrier() + + nvshmem.init( + device=device, + uid=broadcast_objects[0], + rank=global_rank, + nranks=world_size, + initializer_method="uid", + ) + + +# This stream wrapper returns the format required by CUDA Python. This workaround will be removed when nvshmem4py supports Torch stream interoperability. +# For more information see: https://nvidia.github.io/cuda-python/cuda-core/latest/interoperability.html#cuda-stream-protocol +class PyTorchStreamWrapper: + def __init__(self, pt_stream: Any) -> None: + self.pt_stream = pt_stream + self.handle = pt_stream.cuda_stream + + def __cuda_stream__(self) -> tuple[int, int]: + stream_id = self.pt_stream.cuda_stream + return (0, stream_id) diff --git a/src/pplx_kernels/ops.py b/src/pplx_kernels/ops.py index 57c6a33..8a01f1a 100644 --- a/src/pplx_kernels/ops.py +++ b/src/pplx_kernels/ops.py @@ -5,6 +5,8 @@ import torch +logger = logging.getLogger(__name__) + try: _lib_path = os.path.join(os.path.dirname(__file__), "libpplx_kernels.so") torch.ops.load_library(_lib_path) @@ -13,4 +15,4 @@ from types import SimpleNamespace _ops = SimpleNamespace() - logging.exception("Error loading pplx-kernels") + logger.exception("Error loading pplx-kernels") diff --git a/tests/bench_all_to_all.py b/tests/bench_all_to_all.py index b715761..970a4d9 100644 --- a/tests/bench_all_to_all.py +++ b/tests/bench_all_to_all.py @@ -6,18 +6,13 @@ from datetime import datetime from pathlib import Path +import nvshmem.core as nvshmem # type: ignore[import] import torch +from cuda.core.experimental import Device # type: ignore[import] +from nvshmem.core import Teams # type: ignore[import] +from pplx_kernels import PyTorchStreamWrapper, nvshmem_init from pplx_kernels.all_to_all import AllToAll -from pplx_kernels.nvshmem import ( - nvshmem_alloc_empty_unique_id, - nvshmem_alltoall, - nvshmem_barrier_all_on_current_stream, - nvshmem_finalize, - nvshmem_get_unique_id, - nvshmem_init, - nvshmem_malloc, -) from .all_to_all_utils import MoEConfig, RankTestData from .distributed_utils import ( @@ -124,8 +119,8 @@ def bench_all_to_all( ) a2a_out_tensor = torch.empty_like(a2a_tensor) - nvshmem_in = nvshmem_malloc(a2a_shape, torch.uint8, device) - nvshmem_out = nvshmem_malloc(a2a_shape, torch.uint8, device) + nvshmem_in = nvshmem.tensor(a2a_shape, dtype=torch.uint8) + nvshmem_out = nvshmem.tensor(a2a_shape, dtype=torch.uint8) # Compute stats dispatch_bytes = ( @@ -147,11 +142,15 @@ def run() -> tuple[float, ...]: [torch.cuda.Event(enable_timing=True) for _ in range(5)] for _ in range(num_samples) ] - stream = torch.cuda.current_stream() + + torch_stream_wrapped = PyTorchStreamWrapper(torch.cuda.current_stream()) + torch_stream_ = torch.cuda.current_stream() for e0, e1, e2, e3, e4 in events: - nvshmem_barrier_all_on_current_stream() - e0.record(stream) + team = Teams.TEAM_WORLD + nvshmem.collective.barrier(team, torch_stream_wrapped) + + e0.record(torch_stream_) ata.dispatch( out_expert_num_tokens=expert_num_tokens, @@ -162,7 +161,7 @@ def run() -> tuple[float, ...]: indices=indices, bound_m=bound_m, ) - e1.record(stream) + e1.record(torch_stream_) ata.combine( out_tokens=y, @@ -171,16 +170,20 @@ def run() -> tuple[float, ...]: expert_y=expert_y, bound_m=bound_m, ) - e2.record(stream) + e2.record(torch_stream_) torch.distributed.all_to_all_single(a2a_out_tensor, a2a_tensor) - e3.record(stream) - nvshmem_alltoall(nvshmem_out, nvshmem_in) - e4.record(stream) + e3.record(torch_stream_) + + nvshmem.collective.alltoall( + team, nvshmem_out, nvshmem_in, stream=torch_stream_wrapped + ) + + e4.record(torch_stream_) # Get latency - stream.synchronize() + torch_stream_.synchronize() sum_dispatch_us = 0.0 sum_combine_us = 0.0 sum_a2a_us = 0.0 @@ -224,6 +227,9 @@ def run() -> tuple[float, ...]: # Cleanup ata.destroy() + nvshmem.free_tensor(nvshmem_in) + nvshmem.free_tensor(nvshmem_out) + return ( (dispatch_bytes, combine_bytes, a2a_bytes, nvshmem_bytes), result, @@ -236,9 +242,16 @@ def _worker_bench_all_to_all( in_dtype_str: str, out_dtype_str: str, ) -> None: - uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() - torch.distributed.broadcast(uid, src=0) - nvshmem_init(uid, pgi.rank, pgi.world_size) + num_ranks = pgi.world_size + global_rank = pgi.rank + local_rank = pgi.local_rank + + dev = Device(local_rank) + dev.set_current() + + nvshmem_init( + global_rank=global_rank, local_rank=local_rank, world_size=num_ranks, device=dev + ) in_dtype = getattr(torch, in_dtype_str) out_dtype = getattr(torch, out_dtype_str) @@ -334,7 +347,7 @@ def _worker_bench_all_to_all( f_out.close() print("Saved to", outpath) - nvshmem_finalize() + nvshmem.finalize() def main() -> None: diff --git a/tests/test_all_to_all.py b/tests/test_all_to_all.py index 3ddbf2c..eed68bf 100644 --- a/tests/test_all_to_all.py +++ b/tests/test_all_to_all.py @@ -1,16 +1,14 @@ import dataclasses import logging +import nvshmem.core as nvshmem # type: ignore[import] import pytest import torch +import torch.distributed as dist +from cuda.core.experimental import Device # type: ignore[import] +from pplx_kernels import nvshmem_init from pplx_kernels.all_to_all import AllToAll -from pplx_kernels.nvshmem import ( - nvshmem_alloc_empty_unique_id, - nvshmem_finalize, - nvshmem_get_unique_id, - nvshmem_init, -) from .all_to_all_utils import MoEConfig, RankTestData from .distributed_utils import ( @@ -295,9 +293,17 @@ def _worker_test_all_to_all( internode: bool, use_compile: bool = False, ) -> None: - uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() - torch.distributed.broadcast(uid, src=0) - nvshmem_init(uid, pgi.rank, pgi.world_size) + num_ranks = dist.get_world_size() + + global_rank = pgi.rank + local_rank = pgi.local_rank + + dev = Device(local_rank) + dev.set_current() + + nvshmem_init( + global_rank=global_rank, local_rank=local_rank, world_size=num_ranks, device=dev + ) moe_config = dataclasses.replace( moe_config, @@ -305,9 +311,18 @@ def _worker_test_all_to_all( out_dtype=getattr(torch, out_dtype), ) + test_script_init_status = nvshmem.direct.init_status() + if test_script_init_status < 2 and local_rank == 0: + logger.warning( + "NVSHMEM hostlib initialization incomplete - status: %d (rank: %d, local_rank: %d)", + test_script_init_status, + global_rank, + local_rank, + ) + _do_test_all_to_all(pgi, dp_size, moe_config, internode, use_compile) - nvshmem_finalize() + nvshmem.finalize() @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Requires at least 4 GPUs") diff --git a/tests/test_nvshmem.py b/tests/test_nvshmem.py index 2933d77..a3408a4 100644 --- a/tests/test_nvshmem.py +++ b/tests/test_nvshmem.py @@ -1,17 +1,13 @@ +import logging + +import nvshmem.core as nvshmem # type: ignore[import] import pytest import torch +import torch.distributed as dist +from cuda.core.experimental import Device # type: ignore[import] +from nvshmem.core import Teams # type: ignore[import] -from pplx_kernels.nvshmem import ( - nvshmem_alloc_empty_unique_id, - nvshmem_alltoall, - nvshmem_barrier_all_on_current_stream, - nvshmem_finalize, - nvshmem_get_unique_id, - nvshmem_init, - nvshmem_malloc, - nvshmem_my_pe, - nvshmem_n_pes, -) +from pplx_kernels import nvshmem_init from .distributed_utils import ( ProcessGroupInfo, @@ -20,22 +16,63 @@ require_multi_node, ) +logger = logging.getLogger(__name__) + def test_nvshmem_1_gpu() -> None: - uid = nvshmem_get_unique_id() - nvshmem_init(uid, 0, 1) - assert nvshmem_my_pe() == 0 - assert nvshmem_n_pes() == 1 - nvshmem_finalize() + local_rank = 0 + rank_id = 0 # Define rank_id for single GPU test + + torch.cuda.set_device(local_rank) + dev = Device(local_rank) + dev.set_current() + + uniqueid = nvshmem.get_unique_id() + nvshmem.init(device=dev, uid=uniqueid, rank=0, nranks=1, initializer_method="uid") + + # Check host initialization status + test_script_init_status = nvshmem.direct.init_status() + if test_script_init_status < 2 and local_rank == 0: + logger.warning( + "NVSHMEM hostlib initialization incomplete - status: %d (rank: %d, local_rank: %d)", + test_script_init_status, + rank_id, + local_rank, + ) + + assert nvshmem.my_pe() == 0 + assert nvshmem.n_pes() == 1 + + nvshmem.finalize() def _worker_test_nvshmem_4_gpu(pgi: ProcessGroupInfo) -> None: - uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() - torch.distributed.broadcast(uid, src=0) - nvshmem_init(uid, pgi.rank, pgi.world_size) - assert nvshmem_my_pe() == pgi.rank - assert nvshmem_n_pes() == pgi.world_size - nvshmem_finalize() + local_rank = dist.get_rank() + + dev = Device(local_rank) + dev.set_current() + + nvshmem_init( + global_rank=pgi.rank, + local_rank=local_rank, + world_size=pgi.world_size, + device=dev, + ) + + # Check host initialization status + test_script_init_status = nvshmem.direct.init_status() + if test_script_init_status < 2 and local_rank == 0: + logger.warning( + "NVSHMEM hostlib initialization incomplete - status: %d (rank: %d, local_rank: %d)", + test_script_init_status, + pgi.rank, + local_rank, + ) + + assert nvshmem.my_pe() == pgi.rank + assert nvshmem.n_pes() == pgi.world_size + + nvshmem.finalize() @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Requires at least 4 GPUs") @@ -44,26 +81,45 @@ def test_nvshmem_4_gpu() -> None: def _worker_test_all_to_all(pgi: ProcessGroupInfo) -> None: - uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() - torch.distributed.broadcast(uid, src=0) - nvshmem_init(uid, pgi.rank, pgi.world_size) - try: - t_in = nvshmem_malloc([pgi.world_size], dtype=torch.int32, device=pgi.device) - t_in.copy_( - torch.full([pgi.world_size], pgi.rank, dtype=torch.int32, device=pgi.device) + local_rank = dist.get_rank() + + dev = Device(local_rank) + dev.set_current() + + num_ranks = dist.get_world_size() + rank_id = dist.get_rank() + + nvshmem_init( + global_rank=rank_id, local_rank=local_rank, world_size=num_ranks, device=dev + ) + + # Check NVSHMEM host initialization status + test_script_init_status = nvshmem.direct.init_status() + if test_script_init_status < 2 and local_rank == 0: + logger.warning( + "NVSHMEM hostlib initialization incomplete - status: %d (rank: %d, local_rank: %d)", + test_script_init_status, + rank_id, + local_rank, ) - t_out = nvshmem_malloc([pgi.world_size], dtype=torch.int32, device=pgi.device) + # all-to-all test + try: + # Allocate a PyTorch tensor backed by NVSHMEM symmetric memory + t_in = nvshmem.tensor((pgi.world_size,), dtype=torch.int32).fill_(pgi.rank) + t_out = nvshmem.tensor((pgi.world_size,), dtype=torch.int32) + + team = Teams.TEAM_WORLD + nvshmem.collective.alltoall(team, t_out, t_in) - nvshmem_alltoall(t_out, t_in) - nvshmem_barrier_all_on_current_stream() + nvshmem.collective.barrier(team) torch.cuda.synchronize() assert t_out.tolist() == list(range(pgi.world_size)) finally: - del t_in - del t_out - nvshmem_finalize() + nvshmem.free_tensor(t_in) + nvshmem.free_tensor(t_out) + nvshmem.finalize() @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Requires at least 4 GPUs")