Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ def setup_pytorch_extension(
libraries.append("nvshmem_host")
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")

if bool(int(os.getenv("NVTE_ENABLE_ROCSHMEM", 0))):
cxx_flags.append("-DNVTE_ENABLE_ROCSHMEM")
mpi_home = Path(os.getenv("MPI_HOME", "/usr/lib/x86_64-linux-gnu/openmpi"))
include_dirs.append(mpi_home / "include")
library_dirs.append(mpi_home / "lib")
libraries.append("mpi_cxx")


# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
Expand Down
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ def setup_common_extension() -> CMakeExtension:
cmake_flags.append("-DUSE_FUSED_ATTN_AOTRITON=OFF")
if int(os.getenv("NVTE_FUSED_ATTN_CK", "1"))==0 or int(os.getenv("NVTE_FUSED_ATTN", "1"))==0:
cmake_flags.append("-DUSE_FUSED_ATTN_CK=OFF")
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))) and os.getenv("NVTE_ENABLE_ROCSHMEM") is None:
os.environ["NVTE_ENABLE_ROCSHMEM"] = '1'
os.environ["NVTE_ENABLE_NVSHMEM"] = '0'
print("Turning NVTE_ENABLE_ROCSHMEM on, disabling NVTE_ENABLE_NVSHMEM")
if bool(int(os.getenv("NVTE_ENABLE_ROCSHMEM", "0"))):
cmake_flags.append("-DNVTE_ENABLE_ROCSHMEM=ON")

else:
cmake_flags.append("-DUSE_ROCM=OFF")
cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
Expand Down
14 changes: 14 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ else()
IGNORES "*/amd_detail/*"
IGNORES "*/aotriton/*"
IGNORES "*/ck_fused_attn/*"
IGNORES "*/rocshmem_api/*"
CUSTOM_MAP_FILE "${TE}/hipify_custom_map.json"
NO_MATH_REPLACE
)
Expand Down Expand Up @@ -382,6 +383,19 @@ if(USE_CUDA)

# Hack to enable dynamic loading in cuDNN frontend
target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING)
else()
option(NVTE_ENABLE_ROCSHMEM "Compile with ROCSHMEM library" OFF)
if (NVTE_ENABLE_ROCSHMEM)
add_subdirectory(rocshmem_api)
target_link_options(transformer_engine PRIVATE
-fgpu-rdc
)
target_link_libraries(transformer_engine PUBLIC
-Wl,--whole-archive
rocshmemapi
-Wl,--no-whole-archive
)
endif()
endif()

# Helper functions to make header files with C++ strings
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/common/libtransformer_engine.version
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
*transformer_engine::CommOverlapP2PBase*;
*transformer_engine::CommOverlapCore*;
*nvshmem_wait_on_stream*;
*nvshmemi_init_thread*
*nvshmemi_init_thread*;
*te_rocshmem*
};
local: *;
};
54 changes: 54 additions & 0 deletions transformer_engine/common/rocshmem_api/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# License for AMD contributions = MIT. See LICENSE for more information

cmake_minimum_required (VERSION 3.21)
project(rocshmem LANGUAGES HIP)

find_package(hip REQUIRED)
find_package(MPI REQUIRED)

add_library(rocshmemapi STATIC rocshmem_waitkernel.hip)

target_compile_options(rocshmemapi PRIVATE
$<$<COMPILE_LANGUAGE:HIP>:-fgpu-rdc>
)

set_target_properties(rocshmemapi PROPERTIES
CXX_STANDARD 17
HIP_STANDARD 17
POSITION_INDEPENDENT_CODE ON
)

if(DEFINED ENV{ROCSHMEM_HOME})
set(ROCSHMEM_HOME "$ENV{ROCSHMEM_HOME}" CACHE STRING "Location of ROCSHMEM installation")
else()
set(ROCSHMEM_HOME "/opt/rocm" CACHE STRING "Location of ROCSHMEM installation (default)")
endif()
set(ROCSHMEM_LIBRARY_PATH "${ROCSHMEM_HOME}/lib/librocshmem.a")
if (EXISTS ${ROCSHMEM_LIBRARY_PATH})
add_library(rocshmem STATIC IMPORTED)
set_target_properties(rocshmem PROPERTIES
IMPORTED_LOCATION "${ROCSHMEM_LIBRARY_PATH}"
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
)
else()
message(FATAL_ERROR "ROCSHMEM library not found at ${ROCSHMEM_LIBRARY_PATH}. PLease set ROCSHMEM_HOME.")
endif()

set(ROCSHMEM_INCLUDE_DIR "${ROCSHMEM_HOME}/include/rocshmem")
if(NOT EXISTS "${ROCSHMEM_INCLUDE_DIR}")
set(ROCSHMEM_INCLUDE_DIR "${ROCSHMEM_HOME}/include")
endif()

target_include_directories(rocshmemapi PUBLIC
"${ROCSHMEM_INCLUDE_DIR}"
"${CMAKE_CURRENT_SOURCE_DIR}"
"${MPI_INCLUDE_PATH}"
)

target_link_libraries(rocshmemapi PUBLIC
rocshmem
MPI::MPI_CXX
hip::host
hip::device
)
115 changes: 115 additions & 0 deletions transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*************************************************************************
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
* License for AMD contributions = MIT. See LICENSE for more information
*************************************************************************/

#include <hip/hip_runtime.h>
#include <rocshmem.hpp>

#include "../util/logging_hip.h"
#include "rocshmem_waitkernel.hpp"

using namespace rocshmem;

__global__ void wait_until_on_stream_and_reset(uint64_t *wait_flag,
uint64_t wait_value,
uint64_t signal_reset) {
rocshmem_ulonglong_wait_until((unsigned long long*)wait_flag,
ROCSHMEM_CMP_EQ,
(unsigned long long)wait_value);
}

__global__ void rocshmem_putmem_signal_kernel(void* dst_ptr, const void* src_ptr,
size_t nelement, uint64_t* sig_addr,
uint64_t sigval, int peer) {
if (threadIdx.x == 0 && blockIdx.x == 0) {
rocshmem_putmem(dst_ptr, src_ptr, nelement, peer);
rocshmem_fence();
rocshmem_ulonglong_p((unsigned long long*)sig_addr,
(unsigned long long)sigval,
peer);
}
}

void te_rocshmem_putmem_signal(void* dst_ptr, const void* src_ptr, size_t nelement,
uint64_t* sig_addr, uint64_t sigval, int peer,
hipStream_t cur_stream) {
hipLaunchKernelGGL(rocshmem_putmem_signal_kernel,
dim3(1), dim3(1), 0, cur_stream,
dst_ptr, src_ptr, nelement, sig_addr,
sigval, peer);
}

void te_rocshmem_wait_on_stream(uint64_t* sig_addr,
WaitKind wait_kind,
hipStream_t cur_stream) {
uint64_t wait_value = 1;
uint64_t signal_reset = 0;

NVTE_CHECK(wait_kind >= WaitKind::KERNEL_WAIT &&
wait_kind <= WaitKind::STREAM_WAIT,
"Invalid wait kind");

switch (wait_kind) {
// ### wait_until_on_stream not yet implemented for rocshmem ###
// ### KernelWait is robust but slightly slower due to launch ###
case WaitKind::ROCSHMEM_WAIT:
printf("WARNING: rocshmem wait is not implemented yet, defaulting to kernel wait.\n");
// rocshmem__ulonglong_wait_until_on_stream(sig_addr,
// ROCSHMEM_CMP_EQ,
// wait_value,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a warning here saying that rocshmem_wait is not supported on ROCm.

// cur_stream);
// hipStreamWriteValue64(cur_stream,
// reinterpret_cast<hipDeviceptr_t>(sig_addr),
// signal_reset, 0);
// break;
case WaitKind::KERNEL_WAIT:
hipLaunchKernelGGL(wait_until_on_stream_and_reset,
dim3(1), dim3(1), 0, cur_stream,
sig_addr, wait_value, signal_reset);
hipStreamWriteValue64(cur_stream,
reinterpret_cast<hipDeviceptr_t>(sig_addr),
signal_reset, 0);
break;
case WaitKind::STREAM_WAIT:
hipStreamWaitValue64(cur_stream,
reinterpret_cast<hipDeviceptr_t>(sig_addr),
wait_value, hipStreamWaitValueGte);
hipStreamWriteValue64(cur_stream,
reinterpret_cast<hipDeviceptr_t>(sig_addr),
signal_reset, 0);
break;
}
}

int te_rocshmem_init_thread(int required, int* provided) {
if (required == 0 && provided == nullptr) {
rocshmem_init();
return 0;
} else {
return rocshmem_init_thread(required, provided);
}
}

void te_rocshmem_finalize() {
rocshmem_finalize();
}

int te_rocshmem_my_pe() {
return rocshmem_my_pe();
}

int te_rocshmem_n_pes() {
return rocshmem_n_pes();
}

void* te_rocshmem_malloc(size_t size) {
return rocshmem_malloc(size);
}

void te_rocshmem_free(void* ptr) {
rocshmem_free(ptr);
}

void te_rocshmem_wait_until(uint64_t* signal_addr, uint64_t expected_value,
hipStream_t stream);
33 changes: 33 additions & 0 deletions transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*************************************************************************
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
* License for AMD contributions = MIT. See LICENSE for more information
*************************************************************************/

#pragma once

#include <cstdint>

enum class WaitKind : uint8_t {
KERNEL_WAIT = 0,
ROCSHMEM_WAIT = 1,
STREAM_WAIT = 2
};

void te_rocshmem_wait_on_stream(uint64_t *sig_addr, WaitKind wait_kind, hipStream_t cur_stream);

void te_rocshmem_putmem_signal(void* dst_ptr, const void* src_ptr, size_t nelement,
uint64_t* sig_addr, uint64_t sigval, int peer, hipStream_t cur_stream);

/*
These are minimal wrappers around rocshmem functions. As pytorch is a cpp extension,
rocshmem is a static library, and rocshmem does not have separate host / device libraries
we need to move these to common, which handles device code properly.
*/
int te_rocshmem_init_thread(int required, int* provided);
void te_rocshmem_finalize();
int te_rocshmem_my_pe();
int te_rocshmem_n_pes();
void* te_rocshmem_malloc(size_t size);
void te_rocshmem_free(void* ptr);
void te_rocshmem_wait_until(uint64_t* signal_addr, uint64_t expected_value,
hipStream_t stream);
14 changes: 14 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,20 @@ void nvshmem_send_on_current_stream(at::Tensor src, at::Tensor dst, int peer, at
void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_kind);

void nvshmem_finalize();
#else
/***************************************************************************************************
* ROCSHMEM APIs
**************************************************************************************************/

void init_rocshmem_backend(c10d::ProcessGroup *process_group);

at::Tensor create_rocshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype);

void rocshmem_send_on_current_stream(at::Tensor src, at::Tensor dst, int peer, at::Tensor signal);

void rocshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_kind);

void rocshmem_finalize();
#endif

} // namespace transformer_engine::pytorch
Expand Down
19 changes: 19 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,25 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("nvshmem_finalize", &transformer_engine::pytorch::nvshmem_finalize,
"Clean up and finalize the NVSHMEM communication backend and free associated resources",
py::call_guard<py::gil_scoped_release>());
#else
// nvshmem/rocshmem wrappers
m.def("init_nvshmem_backend", &transformer_engine::pytorch::init_rocshmem_backend,
"Initialize ROCSHMEM backend with Pytorch distributed process groups",
py::call_guard<py::gil_scoped_release>());
m.def("create_nvshmem_tensor", &transformer_engine::pytorch::create_rocshmem_tensor,
"Create a tensor in ROCSHMEM shared memory", py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_send_on_current_stream",
&transformer_engine::pytorch::rocshmem_send_on_current_stream,
"Asynchronously send tensor data to a remote PE using ROCSHMEM on the current HIP stream",
py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_wait_on_current_stream",
&transformer_engine::pytorch::rocshmem_wait_on_current_stream,
"Wait for a signal value to be updated by a remote PE using ROCSHMEM on the current HIP "
"stream",
py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_finalize", &transformer_engine::pytorch::rocshmem_finalize,
"Clean up and finalize the ROCSHMEM communication backend and free associated resources",
py::call_guard<py::gil_scoped_release>());
#endif

// multi-tensor functions
Expand Down
Loading