From 35971429c332299f09e4ce2636fca598b712c7e0 Mon Sep 17 00:00:00 2001 From: alextmagro Date: Tue, 11 Nov 2025 13:32:03 -0600 Subject: [PATCH 1/7] Experimental rocSHMEM support --- build_tools/pytorch.py | 8 + setup.py | 7 + transformer_engine/common/CMakeLists.txt | 20 +++ .../common/libtransformer_engine.version | 3 +- .../common/rocshmem_api/CMakeLists.txt | 57 ++++++ .../rocshmem_api/rocshmem_waitkernel.hip | 114 ++++++++++++ .../rocshmem_api/rocshmem_waitkernel.hpp | 33 ++++ transformer_engine/pytorch/csrc/extensions.h | 14 ++ .../pytorch/csrc/extensions/pybind.cpp | 38 ++++ .../pytorch/csrc/extensions/rocshmem_comm.cpp | 167 ++++++++++++++++++ 10 files changed, 460 insertions(+), 1 deletion(-) create mode 100644 transformer_engine/common/rocshmem_api/CMakeLists.txt create mode 100644 transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip create mode 100644 transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hpp create mode 100644 transformer_engine/pytorch/csrc/extensions/rocshmem_comm.cpp diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 0609d1bc9..6c6ec75a8 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -94,6 +94,14 @@ def setup_pytorch_extension( libraries.append("nvshmem_host") cxx_flags.append("-DNVTE_ENABLE_NVSHMEM") + if bool(int(os.getenv("NVTE_ENABLE_ROCSHMEM", 0))): + cxx_flags.append("-DNVTE_ENABLE_ROCSHMEM") + mpi_home = Path(os.getenv("MPI_HOME", "/usr/lib/x86_64-linux-gnu/openmpi")) + include_dirs.append(mpi_home / "include") + library_dirs.append(mpi_home / "lib") + libraries.append("mpi_cxx") + + # Construct PyTorch CUDA extension sources = [str(path) for path in sources] include_dirs = [str(path) for path in include_dirs] diff --git a/setup.py b/setup.py index 91817d56e..c72f201f4 100644 --- a/setup.py +++ b/setup.py @@ -73,6 +73,13 @@ def setup_common_extension() -> CMakeExtension: cmake_flags.append("-DUSE_FUSED_ATTN_AOTRITON=OFF") if int(os.getenv("NVTE_FUSED_ATTN_CK", "1"))==0 or int(os.getenv("NVTE_FUSED_ATTN", "1"))==0: cmake_flags.append("-DUSE_FUSED_ATTN_CK=OFF") + if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))) and os.getenv("NVTE_ENABLE_ROCSHMEM") is None: + os.environ["NVTE_ENABLE_ROCSHMEM"] = '1' + os.environ["NVTE_ENABLE_NVSHMEM"] = '0' + print("Turning NVTE_ENABLE_ROCSHMEM on, disabling NVTE_ENABLE_NVSHMEM") + if bool(int(os.getenv("NVTE_ENABLE_ROCSHMEM", "0"))): + cmake_flags.append("-DNVTE_ENABLE_ROCSHMEM=ON") + else: cmake_flags.append("-DUSE_ROCM=OFF") cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)] diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 5b0f1981d..5e806446e 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -382,6 +382,26 @@ if(USE_CUDA) # Hack to enable dynamic loading in cuDNN frontend target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) +else() + option(NVTE_ENABLE_ROCSHMEM "Compile with ROCSHMEM library" OFF) + if (NVTE_ENABLE_ROCSHMEM) + add_subdirectory(rocshmem_api) + if(DEFINED ENV{ROCSHMEM_HOME}) + set(ROCSHMEM_HOME "$ENV{ROCSHMEM_HOME}" CACHE STRING "Location of ROCSHMEM installation") + else() + set(ROCSHMEM_HOME "/opt/rocm" CACHE STRING "Location of ROCSHMEM installation (default)") + endif() + target_link_options(transformer_engine PRIVATE + -fgpu-rdc + ) + target_link_libraries(transformer_engine PUBLIC + -Wl,--whole-archive + rocshmemapi + "${ROCSHMEM_HOME}/lib/librocshmem.a" + -Wl,--no-whole-archive + ) + target_include_directories(transformer_engine PUBLIC ${ROCSHMEMAPI_INCLUDE_DIR}) + endif() endif() # Helper functions to make header files with C++ strings diff --git a/transformer_engine/common/libtransformer_engine.version b/transformer_engine/common/libtransformer_engine.version index 4412d0c5f..db84c29e6 100644 --- a/transformer_engine/common/libtransformer_engine.version +++ b/transformer_engine/common/libtransformer_engine.version @@ -18,7 +18,8 @@ *transformer_engine::CommOverlapP2PBase*; *transformer_engine::CommOverlapCore*; *nvshmem_wait_on_stream*; - *nvshmemi_init_thread* + *nvshmemi_init_thread*; + *rocshmem* }; local: *; }; diff --git a/transformer_engine/common/rocshmem_api/CMakeLists.txt b/transformer_engine/common/rocshmem_api/CMakeLists.txt new file mode 100644 index 000000000..cbcbe204f --- /dev/null +++ b/transformer_engine/common/rocshmem_api/CMakeLists.txt @@ -0,0 +1,57 @@ +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# License for AMD contributions = MIT. See LICENSE for more information +cmake_minimum_required (VERSION 3.21) +project(rocshmem LANGUAGES HIP) + +find_package(hipblaslt REQUIRED) +find_package(hiprtc REQUIRED) +find_package(hip REQUIRED) +find_package(MPI REQUIRED) + +if(NOT DEFINED ENV{NVTE_ROCM_ARCH}) + set(CMAKE_HIP_ARCHITECTURES gfx942 gfx950) +else() + set(CMAKE_HIP_ARCHITECTURES $ENV{NVTE_ROCM_ARCH}) +endif() + +if(DEFINED ENV{ROCSHMEM_HOME}) + set(ROCSHMEM_HOME "$ENV{ROCSHMEM_HOME}" CACHE STRING "Location of ROCSHMEM installation") +else() + set(ROCSHMEM_HOME "/opt/rocm" CACHE STRING "Location of ROCSHMEM installation (default)") +endif() + +set(ROCSHMEM_INCLUDE_DIR "${ROCSHMEM_HOME}/include/rocshmem") +if(NOT EXISTS "${ROCSHMEM_INCLUDE_DIR}") + set(ROCSHMEM_INCLUDE_DIR "${ROCSHMEM_HOME}/include") +endif() + +add_library(rocshmemapi OBJECT rocshmem_waitkernel.hip) + +target_compile_options(rocshmemapi PRIVATE + $<$:-fgpu-rdc> +) + +target_include_directories(rocshmemapi PUBLIC + "${ROCSHMEM_INCLUDE_DIR}" + "${CMAKE_CURRENT_SOURCE_DIR}" + "${MPI_INCLUDE_PATH}" +) + +target_link_libraries(rocshmemapi PUBLIC + "${ROCSHMEM_HOME}/lib/librocshmem.a" + MPI::MPI_CXX + hip::host + hip::device + roctx64 + hiprtc + roc::hipblaslt +) + +set_target_properties(rocshmemapi PROPERTIES + CXX_STANDARD 17 + HIP_STANDARD 17 + POSITION_INDEPENDENT_CODE ON + HIP_SEPARABLE_COMPILATION ON +) + +set(ROCSHMEMAPI_INCLUDE_DIR "${ROCSHMEM_INCLUDE_DIR}" PARENT_SCOPE) diff --git a/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip b/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip new file mode 100644 index 000000000..88eead1b0 --- /dev/null +++ b/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip @@ -0,0 +1,114 @@ +/************************************************************************* + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * License for AMD contributions = MIT. See LICENSE for more information +*************************************************************************/ + +#include +#include + +#include "../util/logging_hip.h" +#include "rocshmem_waitkernel.hpp" + +using namespace rocshmem; + +__global__ void wait_until_on_stream_and_reset(uint64_t *wait_flag, + uint64_t wait_value, + uint64_t signal_reset) { + rocshmem_ulonglong_wait_until((unsigned long long*)wait_flag, + ROCSHMEM_CMP_EQ, + (unsigned long long)wait_value); +} + +__global__ void rocshmem_putmem_signal_kernel(void* dst_ptr, const void* src_ptr, + size_t nelement, uint64_t* sig_addr, + uint64_t sigval, int peer) { + if (threadIdx.x == 0 && blockIdx.x == 0) { + rocshmem_putmem(dst_ptr, src_ptr, nelement, peer); + rocshmem_fence(); + rocshmem_ulonglong_p((unsigned long long*)sig_addr, + (unsigned long long)sigval, + peer); + } +} + +void te_rocshmem_putmem_signal(void* dst_ptr, const void* src_ptr, size_t nelement, + uint64_t* sig_addr, uint64_t sigval, int peer, + hipStream_t cur_stream) { + hipLaunchKernelGGL(rocshmem_putmem_signal_kernel, + dim3(1), dim3(1), 0, cur_stream, + dst_ptr, src_ptr, nelement, sig_addr, + sigval, peer); +} + +void te_rocshmem_wait_on_stream(uint64_t* sig_addr, + WaitKind wait_kind, + hipStream_t cur_stream) { + uint64_t wait_value = 1; + uint64_t signal_reset = 0; + + NVTE_CHECK(wait_kind >= WaitKind::KERNEL_WAIT && + wait_kind <= WaitKind::STREAM_WAIT, + "Invalid wait kind"); + + switch (wait_kind) { +// ### wait_until_on_stream not yet implemented for rocshmem ### +// ### KernelWait is robust but slightly slower due to launch ### + case WaitKind::ROCSHMEM_WAIT: + // rocshmem__ulonglong_wait_until_on_stream(sig_addr, + // ROCSHMEM_CMP_EQ, + // wait_value, + // cur_stream); + // hipStreamWriteValue64(cur_stream, + // reinterpret_cast(sig_addr), + // signal_reset, 0); + // break; + case WaitKind::KERNEL_WAIT: + hipLaunchKernelGGL(wait_until_on_stream_and_reset, + dim3(1), dim3(1), 0, cur_stream, + sig_addr, wait_value, signal_reset); + hipStreamWriteValue64(cur_stream, + reinterpret_cast(sig_addr), + signal_reset, 0); + break; + case WaitKind::STREAM_WAIT: + hipStreamWaitValue64(cur_stream, + reinterpret_cast(sig_addr), + wait_value, hipStreamWaitValueGte); + hipStreamWriteValue64(cur_stream, + reinterpret_cast(sig_addr), + signal_reset, 0); + break; + } +} + +int te_rocshmem_init_thread(int required, int* provided) { + if (required == 0 && provided == nullptr) { + rocshmem_init(); + return 0; + } else { + return rocshmem_init_thread(required, provided); + } +} + +void te_rocshmem_finalize() { + rocshmem_finalize(); +} + +int te_rocshmem_my_pe() { + return rocshmem_my_pe(); +} + +int te_rocshmem_n_pes() { + return rocshmem_n_pes(); +} + +void* te_rocshmem_malloc(size_t size) { + return rocshmem_malloc(size); +} + +void te_rocshmem_free(void* ptr) { + rocshmem_free(ptr); +} + +void te_rocshmem_wait_until(uint64_t* signal_addr, uint64_t expected_value, + hipStream_t stream); \ No newline at end of file diff --git a/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hpp b/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hpp new file mode 100644 index 000000000..03008e863 --- /dev/null +++ b/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hpp @@ -0,0 +1,33 @@ +/************************************************************************* + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * License for AMD contributions = MIT. See LICENSE for more information +*************************************************************************/ + +#pragma once + +#include + +enum class WaitKind : uint8_t { + KERNEL_WAIT = 0, + ROCSHMEM_WAIT = 1, + STREAM_WAIT = 2 +}; + +void te_rocshmem_wait_on_stream(uint64_t *sig_addr, WaitKind wait_kind, hipStream_t cur_stream); + +void te_rocshmem_putmem_signal(void* dst_ptr, const void* src_ptr, size_t nelement, + uint64_t* sig_addr, uint64_t sigval, int peer, hipStream_t cur_stream); + +/* +These are minimal wrappers around rocshmem functions. As pytorch is a cpp extension, +rocshmem is a static library, and rocshmem does not have separate host / device libraries +we need to move these to common, which handles device code properly. +*/ +int te_rocshmem_init_thread(int required, int* provided); +void te_rocshmem_finalize(); +int te_rocshmem_my_pe(); +int te_rocshmem_n_pes(); +void* te_rocshmem_malloc(size_t size); +void te_rocshmem_free(void* ptr); +void te_rocshmem_wait_until(uint64_t* signal_addr, uint64_t expected_value, + hipStream_t stream); \ No newline at end of file diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 2ff64ae90..3958a3a3d 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -395,6 +395,20 @@ void nvshmem_send_on_current_stream(at::Tensor src, at::Tensor dst, int peer, at void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_kind); void nvshmem_finalize(); +#else +/*************************************************************************************************** + * ROCSHMEM APIs + **************************************************************************************************/ + +void init_rocshmem_backend(c10d::ProcessGroup *process_group); + +at::Tensor create_rocshmem_tensor(const std::vector &shape, c10::ScalarType dtype); + +void rocshmem_send_on_current_stream(at::Tensor src, at::Tensor dst, int peer, at::Tensor signal); + +void rocshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_kind); + +void rocshmem_finalize(); #endif } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 93a42bcc3..730dc5cba 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -303,6 +303,44 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("nvshmem_finalize", &transformer_engine::pytorch::nvshmem_finalize, "Clean up and finalize the NVSHMEM communication backend and free associated resources", py::call_guard()); +#else + // rocshmem functions + m.def("init_rocshmem_backend", &transformer_engine::pytorch::init_rocshmem_backend, + "Initialize ROCSHMEM backend with Pytorch distributed process groups", + py::call_guard()); + m.def("create_rocshmem_tensor", &transformer_engine::pytorch::create_rocshmem_tensor, + "Create a tensor in ROCSHMEM shared memory", py::call_guard()); + m.def("rocshmem_send_on_current_stream", + &transformer_engine::pytorch::rocshmem_send_on_current_stream, + "Asynchronously send tensor data to a remote PE using ROCSHMEM on the current HIP stream", + py::call_guard()); + m.def("rocshmem_wait_on_current_stream", + &transformer_engine::pytorch::rocshmem_wait_on_current_stream, + "Wait for a signal value to be updated by a remote PE using ROCSHMEM on the current HIP " + "stream", + py::call_guard()); + m.def("rocshmem_finalize", &transformer_engine::pytorch::rocshmem_finalize, + "Clean up and finalize the ROCSHMEM communication backend and free associated resources", + py::call_guard()); + + // nvshmem wrappers + m.def("init_nvshmem_backend", &transformer_engine::pytorch::init_rocshmem_backend, + "Initialize ROCSHMEM backend with Pytorch distributed process groups", + py::call_guard()); + m.def("create_nvshmem_tensor", &transformer_engine::pytorch::create_rocshmem_tensor, + "Create a tensor in ROCSHMEM shared memory", py::call_guard()); + m.def("nvshmem_send_on_current_stream", + &transformer_engine::pytorch::rocshmem_send_on_current_stream, + "Asynchronously send tensor data to a remote PE using ROCSHMEM on the current HIP stream", + py::call_guard()); + m.def("nvshmem_wait_on_current_stream", + &transformer_engine::pytorch::rocshmem_wait_on_current_stream, + "Wait for a signal value to be updated by a remote PE using ROCSHMEM on the current HIP " + "stream", + py::call_guard()); + m.def("nvshmem_finalize", &transformer_engine::pytorch::rocshmem_finalize, + "Clean up and finalize the ROCSHMEM communication backend and free associated resources", + py::call_guard()); #endif // multi-tensor functions diff --git a/transformer_engine/pytorch/csrc/extensions/rocshmem_comm.cpp b/transformer_engine/pytorch/csrc/extensions/rocshmem_comm.cpp new file mode 100644 index 000000000..96a1bb4c5 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/rocshmem_comm.cpp @@ -0,0 +1,167 @@ +/************************************************************************* + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * License for AMD contributions = MIT. See LICENSE for more information +*************************************************************************/ + +#include "../extensions.h" + +#ifdef NVTE_ENABLE_ROCSHMEM +#include +#include +#endif + +#include +#include +#include +#include + +namespace transformer_engine::pytorch { + +void init_rocshmem_backend(c10d::ProcessGroup *process_group) { +#ifdef NVTE_ENABLE_ROCSHMEM + auto backend_type = process_group->getBackendType(); + NVTE_CHECK( + backend_type == c10d::ProcessGroup::BackendType::NCCL || + backend_type == c10d::ProcessGroup::BackendType::MPI, + "Currently only support NCCL or MPI bootstrap for ROCSHMEM. Found: ", + c10d::ProcessGroup::backendTypeToString(backend_type) + ); + + int my_rank = process_group->getRank(); + int num_ranks = process_group->getSize(); + + static std::once_flag rocshmem_init_flag; + static bool rocshmem_init_success = false; + + std::call_once(rocshmem_init_flag, [backend_type, my_rank, num_ranks]() { + if (backend_type == c10d::ProcessGroup::BackendType::MPI) { + int mpi_is_initialized = 0; + MPI_Initialized(&mpi_is_initialized); + if (!mpi_is_initialized) { + int provided; + MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided); + NVTE_CHECK( + provided >= MPI_THREAD_SINGLE, + "MPI initialization failed to provide required thread level." + ); + } + + int provided; + int ret = te_rocshmem_init_thread(MPI_THREAD_MULTIPLE, &provided); + NVTE_CHECK( + ret == 0, + "rocshmem_init_thread() failed with return code: ", ret + ); + NVTE_CHECK( + provided >= MPI_THREAD_MULTIPLE, + "ROCm SHMEM initialization failed to provide MPI_THREAD_MULTIPLE support. Got: ", provided + ); + } else { + setenv("ROCSHMEM_RANK", std::to_string(my_rank).c_str(), 0); + setenv("ROCSHMEM_SIZE", std::to_string(num_ranks).c_str(), 0); + + int ret = te_rocshmem_init_thread(0, nullptr); + NVTE_CHECK( + ret == 0, + "rocshmem_init_thread() failed with return code: ", ret, + ". Make sure PMI or compatible launcher environment is available." + ); + } + + rocshmem_init_success = true; + }); + + NVTE_CHECK(rocshmem_init_success, "ROCm SHMEM initialization failed"); + + int pe = te_rocshmem_my_pe(); + int npes = te_rocshmem_n_pes(); + + NVTE_CHECK( + pe == my_rank, + "ROCShmem PE rank mismatch: ProcessGroup rank (", my_rank, + ") != rocshmem_my_pe() (", pe, ")"); + + NVTE_CHECK( + npes == num_ranks, + "ROCShmem total PE count mismatch: ProcessGroup size (", num_ranks, + ") != rocshmem_n_pes() (", npes, ")"); +#else + NVTE_ERROR("Internal TE error: init_rocshmem_backend cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled without NVTE_ENABLE_ROCSHMEM!"); +#endif +} + +void rocshmem_wait_on_current_stream(torch::Tensor signal, const std::string &wait_kind) { +#ifdef NVTE_ENABLE_ROCSHMEM + uint64_t* sig_addr = reinterpret_cast(signal.data_ptr()); + hipStream_t cur_stream = (hipStream_t)at::hip::getCurrentHIPStream(); + + WaitKind kind_enum; + if (wait_kind == "kernel") + kind_enum = WaitKind::KERNEL_WAIT; + else if (wait_kind == "rocshmem" || wait_kind == "nvshmem") + kind_enum = WaitKind::ROCSHMEM_WAIT; + else if (wait_kind == "stream") + kind_enum = WaitKind::STREAM_WAIT; + else + NVTE_CHECK(false, "Invalid wait kind: ", wait_kind); + + te_rocshmem_wait_on_stream(sig_addr, kind_enum, cur_stream); +#else + NVTE_ERROR( + "Internal TE error: rocshmem_wait_on_current_stream cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled without NVTE_ENABLE_ROCSHMEM!"); +#endif +} + +torch::Tensor create_rocshmem_tensor(const std::vector &shape, c10::ScalarType dtype) { +#ifdef NVTE_ENABLE_ROCSHMEM + auto option_gpu = + at::TensorOptions().dtype(dtype).device(at::kHIP) + .device_index(c10::hip::current_device()); + + auto size = torch::elementSize(dtype) * + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>()); + + void* ptr = te_rocshmem_malloc(size); + NVTE_CHECK(ptr != nullptr, "rocshmem_malloc failed for ", size, " bytes"); + + return at::from_blob( + ptr, shape, + [](void* p) { te_rocshmem_free(p); }, + option_gpu); +#else + NVTE_ERROR("Internal TE error: create_rocshmem_tensor cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled without NVTE_ENABLE_ROCSHMEM!"); +#endif +} + +void rocshmem_send_on_current_stream(torch::Tensor src, torch::Tensor dst, int peer, + torch::Tensor signal) { +#ifdef NVTE_ENABLE_ROCSHMEM + void* src_ptr = reinterpret_cast(src.data_ptr()); + void* dst_ptr = reinterpret_cast(dst.data_ptr()); + uint64_t* sig_addr = reinterpret_cast(signal.data_ptr()); + size_t nelement = src.numel() * src.element_size(); + uint64_t sigval = 1; + + at::hip::HIPStream cur_stream = at::hip::getCurrentHIPStream(); + + te_rocshmem_putmem_signal(dst_ptr, src_ptr, nelement, sig_addr, sigval, peer, (hipStream_t)cur_stream); +#else + NVTE_ERROR( + "Internal TE error: rocshmem_send_on_current_stream cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled without NVTE_ENABLE_ROCSHMEM!"); +#endif +} + +void rocshmem_finalize() { +#ifdef NVTE_ENABLE_ROCSHMEM + te_rocshmem_finalize(); +#else + NVTE_ERROR("Internal TE error: nvshmem_finalize cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled without NVTE_ENABLE_ROCSHMEM!"); +#endif +} + +} // namespace transformer_engine::pytorch From 1c54796b98e6bfd8abee87bb35cf86567d2ea4e7 Mon Sep 17 00:00:00 2001 From: alextmagro Date: Tue, 11 Nov 2025 13:32:35 -0600 Subject: [PATCH 2/7] add rocshmem wait warning --- .../rocshmem_api/rocshmem_waitkernel.hip | 2 ++ .../pytorch/csrc/extensions/pybind.cpp | 21 +------------------ 2 files changed, 3 insertions(+), 20 deletions(-) diff --git a/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip b/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip index 88eead1b0..34d3f0e00 100644 --- a/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip +++ b/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip @@ -54,6 +54,8 @@ void te_rocshmem_wait_on_stream(uint64_t* sig_addr, // ### wait_until_on_stream not yet implemented for rocshmem ### // ### KernelWait is robust but slightly slower due to launch ### case WaitKind::ROCSHMEM_WAIT: + printf("WARNING: rocshmem wait is not implemented yet, defaulting to + kernel wait.\n"); // rocshmem__ulonglong_wait_until_on_stream(sig_addr, // ROCSHMEM_CMP_EQ, // wait_value, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 730dc5cba..65805759a 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -304,26 +304,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Clean up and finalize the NVSHMEM communication backend and free associated resources", py::call_guard()); #else - // rocshmem functions - m.def("init_rocshmem_backend", &transformer_engine::pytorch::init_rocshmem_backend, - "Initialize ROCSHMEM backend with Pytorch distributed process groups", - py::call_guard()); - m.def("create_rocshmem_tensor", &transformer_engine::pytorch::create_rocshmem_tensor, - "Create a tensor in ROCSHMEM shared memory", py::call_guard()); - m.def("rocshmem_send_on_current_stream", - &transformer_engine::pytorch::rocshmem_send_on_current_stream, - "Asynchronously send tensor data to a remote PE using ROCSHMEM on the current HIP stream", - py::call_guard()); - m.def("rocshmem_wait_on_current_stream", - &transformer_engine::pytorch::rocshmem_wait_on_current_stream, - "Wait for a signal value to be updated by a remote PE using ROCSHMEM on the current HIP " - "stream", - py::call_guard()); - m.def("rocshmem_finalize", &transformer_engine::pytorch::rocshmem_finalize, - "Clean up and finalize the ROCSHMEM communication backend and free associated resources", - py::call_guard()); - - // nvshmem wrappers + // nvshmem/rocshmem wrappers m.def("init_nvshmem_backend", &transformer_engine::pytorch::init_rocshmem_backend, "Initialize ROCSHMEM backend with Pytorch distributed process groups", py::call_guard()); From ff54a6a9697bfbceaf27b7af175451041f78d40e Mon Sep 17 00:00:00 2001 From: alextmagro Date: Mon, 17 Nov 2025 10:30:32 -0600 Subject: [PATCH 3/7] CMake Cleanup --- transformer_engine/common/CMakeLists.txt | 1 - .../common/libtransformer_engine.version | 2 +- .../common/rocshmem_api/CMakeLists.txt | 25 +++++-------------- .../rocshmem_api/rocshmem_waitkernel.hip | 3 +-- 4 files changed, 8 insertions(+), 23 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 5e806446e..afb5e2677 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -400,7 +400,6 @@ else() "${ROCSHMEM_HOME}/lib/librocshmem.a" -Wl,--no-whole-archive ) - target_include_directories(transformer_engine PUBLIC ${ROCSHMEMAPI_INCLUDE_DIR}) endif() endif() diff --git a/transformer_engine/common/libtransformer_engine.version b/transformer_engine/common/libtransformer_engine.version index db84c29e6..d395e1f3a 100644 --- a/transformer_engine/common/libtransformer_engine.version +++ b/transformer_engine/common/libtransformer_engine.version @@ -19,7 +19,7 @@ *transformer_engine::CommOverlapCore*; *nvshmem_wait_on_stream*; *nvshmemi_init_thread*; - *rocshmem* + *te_rocshmem* }; local: *; }; diff --git a/transformer_engine/common/rocshmem_api/CMakeLists.txt b/transformer_engine/common/rocshmem_api/CMakeLists.txt index cbcbe204f..df911675e 100644 --- a/transformer_engine/common/rocshmem_api/CMakeLists.txt +++ b/transformer_engine/common/rocshmem_api/CMakeLists.txt @@ -3,16 +3,9 @@ cmake_minimum_required (VERSION 3.21) project(rocshmem LANGUAGES HIP) -find_package(hipblaslt REQUIRED) -find_package(hiprtc REQUIRED) find_package(hip REQUIRED) find_package(MPI REQUIRED) -if(NOT DEFINED ENV{NVTE_ROCM_ARCH}) - set(CMAKE_HIP_ARCHITECTURES gfx942 gfx950) -else() - set(CMAKE_HIP_ARCHITECTURES $ENV{NVTE_ROCM_ARCH}) -endif() if(DEFINED ENV{ROCSHMEM_HOME}) set(ROCSHMEM_HOME "$ENV{ROCSHMEM_HOME}" CACHE STRING "Location of ROCSHMEM installation") @@ -31,6 +24,12 @@ target_compile_options(rocshmemapi PRIVATE $<$:-fgpu-rdc> ) +set_target_properties(rocshmemapi PROPERTIES + CXX_STANDARD 17 + HIP_STANDARD 17 + POSITION_INDEPENDENT_CODE ON +) + target_include_directories(rocshmemapi PUBLIC "${ROCSHMEM_INCLUDE_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}" @@ -42,16 +41,4 @@ target_link_libraries(rocshmemapi PUBLIC MPI::MPI_CXX hip::host hip::device - roctx64 - hiprtc - roc::hipblaslt ) - -set_target_properties(rocshmemapi PROPERTIES - CXX_STANDARD 17 - HIP_STANDARD 17 - POSITION_INDEPENDENT_CODE ON - HIP_SEPARABLE_COMPILATION ON -) - -set(ROCSHMEMAPI_INCLUDE_DIR "${ROCSHMEM_INCLUDE_DIR}" PARENT_SCOPE) diff --git a/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip b/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip index 34d3f0e00..9f7fe0e2e 100644 --- a/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip +++ b/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip @@ -54,8 +54,7 @@ void te_rocshmem_wait_on_stream(uint64_t* sig_addr, // ### wait_until_on_stream not yet implemented for rocshmem ### // ### KernelWait is robust but slightly slower due to launch ### case WaitKind::ROCSHMEM_WAIT: - printf("WARNING: rocshmem wait is not implemented yet, defaulting to - kernel wait.\n"); + printf("WARNING: rocshmem wait is not implemented yet, defaulting to kernel wait.\n"); // rocshmem__ulonglong_wait_until_on_stream(sig_addr, // ROCSHMEM_CMP_EQ, // wait_value, From 0a4c252923410ea804d5a847fc83b4c1e9462456 Mon Sep 17 00:00:00 2001 From: alextmagro Date: Mon, 17 Nov 2025 13:24:48 -0600 Subject: [PATCH 4/7] include/hipify refactor --- transformer_engine/common/CMakeLists.txt | 4 ++++ .../pytorch/csrc/extensions/rocshmem_comm.cpp | 6 ++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index afb5e2677..7fbff069f 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -238,6 +238,7 @@ else() IGNORES "*/amd_detail/*" IGNORES "*/aotriton/*" IGNORES "*/ck_fused_attn/*" + IGNORES "*/rocshmem_api/*" CUSTOM_MAP_FILE "${TE}/hipify_custom_map.json" NO_MATH_REPLACE ) @@ -385,6 +386,9 @@ target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAM else() option(NVTE_ENABLE_ROCSHMEM "Compile with ROCSHMEM library" OFF) if (NVTE_ENABLE_ROCSHMEM) + find_package(MPI REQUIRED) + target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX) + target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES}) add_subdirectory(rocshmem_api) if(DEFINED ENV{ROCSHMEM_HOME}) set(ROCSHMEM_HOME "$ENV{ROCSHMEM_HOME}" CACHE STRING "Location of ROCSHMEM installation") diff --git a/transformer_engine/pytorch/csrc/extensions/rocshmem_comm.cpp b/transformer_engine/pytorch/csrc/extensions/rocshmem_comm.cpp index 96a1bb4c5..b8db7a056 100644 --- a/transformer_engine/pytorch/csrc/extensions/rocshmem_comm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/rocshmem_comm.cpp @@ -8,12 +8,10 @@ #ifdef NVTE_ENABLE_ROCSHMEM #include #include -#endif - -#include -#include +#include #include #include +#endif namespace transformer_engine::pytorch { From b95cad1fca3d50992c385e740b7dd0cdf55cd81a Mon Sep 17 00:00:00 2001 From: alextmagro Date: Mon, 17 Nov 2025 13:31:17 -0600 Subject: [PATCH 5/7] make IMPORTED rocshmem library --- transformer_engine/common/CMakeLists.txt | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 7fbff069f..573cef17a 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -395,13 +395,23 @@ else() else() set(ROCSHMEM_HOME "/opt/rocm" CACHE STRING "Location of ROCSHMEM installation (default)") endif() + set(ROCSHMEM_LIBRARY_PATH "${ROCSHMEM_HOME}/lib/librocshmem.a") + if (EXISTS ${ROCSHMEM_LIBRARY_PATH}) + add_library(rocshmem STATIC IMPORTED) + set_target_properties(rocshmem PROPERTIES + IMPORTED_LOCATION "${ROCSHMEM_LIBRARY_PATH}" + IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" + ) + else() + message(FATAL_ERROR "ROCSHMEM library not found at ${ROCSHMEM_LIBRARY_PATH}. PLease set ROCSHMEM_HOME.") + endif() target_link_options(transformer_engine PRIVATE -fgpu-rdc ) target_link_libraries(transformer_engine PUBLIC -Wl,--whole-archive rocshmemapi - "${ROCSHMEM_HOME}/lib/librocshmem.a" + rocshmem -Wl,--no-whole-archive ) endif() From 22f7bf5e6193ecc42d7b6207f5e3c7d140404085 Mon Sep 17 00:00:00 2001 From: alextmagro Date: Tue, 18 Nov 2025 16:44:22 -0600 Subject: [PATCH 6/7] Move rocshmem_home logic to subproject --- transformer_engine/common/CMakeLists.txt | 19 ---------- .../common/rocshmem_api/CMakeLists.txt | 36 ++++++++++++------- 2 files changed, 23 insertions(+), 32 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 573cef17a..c3fbab436 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -386,32 +386,13 @@ target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAM else() option(NVTE_ENABLE_ROCSHMEM "Compile with ROCSHMEM library" OFF) if (NVTE_ENABLE_ROCSHMEM) - find_package(MPI REQUIRED) - target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX) - target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES}) add_subdirectory(rocshmem_api) - if(DEFINED ENV{ROCSHMEM_HOME}) - set(ROCSHMEM_HOME "$ENV{ROCSHMEM_HOME}" CACHE STRING "Location of ROCSHMEM installation") - else() - set(ROCSHMEM_HOME "/opt/rocm" CACHE STRING "Location of ROCSHMEM installation (default)") - endif() - set(ROCSHMEM_LIBRARY_PATH "${ROCSHMEM_HOME}/lib/librocshmem.a") - if (EXISTS ${ROCSHMEM_LIBRARY_PATH}) - add_library(rocshmem STATIC IMPORTED) - set_target_properties(rocshmem PROPERTIES - IMPORTED_LOCATION "${ROCSHMEM_LIBRARY_PATH}" - IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" - ) - else() - message(FATAL_ERROR "ROCSHMEM library not found at ${ROCSHMEM_LIBRARY_PATH}. PLease set ROCSHMEM_HOME.") - endif() target_link_options(transformer_engine PRIVATE -fgpu-rdc ) target_link_libraries(transformer_engine PUBLIC -Wl,--whole-archive rocshmemapi - rocshmem -Wl,--no-whole-archive ) endif() diff --git a/transformer_engine/common/rocshmem_api/CMakeLists.txt b/transformer_engine/common/rocshmem_api/CMakeLists.txt index df911675e..4ab049006 100644 --- a/transformer_engine/common/rocshmem_api/CMakeLists.txt +++ b/transformer_engine/common/rocshmem_api/CMakeLists.txt @@ -1,23 +1,12 @@ # Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information + cmake_minimum_required (VERSION 3.21) project(rocshmem LANGUAGES HIP) find_package(hip REQUIRED) find_package(MPI REQUIRED) - -if(DEFINED ENV{ROCSHMEM_HOME}) - set(ROCSHMEM_HOME "$ENV{ROCSHMEM_HOME}" CACHE STRING "Location of ROCSHMEM installation") -else() - set(ROCSHMEM_HOME "/opt/rocm" CACHE STRING "Location of ROCSHMEM installation (default)") -endif() - -set(ROCSHMEM_INCLUDE_DIR "${ROCSHMEM_HOME}/include/rocshmem") -if(NOT EXISTS "${ROCSHMEM_INCLUDE_DIR}") - set(ROCSHMEM_INCLUDE_DIR "${ROCSHMEM_HOME}/include") -endif() - add_library(rocshmemapi OBJECT rocshmem_waitkernel.hip) target_compile_options(rocshmemapi PRIVATE @@ -30,6 +19,27 @@ set_target_properties(rocshmemapi PROPERTIES POSITION_INDEPENDENT_CODE ON ) +if(DEFINED ENV{ROCSHMEM_HOME}) + set(ROCSHMEM_HOME "$ENV{ROCSHMEM_HOME}" CACHE STRING "Location of ROCSHMEM installation") +else() + set(ROCSHMEM_HOME "/opt/rocm" CACHE STRING "Location of ROCSHMEM installation (default)") +endif() +set(ROCSHMEM_LIBRARY_PATH "${ROCSHMEM_HOME}/lib/librocshmem.a") +if (EXISTS ${ROCSHMEM_LIBRARY_PATH}) + add_library(rocshmem STATIC IMPORTED) + set_target_properties(rocshmem PROPERTIES + IMPORTED_LOCATION "${ROCSHMEM_LIBRARY_PATH}" + IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" + ) +else() + message(FATAL_ERROR "ROCSHMEM library not found at ${ROCSHMEM_LIBRARY_PATH}. PLease set ROCSHMEM_HOME.") +endif() + +set(ROCSHMEM_INCLUDE_DIR "${ROCSHMEM_HOME}/include/rocshmem") +if(NOT EXISTS "${ROCSHMEM_INCLUDE_DIR}") + set(ROCSHMEM_INCLUDE_DIR "${ROCSHMEM_HOME}/include") +endif() + target_include_directories(rocshmemapi PUBLIC "${ROCSHMEM_INCLUDE_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}" @@ -37,7 +47,7 @@ target_include_directories(rocshmemapi PUBLIC ) target_link_libraries(rocshmemapi PUBLIC - "${ROCSHMEM_HOME}/lib/librocshmem.a" + rocshmem MPI::MPI_CXX hip::host hip::device From e82a4b154925f7ddb8a95c8a6f55b0dc0e80fb3b Mon Sep 17 00:00:00 2001 From: alextmagro Date: Wed, 19 Nov 2025 13:49:35 -0600 Subject: [PATCH 7/7] change rocshmem_api from object to static --- transformer_engine/common/rocshmem_api/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/rocshmem_api/CMakeLists.txt b/transformer_engine/common/rocshmem_api/CMakeLists.txt index 4ab049006..fc15f1529 100644 --- a/transformer_engine/common/rocshmem_api/CMakeLists.txt +++ b/transformer_engine/common/rocshmem_api/CMakeLists.txt @@ -7,7 +7,7 @@ project(rocshmem LANGUAGES HIP) find_package(hip REQUIRED) find_package(MPI REQUIRED) -add_library(rocshmemapi OBJECT rocshmem_waitkernel.hip) +add_library(rocshmemapi STATIC rocshmem_waitkernel.hip) target_compile_options(rocshmemapi PRIVATE $<$:-fgpu-rdc>