Skip to content

Commit c793b5e

Browse files
committed
Fixes for rocm 7.0 and dev
1 parent ac337b8 commit c793b5e

File tree

9 files changed

+60
-48
lines changed

9 files changed

+60
-48
lines changed

examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -185,32 +185,29 @@ def _get_layer_args(config, tp_group, tp_size, reference=False):
185185

186186
return args, kwargs, input_shape
187187

188-
def create_ub_cfgs(config_file:str, tp_size: int = 8):
188+
def create_ub_cfgs(config_file: str, tp_size: int = 8):
189189
import json
190190
with open(config_file, 'r') as f:
191191
data = json.load(f)
192192
cfgs = {}
193193
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
194194
layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
195-
196-
for name, method in data.items():
197-
is_reduce_scatter = name in layers_reduce_scatter_overlap
198-
199-
layers_all_gather_overlap = [
195+
layers_all_gather_overlap = [
200196
"qkv_fprop",
201197
"qkv_dgrad",
202198
"proj_dgrad",
203199
"fc1_fprop",
204200
"fc1_dgrad",
205201
"fc2_dgrad",
206202
]
203+
204+
for name, method in data.items():
207205
if _MIN_STREAM_PRIORITY is None or _MAX_STREAM_PRIORITY is None:
208206
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = tex.get_stream_priority_range()
209207

210-
211208
cfg = {
212209
"method": method,
213-
"is_reduce_scatter": is_reduce_scatter,
210+
"is_reduce_scatter": name in layers_reduce_scatter_overlap,
214211
"num_sm": 1 if method == "ring_exchange" else 16,
215212
"cga_size": 1 if method == "ring_exchange" else 2,
216213
"set_sm_margin": False,

transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@
2121
#define HALF_BYTES 2
2222
#define UB_MAX_SM 32
2323

24+
#ifdef __HIP_PLATFORM_AMD__
25+
#define half_dtype hip_bfloat16
26+
#define __nv_fp8_e5m2 te_hip_fp8_e5m2
27+
#define __nv_fp8_e4m3 te_hip_fp8_e4m3
28+
#endif
29+
2430
using namespace std::placeholders;
2531

2632
namespace transformer_engine {
@@ -448,7 +454,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
448454
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
449455
bool grad, bool accumulate, bool use_split_accumulator,
450456
TensorWrapper &rs_output, cudaStream_t stream_main) {
451-
printf("split_overlap_rs_pipeline");
457+
printf("split_overlap_rs_pipeline\n");
452458
// Get GEMM dimensions
453459
int ori_sms = _ub_comm->sms;
454460
_ub_comm->use_ce = _use_ce;
@@ -596,7 +602,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
596602
CommOverlapType comm_type, int num_max_streams,
597603
int comm_cga_size, int gemm_priority, int comm_priority,
598604
int num_comm_sm, bool set_sm_margin, bool use_ce,
599-
bool atomic_gemm, bool aggregate, bool use_rd = false)
605+
bool atomic_gemm, bool aggregate, bool use_rd)
600606
: CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size,
601607
allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size,
602608
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
@@ -798,7 +804,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
798804
TensorWrapper &workspace, bool grad, bool accumulate,
799805
bool use_split_accumulator, TensorWrapper &B_copy,
800806
cudaStream_t stream_main) {
801-
printf("split_overlap_ag");
807+
printf("split_overlap_ag\n");
802808
int ori_sms = _ub_comm->sms;
803809
_ub_comm->use_ce = _use_ce;
804810
_ub_comm->sms = _num_comm_sm;
@@ -960,12 +966,12 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
960966
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
961967
** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
962968
*/
963-
void CommOverlapP2PBase::split_overlap_ag_rd(TensorWrapper &A, bool transa, TensorWrapper &B,
964-
bool transb, TensorWrapper &D, TensorWrapper &bias,
965-
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
966-
bool grad, bool accumulate, bool use_split_accumulator,
967-
TensorWrapper &B_copy, cudaStream_t stream_main) {
968-
printf("split_overlap_ag_rd");
969+
void CommOverlapP2PBase::split_overlap_ag_rd(const TensorWrapper &A, bool transa, const TensorWrapper &B,
970+
bool transb, TensorWrapper &D, TensorWrapper &bias,
971+
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
972+
bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy,
973+
cudaStream_t stream_main) {
974+
printf("split_overlap_ag_rd\n");
969975
int ori_sms = _ub_comm->sms;
970976
_ub_comm->use_ce = _use_ce;
971977
_ub_comm->sms = _num_comm_sm;
@@ -1025,12 +1031,12 @@ void CommOverlapP2PBase::split_overlap_ag_rd(TensorWrapper &A, bool transa, Tens
10251031
// GEMM
10261032
char *input_b_chunk_ptr = input_b_ptr + send_offset;
10271033
auto input_b_chunk =
1028-
TensorWrapper(reinterpret_cast<void *>(input_b_chunk_ptr), {n_chunk * 2, k}, B.dtype(),
1034+
TensorWrapper(reinterpret_cast<void *>(input_b_chunk_ptr), std::vector<size_t>{n_chunk * 2, k}, B.dtype(),
10291035
nullptr, nullptr, B.scale_inv());
10301036

10311037
char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes);
10321038
auto output_chunk = TensorWrapper(reinterpret_cast<void *>(output_chunk_ptr),
1033-
{n_chunk * 2, m}, D.dtype(), D.amax(), D.scale(), nullptr);
1039+
std::vector<size_t>{n_chunk * 2, m}, D.dtype(), D.amax(), D.scale(), nullptr);
10341040

10351041
char *aux_chunk_ptr =
10361042
(do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr;
@@ -1084,12 +1090,12 @@ void CommOverlapP2PBase::split_overlap_ag_rd(TensorWrapper &A, bool transa, Tens
10841090
cudaStream_t compute_stream = _stream_compute[chunk_id % _stream_compute.size()];
10851091

10861092
auto input_b_chunk = TensorWrapper(_ubufs[chunk_id].dptr(),
1087-
{n_chunk, k}, B.dtype(),
1093+
std::vector<size_t>{n_chunk, k}, B.dtype(),
10881094
nullptr, nullptr, B.scale_inv());
10891095

10901096
char* output_chunk_ptr = output_ptr + (chunk_id * output_chunk_bytes);
10911097
auto output_chunk = TensorWrapper(reinterpret_cast<void *>(output_chunk_ptr),
1092-
{n_chunk, m},
1098+
std::vector<size_t>{n_chunk, m},
10931099
D.dtype(), D.amax(), D.scale(), nullptr);
10941100

10951101
char *aux_chunk_ptr =
@@ -1140,12 +1146,12 @@ void CommOverlapP2PBase::split_overlap_ag_rd(TensorWrapper &A, bool transa, Tens
11401146
cudaStream_t compute_stream = _stream_compute[new_chunk_id % _stream_compute.size()];
11411147

11421148
auto input_b_chunk = TensorWrapper(_ubufs[new_chunk_id].dptr(),
1143-
{n_chunk, k}, B.dtype(),
1149+
std::vector<size_t>{n_chunk, k}, B.dtype(),
11441150
nullptr, nullptr, B.scale_inv());
11451151

11461152
char* output_chunk_ptr = output_ptr + (new_chunk_id * output_chunk_bytes);
11471153
auto output_chunk = TensorWrapper(reinterpret_cast<void *>(output_chunk_ptr),
1148-
{n_chunk, m},
1154+
std::vector<size_t>{n_chunk, m},
11491155
D.dtype(), D.amax(), D.scale(), nullptr);
11501156

11511157
char *aux_chunk_ptr =
@@ -1271,7 +1277,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
12711277
TensorWrapper &workspace, bool grad, bool accumulate,
12721278
bool use_split_accumulator, TensorWrapper &rs_output,
12731279
cudaStream_t stream_main) {
1274-
printf("split_overlap_rs_p2p");
1280+
printf("split_overlap_rs_p2p\n");
12751281
int ori_sms = _ub_comm->sms;
12761282
_ub_comm->use_ce = _use_ce;
12771283
_ub_comm->sms = _num_comm_sm;

transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,17 @@
88
#include <cuda_fp8.h>
99
#include <cuda_runtime.h>
1010

11-
#if __CUDA_ARCH__ >= 800
12-
#define half_dtype nv_bfloat16
13-
#else
14-
#define half_dtype half
15-
#endif
1611

1712
#ifdef __HIP_PLATFORM_AMD__
1813
#define half_dtype hip_bfloat16
1914
#define __nv_fp8_e5m2 te_hip_fp8_e5m2
2015
#define __nv_fp8_e4m3 te_hip_fp8_e4m3
16+
#else
17+
#if __CUDA_ARCH__ >= 800
18+
#define half_dtype nv_bfloat16
19+
#else
20+
#define half_dtype half
21+
#endif
2122
#endif
2223

2324
#include <assert.h>
@@ -2094,7 +2095,8 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int
20942095
}
20952096
}
20962097
#else
2097-
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
2098+
int threads = comm->threads;
2099+
if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) {
20982100
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) callranks_agMC(16) callranks_agMC(32)
20992101
} else {
21002102
callranks_ag(2) callranks_ag(4) callranks_ag(8) callranks_ag(16) callranks_ag(32)
@@ -2150,7 +2152,7 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const
21502152
}
21512153
#else
21522154
int threads = comm->threads;
2153-
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
2155+
if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) {
21542156
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) callranks_rsMC(16) callranks_rsMC(32)
21552157
} else {
21562158
callranks_rs(2) callranks_rs(4) callranks_rs(8) callranks_rs(16) callranks_rs(32)
@@ -2666,6 +2668,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
26662668
cudaLaunchKernelExC(&cfg, reinterpret_cast<void *>(kuserbuffers_pushsend), kernelArgs));
26672669
#else
26682670
cudaLaunchKernel(reinterpret_cast<void *>(kuserbuffers_pushsend), sms, threads, kernelArgs, 0, stream));
2671+
#endif
26692672
}
26702673
}
26712674

@@ -2812,7 +2815,7 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler
28122815
void *flagptr_send = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 0);
28132816
void *flagptr_recv = GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 0);
28142817

2815-
#ifndef
2818+
#ifndef __HIP_PLATFORM_AMD__
28162819
SETUP_LAUNCH_CONFIG(comm->sms, 1024, stream);
28172820
#else
28182821
int sms = comm->sms;

transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class CommOverlapCore {
5858
int _comm_priority;
5959
bool _atomic_gemm{false};
6060
bool _is_p2p{false};
61+
bool _use_rd{false};
6162

6263
TensorWrapper _ubuf;
6364
TensorWrapper _counter;
@@ -93,6 +94,8 @@ class CommOverlapCore {
9394

9495
bool is_p2p_overlap() { return _is_p2p; }
9596

97+
bool is_use_rd() { return _use_rd; }
98+
9699
bool is_fp8_ubuf() { return _ubuf.element_size() == 1; }
97100

98101
virtual void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B,
@@ -199,7 +202,9 @@ class CommOverlapBase : public CommOverlapCore {
199202
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
200203
TensorWrapper &workspace, bool grad, bool accumulate,
201204
bool use_split_accumulator, TensorWrapper &B_copy,
202-
cudaStream_t stream_main) override;
205+
cudaStream_t stream_main) override {
206+
NVTE_ERROR("Operation not supported.");
207+
};
203208

204209
/*
205210
** Split FPROP GEMM + ReduceScatter

transformer_engine/common/util/pybind_helper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
#include <transformer_engine/fused_attn.h>
1515
#include <transformer_engine/transformer_engine.h>
1616

17+
#ifdef __HIP_PLATFORM_AMD__
18+
#include "hip_runtime.h"
19+
#else
1720
#include "cuda_runtime.h"
21+
#endif
1822

1923
// Define fused-attention handles separately for USE_ROCM
2024
#ifndef USE_ROCM

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,6 @@
1313

1414
#include "common.h"
1515

16-
#ifdef USE_ROCM
17-
namespace transformer_engine {
18-
//dummy CommOverlapCore, CommOverlapType in rocm
19-
class CommOverlapCore{};
20-
class CommOverlapType{};
21-
}
22-
#endif
23-
2416
namespace transformer_engine::pytorch {
2517

2618
/***************************************************************************************************
@@ -456,7 +448,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
456448
int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
457449
int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 3,
458450
bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true,
459-
bool aggregate = false);
451+
bool aggregate = false, bool use_rd = false);
460452

461453
~CommOverlapP2P() {}
462454

transformer_engine/pytorch/csrc/extensions/gemm.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
187187
std::move(swizzle_scaling_factors(B_tensor, !transb)));
188188

189189
if (comm_overlap) {
190-
#ifndef USE_ROCM
191190
// Prepare extra output tensor
192191
TensorWrapper extra_output_tensor;
193192
if (extra_output.has_value()) {
@@ -213,6 +212,13 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
213212
accumulate, use_split_accumulator,
214213
extra_output_tensor, main_stream);
215214
});
215+
} else if (comm_overlap->is_use_rd()) {
216+
NVTE_SCOPED_GIL_RELEASE({
217+
comm_overlap->split_overlap_ag_rd(A_tensor, transa, B_tensor, transb, D_tensor,
218+
bias_tensor, te_pre_gelu_out, te_workspace, grad,
219+
accumulate, use_split_accumulator,
220+
extra_output_tensor, main_stream);
221+
});
216222
} else {
217223
NVTE_SCOPED_GIL_RELEASE({
218224
comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor,
@@ -238,9 +244,6 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
238244
});
239245
}
240246
}
241-
#else
242-
NVTE_ERROR("ROCm TE does not support comm_overlap\n");
243-
#endif //!USE_ROCM
244247
} else {
245248
// Launch GEMM
246249
NVTE_SCOPED_GIL_RELEASE({

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,9 +403,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
403403
py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 1,
404404
py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1,
405405
py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false,
406-
py::arg("use_ce") = true, py::arg("aggregate") = false, py::arg("use_rd" = false))
406+
py::arg("use_ce") = true, py::arg("aggregate") = false, py::arg("use_rd") = false)
407407
.def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"),
408408
py::arg("local_chunk") = false)
409409
.def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false,
410-
py::arg("shape") = std::nullopt,);
410+
py::arg("shape") = std::nullopt);
411411
}

transformer_engine/pytorch/module/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ def get_default_config(name):
308308
"comm_priority": _MAX_STREAM_PRIORITY,
309309
"gemm_priority": _MIN_STREAM_PRIORITY,
310310
"pipeline_rs_overlap_first_gemm": False,
311+
"use_rd": False,
311312
}
312313
return default_cfg
313314

@@ -326,6 +327,7 @@ def add_ub(
326327
comm_priority: int = 0,
327328
gemm_priority: int = 0,
328329
pipeline_rs_overlap_first_gemm: bool = False,
330+
use_rd: bool = False,
329331
) -> None:
330332
if atomic_gemm:
331333
warnings.warn(

0 commit comments

Comments
 (0)