diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 0609d1bc9..6c6ec75a8 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -94,6 +94,14 @@ def setup_pytorch_extension( libraries.append("nvshmem_host") cxx_flags.append("-DNVTE_ENABLE_NVSHMEM") + if bool(int(os.getenv("NVTE_ENABLE_ROCSHMEM", 0))): + cxx_flags.append("-DNVTE_ENABLE_ROCSHMEM") + mpi_home = Path(os.getenv("MPI_HOME", "/usr/lib/x86_64-linux-gnu/openmpi")) + include_dirs.append(mpi_home / "include") + library_dirs.append(mpi_home / "lib") + libraries.append("mpi_cxx") + + # Construct PyTorch CUDA extension sources = [str(path) for path in sources] include_dirs = [str(path) for path in include_dirs] diff --git a/setup.py b/setup.py index 91817d56e..c72f201f4 100644 --- a/setup.py +++ b/setup.py @@ -73,6 +73,13 @@ def setup_common_extension() -> CMakeExtension: cmake_flags.append("-DUSE_FUSED_ATTN_AOTRITON=OFF") if int(os.getenv("NVTE_FUSED_ATTN_CK", "1"))==0 or int(os.getenv("NVTE_FUSED_ATTN", "1"))==0: cmake_flags.append("-DUSE_FUSED_ATTN_CK=OFF") + if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))) and os.getenv("NVTE_ENABLE_ROCSHMEM") is None: + os.environ["NVTE_ENABLE_ROCSHMEM"] = '1' + os.environ["NVTE_ENABLE_NVSHMEM"] = '0' + print("Turning NVTE_ENABLE_ROCSHMEM on, disabling NVTE_ENABLE_NVSHMEM") + if bool(int(os.getenv("NVTE_ENABLE_ROCSHMEM", "0"))): + cmake_flags.append("-DNVTE_ENABLE_ROCSHMEM=ON") + else: cmake_flags.append("-DUSE_ROCM=OFF") cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)] diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 5b0f1981d..c3fbab436 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -238,6 +238,7 @@ else() IGNORES "*/amd_detail/*" IGNORES "*/aotriton/*" IGNORES "*/ck_fused_attn/*" + IGNORES "*/rocshmem_api/*" CUSTOM_MAP_FILE "${TE}/hipify_custom_map.json" NO_MATH_REPLACE ) @@ -382,6 +383,19 @@ if(USE_CUDA) # Hack to enable dynamic loading in cuDNN frontend target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) +else() + option(NVTE_ENABLE_ROCSHMEM "Compile with ROCSHMEM library" OFF) + if (NVTE_ENABLE_ROCSHMEM) + add_subdirectory(rocshmem_api) + target_link_options(transformer_engine PRIVATE + -fgpu-rdc + ) + target_link_libraries(transformer_engine PUBLIC + -Wl,--whole-archive + rocshmemapi + -Wl,--no-whole-archive + ) + endif() endif() # Helper functions to make header files with C++ strings diff --git a/transformer_engine/common/libtransformer_engine.version b/transformer_engine/common/libtransformer_engine.version index 4412d0c5f..d395e1f3a 100644 --- a/transformer_engine/common/libtransformer_engine.version +++ b/transformer_engine/common/libtransformer_engine.version @@ -18,7 +18,8 @@ *transformer_engine::CommOverlapP2PBase*; *transformer_engine::CommOverlapCore*; *nvshmem_wait_on_stream*; - *nvshmemi_init_thread* + *nvshmemi_init_thread*; + *te_rocshmem* }; local: *; }; diff --git a/transformer_engine/common/rocshmem_api/CMakeLists.txt b/transformer_engine/common/rocshmem_api/CMakeLists.txt new file mode 100644 index 000000000..fc15f1529 --- /dev/null +++ b/transformer_engine/common/rocshmem_api/CMakeLists.txt @@ -0,0 +1,54 @@ +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# License for AMD contributions = MIT. See LICENSE for more information + +cmake_minimum_required (VERSION 3.21) +project(rocshmem LANGUAGES HIP) + +find_package(hip REQUIRED) +find_package(MPI REQUIRED) + +add_library(rocshmemapi STATIC rocshmem_waitkernel.hip) + +target_compile_options(rocshmemapi PRIVATE + $<$:-fgpu-rdc> +) + +set_target_properties(rocshmemapi PROPERTIES + CXX_STANDARD 17 + HIP_STANDARD 17 + POSITION_INDEPENDENT_CODE ON +) + +if(DEFINED ENV{ROCSHMEM_HOME}) + set(ROCSHMEM_HOME "$ENV{ROCSHMEM_HOME}" CACHE STRING "Location of ROCSHMEM installation") +else() + set(ROCSHMEM_HOME "/opt/rocm" CACHE STRING "Location of ROCSHMEM installation (default)") +endif() +set(ROCSHMEM_LIBRARY_PATH "${ROCSHMEM_HOME}/lib/librocshmem.a") +if (EXISTS ${ROCSHMEM_LIBRARY_PATH}) + add_library(rocshmem STATIC IMPORTED) + set_target_properties(rocshmem PROPERTIES + IMPORTED_LOCATION "${ROCSHMEM_LIBRARY_PATH}" + IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" + ) +else() + message(FATAL_ERROR "ROCSHMEM library not found at ${ROCSHMEM_LIBRARY_PATH}. PLease set ROCSHMEM_HOME.") +endif() + +set(ROCSHMEM_INCLUDE_DIR "${ROCSHMEM_HOME}/include/rocshmem") +if(NOT EXISTS "${ROCSHMEM_INCLUDE_DIR}") + set(ROCSHMEM_INCLUDE_DIR "${ROCSHMEM_HOME}/include") +endif() + +target_include_directories(rocshmemapi PUBLIC + "${ROCSHMEM_INCLUDE_DIR}" + "${CMAKE_CURRENT_SOURCE_DIR}" + "${MPI_INCLUDE_PATH}" +) + +target_link_libraries(rocshmemapi PUBLIC + rocshmem + MPI::MPI_CXX + hip::host + hip::device +) diff --git a/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip b/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip new file mode 100644 index 000000000..9f7fe0e2e --- /dev/null +++ b/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip @@ -0,0 +1,115 @@ +/************************************************************************* + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * License for AMD contributions = MIT. See LICENSE for more information +*************************************************************************/ + +#include +#include + +#include "../util/logging_hip.h" +#include "rocshmem_waitkernel.hpp" + +using namespace rocshmem; + +__global__ void wait_until_on_stream_and_reset(uint64_t *wait_flag, + uint64_t wait_value, + uint64_t signal_reset) { + rocshmem_ulonglong_wait_until((unsigned long long*)wait_flag, + ROCSHMEM_CMP_EQ, + (unsigned long long)wait_value); +} + +__global__ void rocshmem_putmem_signal_kernel(void* dst_ptr, const void* src_ptr, + size_t nelement, uint64_t* sig_addr, + uint64_t sigval, int peer) { + if (threadIdx.x == 0 && blockIdx.x == 0) { + rocshmem_putmem(dst_ptr, src_ptr, nelement, peer); + rocshmem_fence(); + rocshmem_ulonglong_p((unsigned long long*)sig_addr, + (unsigned long long)sigval, + peer); + } +} + +void te_rocshmem_putmem_signal(void* dst_ptr, const void* src_ptr, size_t nelement, + uint64_t* sig_addr, uint64_t sigval, int peer, + hipStream_t cur_stream) { + hipLaunchKernelGGL(rocshmem_putmem_signal_kernel, + dim3(1), dim3(1), 0, cur_stream, + dst_ptr, src_ptr, nelement, sig_addr, + sigval, peer); +} + +void te_rocshmem_wait_on_stream(uint64_t* sig_addr, + WaitKind wait_kind, + hipStream_t cur_stream) { + uint64_t wait_value = 1; + uint64_t signal_reset = 0; + + NVTE_CHECK(wait_kind >= WaitKind::KERNEL_WAIT && + wait_kind <= WaitKind::STREAM_WAIT, + "Invalid wait kind"); + + switch (wait_kind) { +// ### wait_until_on_stream not yet implemented for rocshmem ### +// ### KernelWait is robust but slightly slower due to launch ### + case WaitKind::ROCSHMEM_WAIT: + printf("WARNING: rocshmem wait is not implemented yet, defaulting to kernel wait.\n"); + // rocshmem__ulonglong_wait_until_on_stream(sig_addr, + // ROCSHMEM_CMP_EQ, + // wait_value, + // cur_stream); + // hipStreamWriteValue64(cur_stream, + // reinterpret_cast(sig_addr), + // signal_reset, 0); + // break; + case WaitKind::KERNEL_WAIT: + hipLaunchKernelGGL(wait_until_on_stream_and_reset, + dim3(1), dim3(1), 0, cur_stream, + sig_addr, wait_value, signal_reset); + hipStreamWriteValue64(cur_stream, + reinterpret_cast(sig_addr), + signal_reset, 0); + break; + case WaitKind::STREAM_WAIT: + hipStreamWaitValue64(cur_stream, + reinterpret_cast(sig_addr), + wait_value, hipStreamWaitValueGte); + hipStreamWriteValue64(cur_stream, + reinterpret_cast(sig_addr), + signal_reset, 0); + break; + } +} + +int te_rocshmem_init_thread(int required, int* provided) { + if (required == 0 && provided == nullptr) { + rocshmem_init(); + return 0; + } else { + return rocshmem_init_thread(required, provided); + } +} + +void te_rocshmem_finalize() { + rocshmem_finalize(); +} + +int te_rocshmem_my_pe() { + return rocshmem_my_pe(); +} + +int te_rocshmem_n_pes() { + return rocshmem_n_pes(); +} + +void* te_rocshmem_malloc(size_t size) { + return rocshmem_malloc(size); +} + +void te_rocshmem_free(void* ptr) { + rocshmem_free(ptr); +} + +void te_rocshmem_wait_until(uint64_t* signal_addr, uint64_t expected_value, + hipStream_t stream); \ No newline at end of file diff --git a/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hpp b/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hpp new file mode 100644 index 000000000..03008e863 --- /dev/null +++ b/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hpp @@ -0,0 +1,33 @@ +/************************************************************************* + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * License for AMD contributions = MIT. See LICENSE for more information +*************************************************************************/ + +#pragma once + +#include + +enum class WaitKind : uint8_t { + KERNEL_WAIT = 0, + ROCSHMEM_WAIT = 1, + STREAM_WAIT = 2 +}; + +void te_rocshmem_wait_on_stream(uint64_t *sig_addr, WaitKind wait_kind, hipStream_t cur_stream); + +void te_rocshmem_putmem_signal(void* dst_ptr, const void* src_ptr, size_t nelement, + uint64_t* sig_addr, uint64_t sigval, int peer, hipStream_t cur_stream); + +/* +These are minimal wrappers around rocshmem functions. As pytorch is a cpp extension, +rocshmem is a static library, and rocshmem does not have separate host / device libraries +we need to move these to common, which handles device code properly. +*/ +int te_rocshmem_init_thread(int required, int* provided); +void te_rocshmem_finalize(); +int te_rocshmem_my_pe(); +int te_rocshmem_n_pes(); +void* te_rocshmem_malloc(size_t size); +void te_rocshmem_free(void* ptr); +void te_rocshmem_wait_until(uint64_t* signal_addr, uint64_t expected_value, + hipStream_t stream); \ No newline at end of file diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 2ff64ae90..3958a3a3d 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -395,6 +395,20 @@ void nvshmem_send_on_current_stream(at::Tensor src, at::Tensor dst, int peer, at void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_kind); void nvshmem_finalize(); +#else +/*************************************************************************************************** + * ROCSHMEM APIs + **************************************************************************************************/ + +void init_rocshmem_backend(c10d::ProcessGroup *process_group); + +at::Tensor create_rocshmem_tensor(const std::vector &shape, c10::ScalarType dtype); + +void rocshmem_send_on_current_stream(at::Tensor src, at::Tensor dst, int peer, at::Tensor signal); + +void rocshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_kind); + +void rocshmem_finalize(); #endif } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 93a42bcc3..65805759a 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -303,6 +303,25 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("nvshmem_finalize", &transformer_engine::pytorch::nvshmem_finalize, "Clean up and finalize the NVSHMEM communication backend and free associated resources", py::call_guard()); +#else + // nvshmem/rocshmem wrappers + m.def("init_nvshmem_backend", &transformer_engine::pytorch::init_rocshmem_backend, + "Initialize ROCSHMEM backend with Pytorch distributed process groups", + py::call_guard()); + m.def("create_nvshmem_tensor", &transformer_engine::pytorch::create_rocshmem_tensor, + "Create a tensor in ROCSHMEM shared memory", py::call_guard()); + m.def("nvshmem_send_on_current_stream", + &transformer_engine::pytorch::rocshmem_send_on_current_stream, + "Asynchronously send tensor data to a remote PE using ROCSHMEM on the current HIP stream", + py::call_guard()); + m.def("nvshmem_wait_on_current_stream", + &transformer_engine::pytorch::rocshmem_wait_on_current_stream, + "Wait for a signal value to be updated by a remote PE using ROCSHMEM on the current HIP " + "stream", + py::call_guard()); + m.def("nvshmem_finalize", &transformer_engine::pytorch::rocshmem_finalize, + "Clean up and finalize the ROCSHMEM communication backend and free associated resources", + py::call_guard()); #endif // multi-tensor functions diff --git a/transformer_engine/pytorch/csrc/extensions/rocshmem_comm.cpp b/transformer_engine/pytorch/csrc/extensions/rocshmem_comm.cpp new file mode 100644 index 000000000..b8db7a056 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/rocshmem_comm.cpp @@ -0,0 +1,165 @@ +/************************************************************************* + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * License for AMD contributions = MIT. See LICENSE for more information +*************************************************************************/ + +#include "../extensions.h" + +#ifdef NVTE_ENABLE_ROCSHMEM +#include +#include +#include +#include +#include +#endif + +namespace transformer_engine::pytorch { + +void init_rocshmem_backend(c10d::ProcessGroup *process_group) { +#ifdef NVTE_ENABLE_ROCSHMEM + auto backend_type = process_group->getBackendType(); + NVTE_CHECK( + backend_type == c10d::ProcessGroup::BackendType::NCCL || + backend_type == c10d::ProcessGroup::BackendType::MPI, + "Currently only support NCCL or MPI bootstrap for ROCSHMEM. Found: ", + c10d::ProcessGroup::backendTypeToString(backend_type) + ); + + int my_rank = process_group->getRank(); + int num_ranks = process_group->getSize(); + + static std::once_flag rocshmem_init_flag; + static bool rocshmem_init_success = false; + + std::call_once(rocshmem_init_flag, [backend_type, my_rank, num_ranks]() { + if (backend_type == c10d::ProcessGroup::BackendType::MPI) { + int mpi_is_initialized = 0; + MPI_Initialized(&mpi_is_initialized); + if (!mpi_is_initialized) { + int provided; + MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided); + NVTE_CHECK( + provided >= MPI_THREAD_SINGLE, + "MPI initialization failed to provide required thread level." + ); + } + + int provided; + int ret = te_rocshmem_init_thread(MPI_THREAD_MULTIPLE, &provided); + NVTE_CHECK( + ret == 0, + "rocshmem_init_thread() failed with return code: ", ret + ); + NVTE_CHECK( + provided >= MPI_THREAD_MULTIPLE, + "ROCm SHMEM initialization failed to provide MPI_THREAD_MULTIPLE support. Got: ", provided + ); + } else { + setenv("ROCSHMEM_RANK", std::to_string(my_rank).c_str(), 0); + setenv("ROCSHMEM_SIZE", std::to_string(num_ranks).c_str(), 0); + + int ret = te_rocshmem_init_thread(0, nullptr); + NVTE_CHECK( + ret == 0, + "rocshmem_init_thread() failed with return code: ", ret, + ". Make sure PMI or compatible launcher environment is available." + ); + } + + rocshmem_init_success = true; + }); + + NVTE_CHECK(rocshmem_init_success, "ROCm SHMEM initialization failed"); + + int pe = te_rocshmem_my_pe(); + int npes = te_rocshmem_n_pes(); + + NVTE_CHECK( + pe == my_rank, + "ROCShmem PE rank mismatch: ProcessGroup rank (", my_rank, + ") != rocshmem_my_pe() (", pe, ")"); + + NVTE_CHECK( + npes == num_ranks, + "ROCShmem total PE count mismatch: ProcessGroup size (", num_ranks, + ") != rocshmem_n_pes() (", npes, ")"); +#else + NVTE_ERROR("Internal TE error: init_rocshmem_backend cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled without NVTE_ENABLE_ROCSHMEM!"); +#endif +} + +void rocshmem_wait_on_current_stream(torch::Tensor signal, const std::string &wait_kind) { +#ifdef NVTE_ENABLE_ROCSHMEM + uint64_t* sig_addr = reinterpret_cast(signal.data_ptr()); + hipStream_t cur_stream = (hipStream_t)at::hip::getCurrentHIPStream(); + + WaitKind kind_enum; + if (wait_kind == "kernel") + kind_enum = WaitKind::KERNEL_WAIT; + else if (wait_kind == "rocshmem" || wait_kind == "nvshmem") + kind_enum = WaitKind::ROCSHMEM_WAIT; + else if (wait_kind == "stream") + kind_enum = WaitKind::STREAM_WAIT; + else + NVTE_CHECK(false, "Invalid wait kind: ", wait_kind); + + te_rocshmem_wait_on_stream(sig_addr, kind_enum, cur_stream); +#else + NVTE_ERROR( + "Internal TE error: rocshmem_wait_on_current_stream cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled without NVTE_ENABLE_ROCSHMEM!"); +#endif +} + +torch::Tensor create_rocshmem_tensor(const std::vector &shape, c10::ScalarType dtype) { +#ifdef NVTE_ENABLE_ROCSHMEM + auto option_gpu = + at::TensorOptions().dtype(dtype).device(at::kHIP) + .device_index(c10::hip::current_device()); + + auto size = torch::elementSize(dtype) * + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>()); + + void* ptr = te_rocshmem_malloc(size); + NVTE_CHECK(ptr != nullptr, "rocshmem_malloc failed for ", size, " bytes"); + + return at::from_blob( + ptr, shape, + [](void* p) { te_rocshmem_free(p); }, + option_gpu); +#else + NVTE_ERROR("Internal TE error: create_rocshmem_tensor cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled without NVTE_ENABLE_ROCSHMEM!"); +#endif +} + +void rocshmem_send_on_current_stream(torch::Tensor src, torch::Tensor dst, int peer, + torch::Tensor signal) { +#ifdef NVTE_ENABLE_ROCSHMEM + void* src_ptr = reinterpret_cast(src.data_ptr()); + void* dst_ptr = reinterpret_cast(dst.data_ptr()); + uint64_t* sig_addr = reinterpret_cast(signal.data_ptr()); + size_t nelement = src.numel() * src.element_size(); + uint64_t sigval = 1; + + at::hip::HIPStream cur_stream = at::hip::getCurrentHIPStream(); + + te_rocshmem_putmem_signal(dst_ptr, src_ptr, nelement, sig_addr, sigval, peer, (hipStream_t)cur_stream); +#else + NVTE_ERROR( + "Internal TE error: rocshmem_send_on_current_stream cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled without NVTE_ENABLE_ROCSHMEM!"); +#endif +} + +void rocshmem_finalize() { +#ifdef NVTE_ENABLE_ROCSHMEM + te_rocshmem_finalize(); +#else + NVTE_ERROR("Internal TE error: nvshmem_finalize cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled without NVTE_ENABLE_ROCSHMEM!"); +#endif +} + +} // namespace transformer_engine::pytorch