Skip to content

Commit e7d2a61

Browse files
committed
Experimental rocSHMEM support
1 parent 0aef9a9 commit e7d2a61

File tree

10 files changed

+460
-1
lines changed

10 files changed

+460
-1
lines changed

build_tools/pytorch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,14 @@ def setup_pytorch_extension(
9494
libraries.append("nvshmem_host")
9595
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")
9696

97+
if bool(int(os.getenv("NVTE_ENABLE_ROCSHMEM", 0))):
98+
cxx_flags.append("-DNVTE_ENABLE_ROCSHMEM")
99+
mpi_home = Path(os.getenv("MPI_HOME", "/usr/lib/x86_64-linux-gnu/openmpi"))
100+
include_dirs.append(mpi_home / "include")
101+
library_dirs.append(mpi_home / "lib")
102+
libraries.append("mpi_cxx")
103+
104+
97105
# Construct PyTorch CUDA extension
98106
sources = [str(path) for path in sources]
99107
include_dirs = [str(path) for path in include_dirs]

setup.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@ def setup_common_extension() -> CMakeExtension:
7373
cmake_flags.append("-DUSE_FUSED_ATTN_AOTRITON=OFF")
7474
if int(os.getenv("NVTE_FUSED_ATTN_CK", "1"))==0 or int(os.getenv("NVTE_FUSED_ATTN", "1"))==0:
7575
cmake_flags.append("-DUSE_FUSED_ATTN_CK=OFF")
76+
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))) and os.getenv("NVTE_ENABLE_ROCSHMEM") is None:
77+
os.environ["NVTE_ENABLE_ROCSHMEM"] = '1'
78+
os.environ["NVTE_ENABLE_NVSHMEM"] = '0'
79+
print("Turning NVTE_ENABLE_ROCSHMEM on, disabling NVTE_ENABLE_NVSHMEM")
80+
if bool(int(os.getenv("NVTE_ENABLE_ROCSHMEM", "0"))):
81+
cmake_flags.append("-DNVTE_ENABLE_ROCSHMEM=ON")
82+
7683
else:
7784
cmake_flags.append("-DUSE_ROCM=OFF")
7885
cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]

transformer_engine/common/CMakeLists.txt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,26 @@ if(USE_CUDA)
381381

382382
# Hack to enable dynamic loading in cuDNN frontend
383383
target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING)
384+
else()
385+
option(NVTE_ENABLE_ROCSHMEM "Compile with ROCSHMEM library" OFF)
386+
if (NVTE_ENABLE_ROCSHMEM)
387+
add_subdirectory(rocshmem_api)
388+
if(DEFINED ENV{ROCSHMEM_HOME})
389+
set(ROCSHMEM_HOME "$ENV{ROCSHMEM_HOME}" CACHE STRING "Location of ROCSHMEM installation")
390+
else()
391+
set(ROCSHMEM_HOME "/opt/rocm" CACHE STRING "Location of ROCSHMEM installation (default)")
392+
endif()
393+
target_link_options(transformer_engine PRIVATE
394+
-fgpu-rdc
395+
)
396+
target_link_libraries(transformer_engine PUBLIC
397+
-Wl,--whole-archive
398+
rocshmemapi
399+
"${ROCSHMEM_HOME}/lib/librocshmem.a"
400+
-Wl,--no-whole-archive
401+
)
402+
target_include_directories(transformer_engine PUBLIC ${ROCSHMEMAPI_INCLUDE_DIR})
403+
endif()
384404
endif()
385405

386406
# Helper functions to make header files with C++ strings

transformer_engine/common/libtransformer_engine.version

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
*transformer_engine::CommOverlapP2PBase*;
1919
*transformer_engine::CommOverlapCore*;
2020
*nvshmem_wait_on_stream*;
21-
*nvshmemi_init_thread*
21+
*nvshmemi_init_thread*;
22+
*rocshmem*
2223
};
2324
local: *;
2425
};
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
2+
# License for AMD contributions = MIT. See LICENSE for more information
3+
cmake_minimum_required (VERSION 3.21)
4+
project(rocshmem LANGUAGES HIP)
5+
6+
find_package(hipblaslt REQUIRED)
7+
find_package(hiprtc REQUIRED)
8+
find_package(hip REQUIRED)
9+
find_package(MPI REQUIRED)
10+
11+
if(NOT DEFINED ENV{NVTE_ROCM_ARCH})
12+
set(CMAKE_HIP_ARCHITECTURES gfx942 gfx950)
13+
else()
14+
set(CMAKE_HIP_ARCHITECTURES $ENV{NVTE_ROCM_ARCH})
15+
endif()
16+
17+
if(DEFINED ENV{ROCSHMEM_HOME})
18+
set(ROCSHMEM_HOME "$ENV{ROCSHMEM_HOME}" CACHE STRING "Location of ROCSHMEM installation")
19+
else()
20+
set(ROCSHMEM_HOME "/opt/rocm" CACHE STRING "Location of ROCSHMEM installation (default)")
21+
endif()
22+
23+
set(ROCSHMEM_INCLUDE_DIR "${ROCSHMEM_HOME}/include/rocshmem")
24+
if(NOT EXISTS "${ROCSHMEM_INCLUDE_DIR}")
25+
set(ROCSHMEM_INCLUDE_DIR "${ROCSHMEM_HOME}/include")
26+
endif()
27+
28+
add_library(rocshmemapi OBJECT rocshmem_waitkernel.hip)
29+
30+
target_compile_options(rocshmemapi PRIVATE
31+
$<$<COMPILE_LANGUAGE:HIP>:-fgpu-rdc>
32+
)
33+
34+
target_include_directories(rocshmemapi PUBLIC
35+
"${ROCSHMEM_INCLUDE_DIR}"
36+
"${CMAKE_CURRENT_SOURCE_DIR}"
37+
"${MPI_INCLUDE_PATH}"
38+
)
39+
40+
target_link_libraries(rocshmemapi PUBLIC
41+
"${ROCSHMEM_HOME}/lib/librocshmem.a"
42+
MPI::MPI_CXX
43+
hip::host
44+
hip::device
45+
roctx64
46+
hiprtc
47+
roc::hipblaslt
48+
)
49+
50+
set_target_properties(rocshmemapi PROPERTIES
51+
CXX_STANDARD 17
52+
HIP_STANDARD 17
53+
POSITION_INDEPENDENT_CODE ON
54+
HIP_SEPARABLE_COMPILATION ON
55+
)
56+
57+
set(ROCSHMEMAPI_INCLUDE_DIR "${ROCSHMEM_INCLUDE_DIR}" PARENT_SCOPE)
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*************************************************************************
2+
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
* License for AMD contributions = MIT. See LICENSE for more information
4+
*************************************************************************/
5+
6+
#include <hip/hip_runtime.h>
7+
#include <rocshmem.hpp>
8+
9+
#include "../util/logging_hip.h"
10+
#include "rocshmem_waitkernel.hpp"
11+
12+
using namespace rocshmem;
13+
14+
__global__ void wait_until_on_stream_and_reset(uint64_t *wait_flag,
15+
uint64_t wait_value,
16+
uint64_t signal_reset) {
17+
rocshmem_ulonglong_wait_until((unsigned long long*)wait_flag,
18+
ROCSHMEM_CMP_EQ,
19+
(unsigned long long)wait_value);
20+
}
21+
22+
__global__ void rocshmem_putmem_signal_kernel(void* dst_ptr, const void* src_ptr,
23+
size_t nelement, uint64_t* sig_addr,
24+
uint64_t sigval, int peer) {
25+
if (threadIdx.x == 0 && blockIdx.x == 0) {
26+
rocshmem_putmem(dst_ptr, src_ptr, nelement, peer);
27+
rocshmem_fence();
28+
rocshmem_ulonglong_p((unsigned long long*)sig_addr,
29+
(unsigned long long)sigval,
30+
peer);
31+
}
32+
}
33+
34+
void te_rocshmem_putmem_signal(void* dst_ptr, const void* src_ptr, size_t nelement,
35+
uint64_t* sig_addr, uint64_t sigval, int peer,
36+
hipStream_t cur_stream) {
37+
hipLaunchKernelGGL(rocshmem_putmem_signal_kernel,
38+
dim3(1), dim3(1), 0, cur_stream,
39+
dst_ptr, src_ptr, nelement, sig_addr,
40+
sigval, peer);
41+
}
42+
43+
void te_rocshmem_wait_on_stream(uint64_t* sig_addr,
44+
WaitKind wait_kind,
45+
hipStream_t cur_stream) {
46+
uint64_t wait_value = 1;
47+
uint64_t signal_reset = 0;
48+
49+
NVTE_CHECK(wait_kind >= WaitKind::KERNEL_WAIT &&
50+
wait_kind <= WaitKind::STREAM_WAIT,
51+
"Invalid wait kind");
52+
53+
switch (wait_kind) {
54+
// ### wait_until_on_stream not yet implemented for rocshmem ###
55+
// ### KernelWait is robust but slightly slower due to launch ###
56+
case WaitKind::ROCSHMEM_WAIT:
57+
// rocshmem__ulonglong_wait_until_on_stream(sig_addr,
58+
// ROCSHMEM_CMP_EQ,
59+
// wait_value,
60+
// cur_stream);
61+
// hipStreamWriteValue64(cur_stream,
62+
// reinterpret_cast<hipDeviceptr_t>(sig_addr),
63+
// signal_reset, 0);
64+
// break;
65+
case WaitKind::KERNEL_WAIT:
66+
hipLaunchKernelGGL(wait_until_on_stream_and_reset,
67+
dim3(1), dim3(1), 0, cur_stream,
68+
sig_addr, wait_value, signal_reset);
69+
hipStreamWriteValue64(cur_stream,
70+
reinterpret_cast<hipDeviceptr_t>(sig_addr),
71+
signal_reset, 0);
72+
break;
73+
case WaitKind::STREAM_WAIT:
74+
hipStreamWaitValue64(cur_stream,
75+
reinterpret_cast<hipDeviceptr_t>(sig_addr),
76+
wait_value, hipStreamWaitValueGte);
77+
hipStreamWriteValue64(cur_stream,
78+
reinterpret_cast<hipDeviceptr_t>(sig_addr),
79+
signal_reset, 0);
80+
break;
81+
}
82+
}
83+
84+
int te_rocshmem_init_thread(int required, int* provided) {
85+
if (required == 0 && provided == nullptr) {
86+
rocshmem_init();
87+
return 0;
88+
} else {
89+
return rocshmem_init_thread(required, provided);
90+
}
91+
}
92+
93+
void te_rocshmem_finalize() {
94+
rocshmem_finalize();
95+
}
96+
97+
int te_rocshmem_my_pe() {
98+
return rocshmem_my_pe();
99+
}
100+
101+
int te_rocshmem_n_pes() {
102+
return rocshmem_n_pes();
103+
}
104+
105+
void* te_rocshmem_malloc(size_t size) {
106+
return rocshmem_malloc(size);
107+
}
108+
109+
void te_rocshmem_free(void* ptr) {
110+
rocshmem_free(ptr);
111+
}
112+
113+
void te_rocshmem_wait_until(uint64_t* signal_addr, uint64_t expected_value,
114+
hipStream_t stream);
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*************************************************************************
2+
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
* License for AMD contributions = MIT. See LICENSE for more information
4+
*************************************************************************/
5+
6+
#pragma once
7+
8+
#include <cstdint>
9+
10+
enum class WaitKind : uint8_t {
11+
KERNEL_WAIT = 0,
12+
ROCSHMEM_WAIT = 1,
13+
STREAM_WAIT = 2
14+
};
15+
16+
void te_rocshmem_wait_on_stream(uint64_t *sig_addr, WaitKind wait_kind, hipStream_t cur_stream);
17+
18+
void te_rocshmem_putmem_signal(void* dst_ptr, const void* src_ptr, size_t nelement,
19+
uint64_t* sig_addr, uint64_t sigval, int peer, hipStream_t cur_stream);
20+
21+
/*
22+
These are minimal wrappers around rocshmem functions. As pytorch is a cpp extension,
23+
rocshmem is a static library, and rocshmem does not have separate host / device libraries
24+
we need to move these to common, which handles device code properly.
25+
*/
26+
int te_rocshmem_init_thread(int required, int* provided);
27+
void te_rocshmem_finalize();
28+
int te_rocshmem_my_pe();
29+
int te_rocshmem_n_pes();
30+
void* te_rocshmem_malloc(size_t size);
31+
void te_rocshmem_free(void* ptr);
32+
void te_rocshmem_wait_until(uint64_t* signal_addr, uint64_t expected_value,
33+
hipStream_t stream);

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,20 @@ void nvshmem_send_on_current_stream(at::Tensor src, at::Tensor dst, int peer, at
395395
void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_kind);
396396

397397
void nvshmem_finalize();
398+
#else
399+
/***************************************************************************************************
400+
* ROCSHMEM APIs
401+
**************************************************************************************************/
402+
403+
void init_rocshmem_backend(c10d::ProcessGroup *process_group);
404+
405+
at::Tensor create_rocshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype);
406+
407+
void rocshmem_send_on_current_stream(at::Tensor src, at::Tensor dst, int peer, at::Tensor signal);
408+
409+
void rocshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_kind);
410+
411+
void rocshmem_finalize();
398412
#endif
399413

400414
} // namespace transformer_engine::pytorch

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,44 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
303303
m.def("nvshmem_finalize", &transformer_engine::pytorch::nvshmem_finalize,
304304
"Clean up and finalize the NVSHMEM communication backend and free associated resources",
305305
py::call_guard<py::gil_scoped_release>());
306+
#else
307+
// rocshmem functions
308+
m.def("init_rocshmem_backend", &transformer_engine::pytorch::init_rocshmem_backend,
309+
"Initialize ROCSHMEM backend with Pytorch distributed process groups",
310+
py::call_guard<py::gil_scoped_release>());
311+
m.def("create_rocshmem_tensor", &transformer_engine::pytorch::create_rocshmem_tensor,
312+
"Create a tensor in ROCSHMEM shared memory", py::call_guard<py::gil_scoped_release>());
313+
m.def("rocshmem_send_on_current_stream",
314+
&transformer_engine::pytorch::rocshmem_send_on_current_stream,
315+
"Asynchronously send tensor data to a remote PE using ROCSHMEM on the current HIP stream",
316+
py::call_guard<py::gil_scoped_release>());
317+
m.def("rocshmem_wait_on_current_stream",
318+
&transformer_engine::pytorch::rocshmem_wait_on_current_stream,
319+
"Wait for a signal value to be updated by a remote PE using ROCSHMEM on the current HIP "
320+
"stream",
321+
py::call_guard<py::gil_scoped_release>());
322+
m.def("rocshmem_finalize", &transformer_engine::pytorch::rocshmem_finalize,
323+
"Clean up and finalize the ROCSHMEM communication backend and free associated resources",
324+
py::call_guard<py::gil_scoped_release>());
325+
326+
// nvshmem wrappers
327+
m.def("init_nvshmem_backend", &transformer_engine::pytorch::init_rocshmem_backend,
328+
"Initialize ROCSHMEM backend with Pytorch distributed process groups",
329+
py::call_guard<py::gil_scoped_release>());
330+
m.def("create_nvshmem_tensor", &transformer_engine::pytorch::create_rocshmem_tensor,
331+
"Create a tensor in ROCSHMEM shared memory", py::call_guard<py::gil_scoped_release>());
332+
m.def("nvshmem_send_on_current_stream",
333+
&transformer_engine::pytorch::rocshmem_send_on_current_stream,
334+
"Asynchronously send tensor data to a remote PE using ROCSHMEM on the current HIP stream",
335+
py::call_guard<py::gil_scoped_release>());
336+
m.def("nvshmem_wait_on_current_stream",
337+
&transformer_engine::pytorch::rocshmem_wait_on_current_stream,
338+
"Wait for a signal value to be updated by a remote PE using ROCSHMEM on the current HIP "
339+
"stream",
340+
py::call_guard<py::gil_scoped_release>());
341+
m.def("nvshmem_finalize", &transformer_engine::pytorch::rocshmem_finalize,
342+
"Clean up and finalize the ROCSHMEM communication backend and free associated resources",
343+
py::call_guard<py::gil_scoped_release>());
306344
#endif
307345

308346
// multi-tensor functions

0 commit comments

Comments
 (0)