Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
498 changes: 498 additions & 0 deletions examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions examples/pytorch/comm_gemm_overlap/ub_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"proj_fprop" : "pipeline",
"fc2_fprop" : "ring_exchange",
"qkv_fprop" : "ring_exchange",
"fc1_fprop" : "recursive_doubling"
}
4 changes: 3 additions & 1 deletion hipify_custom_map.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
"util/cuda_runtime.h" : "util/hip_runtime.h",
"ATen/cudnn/Handle.h" : "ATen/miopen/Handle.h",
"CUfunc_cache" : "hipFuncCache_t",
"<nvtx3/nvToolsExt.h>" : "<roctracer/roctx.h>"
"<nvtx3/nvToolsExt.h>" : "<roctracer/roctx.h>",
"cudaLaunchKernel": "hipLaunchKernel",
"CUmemGenericAllocationHandle": "hipMemGenericAllocationHandle_t"
}
}

7 changes: 7 additions & 0 deletions tests/pytorch/distributed/run_layer_with_overlap.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/usr/bin/python3

# This file was modified for portability to AMDGPU
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -28,6 +30,11 @@
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

import transformer_engine.pytorch.cpp_extensions as tex
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
if not tex.device_supports_multicast():
os.environ["UB_SKIPMC"] = "1"


class multi_module_model(torch.nn.Module):
def __init__(self, module, num_layers, *args, **kwargs):
Expand Down
26 changes: 16 additions & 10 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,11 @@ else()
fused_rope/fused_rope.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu)
recipe/fp8_block_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp)

# process source code files
set(TE ${CMAKE_CURRENT_SOURCE_DIR}/../..)
Expand Down Expand Up @@ -261,17 +265,19 @@ if (USE_CUDA)
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
endif()

# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI
# Changed
option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF)
if (NVTE_UB_WITH_MPI)
find_package(MPI REQUIRED)
target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX)
target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES})
target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI)
endif()
# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI
# Changed
option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF)
if (NVTE_UB_WITH_MPI)
find_package(MPI REQUIRED)
target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX)
target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES})
target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI)
endif()

if (USE_CUDA)
# Hack to enable dynamic loading in cuDNN frontend
target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING)
else()
Expand Down
258 changes: 257 additions & 1 deletion transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand Down Expand Up @@ -360,8 +362,12 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks,

NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE));
NVTE_CHECK_CUDA(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE));
#ifdef __HIP_PLATFORM_AMD__
reinterpret_cast<int *>((reinterpret_cast<uintptr_t>((*comm)->flags) + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
#else
(*comm)->flags =
reinterpret_cast<int *>(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
#endif

using namespace std;

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand Down Expand Up @@ -36,7 +38,8 @@ enum class CommOverlapAlgo {
SPLIT_PIPELINED_RS_P2P = 4,
ATOMIC_GEMM_RS = 5,
ATOMIC_GEMM_AG_P2P = 6,
ATOMIC_GEMM_RS_P2P = 7
ATOMIC_GEMM_RS_P2P = 7,
SPLIT_PIPELINED_AG_RD_P2P = 8
};

class CommOverlapCore {
Expand All @@ -57,6 +60,7 @@ class CommOverlapCore {
int _comm_priority;
bool _atomic_gemm{false};
bool _is_p2p{false};
bool _use_rd{false};

TensorWrapper _ubuf;
TensorWrapper _counter;
Expand Down Expand Up @@ -92,6 +96,8 @@ class CommOverlapCore {

bool is_p2p_overlap() { return _is_p2p; }

bool is_use_rd() { return _use_rd; }

bool is_fp8_ubuf() { return _ubuf.element_size() == 1; }

virtual void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B,
Expand Down Expand Up @@ -133,6 +139,14 @@ class CommOverlapCore {
cudaStream_t stream_main) {
NVTE_ERROR("Operation is not implemented.");
}

virtual void split_overlap_ag_rd(const TensorWrapper &A, bool transa, const TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main) {
NVTE_ERROR("Operation is not implemented.");
}
}; // CommOverlapCore

class CommOverlapBase : public CommOverlapCore {
Expand Down Expand Up @@ -181,6 +195,19 @@ class CommOverlapBase : public CommOverlapCore {
NVTE_ERROR("Operation not supported.");
}

/*
** Split AllGather + GEMM using P2P communication using recursive doubling
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/
void split_overlap_ag_rd(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main) override {
NVTE_ERROR("Operation not supported.");
};

/*
** Split FPROP GEMM + ReduceScatter
*/
Expand All @@ -205,6 +232,7 @@ class CommOverlapP2PBase : public CommOverlapCore {
bool _is_reduce_scatter{false};
bool _use_multiatomic_ag{false};
bool _aggregate;
bool use_rd;
int _next_rank;
int _prev_rank;
int _rank_round_tp;
Expand All @@ -224,7 +252,7 @@ class CommOverlapP2PBase : public CommOverlapCore {
CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS,
int comm_cga_size = 1, int gemm_priority = 0, int comm_priority = 0,
int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true,
bool atomic_gemm = false, bool aggregate = false);
bool atomic_gemm = false, bool aggregate = false, bool use_rd = false);

virtual ~CommOverlapP2PBase();

Expand Down Expand Up @@ -260,6 +288,15 @@ class CommOverlapP2PBase : public CommOverlapCore {
bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main) override;

/*
** Split AllGather + GEMM using P2P communication using recursive doubling
*/
void split_overlap_ag_rd(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main) override;

/*
** Split ReduceScatter + GEMM using P2P communication
*/
Expand Down
7 changes: 3 additions & 4 deletions transformer_engine/common/util/cuda_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace {
#include "string_path_cuda_include.h"

} // namespace
#endif // __HIP_PLATFORM_AMD__
#endif // #ifndef __HIP_PLATFORM_AMD__

int num_devices() {
auto query_num_devices = []() -> int {
Expand Down Expand Up @@ -103,7 +103,6 @@ int sm_count(int device_id) {
return cache[device_id];
}

#ifndef __HIP_PLATFORM_AMD__
void stream_priority_range(int *low_priority, int *high_priority, int device_id) {
static std::vector<std::pair<int, int>> cache(num_devices());
static std::vector<std::once_flag> flags(num_devices());
Expand All @@ -125,7 +124,7 @@ void stream_priority_range(int *low_priority, int *high_priority, int device_id)
}

bool supports_multicast(int device_id) {
#if CUDART_VERSION >= 12010
#if !defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= 12010
// NOTE: This needs to be guarded at compile time because the
// CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED enum is not defined in earlier CUDA versions.
static std::vector<bool> cache(num_devices(), false);
Expand Down Expand Up @@ -155,7 +154,7 @@ bool supports_multicast(int device_id) {
#endif
}


#ifndef __HIP_PLATFORM_AMD__
const std::string &include_directory(bool required) {
static std::string path;

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/util/cuda_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ const std::string &sm_arch_name(int device_id = -1);
*/
int sm_count(int device_id = -1);

#ifndef __HIP_PLATFORM_AMD__
/* \brief Minimum and maximum stream priorities supported on device
*
* \param[in] device_id CUDA device (default is current device)
Expand All @@ -69,6 +68,7 @@ void stream_priority_range(int *low_priority, int *high_priority, int device_id
*/
bool supports_multicast(int device_id = -1);

#ifndef __HIP_PLATFORM_AMD__
/* \brief Path to CUDA Toolkit headers
*
* The path can be configured by setting NVTE_CUDA_INCLUDE_DIR in the
Expand Down
21 changes: 7 additions & 14 deletions transformer_engine/common/util/pybind_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
#define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_

#include <pybind11/pybind11.h>
//TODO: rocm does not support comm gemm overlap yet
#ifndef USE_ROCM
#include <transformer_engine/comm_gemm_overlap.h>
#endif
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/transformer_engine.h>

#ifdef __HIP_PLATFORM_AMD__
#include "hip_runtime.h"
#else
#include "cuda_runtime.h"
#endif

// Define fused-attention handles separately for USE_ROCM
#ifndef USE_ROCM
Expand All @@ -35,8 +36,6 @@
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend);
#endif

// Define comm overlap handles if not using ROCm
#ifndef USE_ROCM
#define NVTE_DECLARE_COMM_OVERLAP_HANDLES(m) \
pybind11::enum_<transformer_engine::CommOverlapType>(m, "CommOverlapType", \
pybind11::module_local()) \
Expand All @@ -53,7 +52,9 @@
transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \
.value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \
.value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \
.value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \
.value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P) \
.value("SPLIT_PIPELINED_AG_RD_P2P", \
transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_RD_P2P); \
py::class_<transformer_engine::CommOverlapCore, \
std::shared_ptr<transformer_engine::CommOverlapCore>>(m, "CommOverlapCore", \
pybind11::module_local()) \
Expand Down Expand Up @@ -88,14 +89,6 @@
py::call_guard<py::gil_scoped_release>(), py::arg("device_id") = -1); \
m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \
py::call_guard<py::gil_scoped_release>());
#else
#define NVTE_DECLARE_COMM_OVERLAP_HANDLES(m) \
pybind11::class_<transformer_engine::CommOverlapType>(m, "CommOverlapType", \
pybind11::module_local()); \
py::class_<transformer_engine::CommOverlapCore, \
std::shared_ptr<transformer_engine::CommOverlapCore>>(m, "CommOverlapCore", \
pybind11::module_local());
#endif

#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \
pybind11::enum_<transformer_engine::DType>(m, "DType", pybind11::module_local()) \
Expand Down
2 changes: 0 additions & 2 deletions transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
#include <torch/torch.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
#ifndef USE_ROCM
#include <transformer_engine/comm_gemm_overlap.h>
#endif
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/fused_rope.h>
#include <transformer_engine/gemm.h>
Expand Down
12 changes: 1 addition & 11 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,6 @@

#include "common.h"

#ifdef USE_ROCM
namespace transformer_engine {
//dummy CommOverlapCore, CommOverlapType in rocm
class CommOverlapCore{};
class CommOverlapType{};
}
#endif

namespace transformer_engine::pytorch {

/***************************************************************************************************
Expand Down Expand Up @@ -399,7 +391,6 @@ void nvshmem_finalize();

} // namespace transformer_engine::pytorch

#ifndef USE_ROCM
/***************************************************************************************************
* Comm+GEMM Overlap Wrappers
**************************************************************************************************/
Expand Down Expand Up @@ -457,7 +448,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 3,
bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true,
bool aggregate = false);
bool aggregate = false, bool use_rd = false);

~CommOverlapP2P() {}

Expand All @@ -467,6 +458,5 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
std::optional<std::vector<int64_t>> shape = std::nullopt);

}; // CommOverlapP2P
#endif // !USE_ROCM

#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
*
* See LICENSE for license information.
************************************************************************/
#ifndef USE_ROCM
#include "../extensions.h"
#include "transformer_engine/transformer_engine.h"

Expand Down Expand Up @@ -227,14 +226,14 @@ CommOverlapP2P::CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::Scal
te::CommOverlapType comm_type, int num_max_streams,
int comm_cga_size, int gemm_priority, int comm_priority,
int num_comm_sm, bool set_sm_margin, bool atomic_gemm, bool use_ce,
bool aggregate)
bool aggregate, bool use_rd)
: te::CommOverlapP2PBase(
buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), helper->myrank,
helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes,
tp_size, std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5),
std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams,
comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
atomic_gemm, aggregate) {}
atomic_gemm, aggregate, use_rd) {}

/*
** Copy input to _ubufs[0]
Expand Down Expand Up @@ -302,4 +301,3 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vecto
const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype());
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
}
#endif // !USE_ROCM
Loading