Skip to content

Commit e4e40e8

Browse files
committed
fast forward to v2.9 + optimizations
1 parent 2c1717c commit e4e40e8

File tree

9 files changed

+329
-49
lines changed

9 files changed

+329
-49
lines changed

examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _parse_args(argv=None, namespace=None):
6868
)
6969
parser.add_argument("--seed", type=int, default=1234, help="RNG seed.")
7070
parser.add_argument(
71-
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
71+
"--fp8", action="store_true", default=False, help="Enables the te.autocast() context."
7272
)
7373
parser.add_argument(
7474
"--no-comm-overlap",
@@ -263,7 +263,13 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False)
263263
te.module.base.initialize_ub(
264264
[batched_size, hidden_size],
265265
tp_size,
266-
use_fp8=opts.fp8,
266+
quantization_modes=[
267+
(
268+
te.module.base.UserBufferQuantizationMode.FP8
269+
if opts.fp8
270+
else te.module.base.UserBufferQuantizationMode.NONE
271+
)
272+
],
267273
dtype=torch.bfloat16,
268274
bootstrap_backend=opts.bootstrap_backend,
269275
)
@@ -293,7 +299,7 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False)
293299

294300
dist_print(" |-- Forward pass", group=tp_group, debug=True)
295301
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
296-
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
302+
with te.autocast(enabled=opts.fp8, recipe=fp8_recipe, amax_reduction_group=nccl_world):
297303
y = model(x)
298304
if isinstance(y, tuple):
299305
out, *_ = y

transformer_engine/common/CMakeLists.txt

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,11 @@ list(APPEND transformer_engine_SOURCES
162162
fused_router/fused_topk_with_score_function.cu
163163
recipe/current_scaling.cu
164164
recipe/delayed_scaling.cu
165-
recipe/fp8_block_scaling.cu)
165+
recipe/fp8_block_scaling.cu
166+
comm_gemm_overlap/userbuffers/ipcsocket.cc
167+
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
168+
comm_gemm_overlap/userbuffers/userbuffers.cu
169+
comm_gemm_overlap/comm_gemm_overlap.cpp)
166170
if(USE_CUDA)
167171
# Removed indent to minimize code diff with NV upstream
168172
# Files unique in cuda building
@@ -175,11 +179,7 @@ list(APPEND transformer_engine_SOURCES
175179
fused_attn/fused_attn_fp8.cu
176180
fused_attn/fused_attn.cpp
177181
fused_attn/utils.cu
178-
util/cuda_nvml.cpp
179-
comm_gemm_overlap/userbuffers/ipcsocket.cc
180-
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
181-
comm_gemm_overlap/userbuffers/userbuffers.cu
182-
comm_gemm_overlap/comm_gemm_overlap.cpp)
182+
util/cuda_nvml.cpp)
183183
add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
184184
else()
185185
list(APPEND transformer_engine_SOURCES
@@ -189,10 +189,7 @@ else()
189189
fused_attn_rocm/utils.cpp
190190
gemm/rocm_gemm.cu
191191
amd_detail/system.cpp
192-
comm_gemm_overlap/userbuffers/ipcsocket.cc
193-
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
194-
comm_gemm_overlap/userbuffers/userbuffers.cu
195-
comm_gemm_overlap/comm_gemm_overlap.cpp)
192+
comm_gemm_overlap/rocm_comm_gemm_overlap.cpp)
196193

197194
# process source code files
198195
set(TE ${CMAKE_CURRENT_SOURCE_DIR}/../..)

transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@
2121
#define HALF_BYTES 2
2222
#define UB_MAX_SM 32
2323

24+
#ifdef __HIP_PLATFORM_AMD__
25+
#define half_dtype hip_bfloat16
26+
#define __nv_fp8_e5m2 te_hip_fp8_e5m2
27+
#define __nv_fp8_e4m3 te_hip_fp8_e4m3
28+
#endif
29+
2430
using namespace std::placeholders;
2531

2632
namespace transformer_engine {
@@ -328,6 +334,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
328334
bool accumulate, bool use_split_accumulator,
329335
CommOverlapType comm_type, TensorWrapper &rs_output,
330336
cudaStream_t stream_main) {
337+
printf("bulk_overlap\n");
331338
int ori_sms = _ub_comm->sms;
332339
_ub_comm->use_ce = _use_ce;
333340
_ub_comm->sms = _num_comm_sm;
@@ -353,7 +360,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
353360
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
354361
reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0,
355362
comm_elements, _ub_comm, _stream_comm,
356-
(cudaEvent_t)_comm_launch_event);
363+
(cudaEvent_t)_comm_launch_event);
357364
} else {
358365
reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm,
359366
(cudaEvent_t)_comm_launch_event);
@@ -385,6 +392,7 @@ void CommOverlapBase::atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa
385392
TensorWrapper &workspace, bool grad, bool accumulate,
386393
bool use_split_accumulator, TensorWrapper &rs_output,
387394
cudaStream_t stream_main) {
395+
printf("atomic_gemm_overlap_rs\n");
388396
int ori_sms = _ub_comm->sms;
389397
_ub_comm->use_ce = _use_ce;
390398
_ub_comm->sms = _num_comm_sm;
@@ -481,6 +489,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
481489
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
482490
bool grad, bool accumulate, bool use_split_accumulator,
483491
TensorWrapper &rs_output, cudaStream_t stream_main) {
492+
printf("split_overlap_rs\n");
484493
// Get GEMM dimensions
485494
int ori_sms = _ub_comm->sms;
486495
_ub_comm->use_ce = _use_ce;
@@ -619,6 +628,8 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
619628

620629
void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream,
621630
cudaStream_t stream_main) {
631+
printf("bulk_overlap_external_ag\n");
632+
622633
int comm_bytes = _ubuf.bytes();
623634
int comm_bytes_per_rank = comm_bytes / _tp_size;
624635

@@ -651,19 +662,20 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
651662
CommOverlapType comm_type, int num_max_streams,
652663
int comm_cga_size, int gemm_priority, int comm_priority,
653664
int num_comm_sm, bool set_sm_margin, bool use_ce,
654-
bool atomic_gemm, bool aggregate)
665+
bool atomic_gemm, bool aggregate, bool use_rd)
655666
: CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size,
656667
allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size,
657668
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
658669
atomic_gemm) {
659-
initialize(buffer_shape, buffer_dtype, comm_type, aggregate);
670+
initialize(buffer_shape, buffer_dtype, comm_type, aggregate, use_rd);
660671
}
661672

662673
void CommOverlapP2PBase::initialize(const std::vector<size_t> &buffer_shape, DType buffer_dtype,
663-
CommOverlapType comm_type, bool aggregate) {
674+
CommOverlapType comm_type, bool aggregate, bool use_rd) {
664675
_is_p2p = true;
665676
_is_reduce_scatter = comm_type == CommOverlapType::RS;
666677
_aggregate = aggregate;
678+
_use_rd = use_rd;
667679

668680
// Create workspace tensor with userbuffer
669681
NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!");
@@ -788,6 +800,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(
788800
const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D,
789801
TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
790802
bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) {
803+
printf("atomic_gemm_overlap_ag\n");
791804
int ori_sms = _ub_comm->sms;
792805
_ub_comm->use_ce = _use_ce;
793806
_ub_comm->sms = _num_comm_sm;
@@ -890,6 +903,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
890903
TensorWrapper &workspace, bool grad, bool accumulate,
891904
bool use_split_accumulator, TensorWrapper &B_copy,
892905
cudaStream_t stream_main) {
906+
printf("split_overlap_ag\n");
893907
int ori_sms = _ub_comm->sms;
894908
_ub_comm->use_ce = _use_ce;
895909
_ub_comm->sms = _num_comm_sm;
@@ -1057,6 +1071,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(
10571071
TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
10581072
bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output,
10591073
cudaStream_t stream_main) {
1074+
printf("atomic_gemm_overlap_rs\n");
10601075
int ori_sms = _ub_comm->sms;
10611076
_ub_comm->use_ce = _use_ce;
10621077
_ub_comm->sms = _num_comm_sm;
@@ -1121,6 +1136,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
11211136
TensorWrapper &workspace, bool grad, bool accumulate,
11221137
bool use_split_accumulator, TensorWrapper &rs_output,
11231138
cudaStream_t stream_main) {
1139+
printf("split_overlap_rs\n");
11241140
int ori_sms = _ub_comm->sms;
11251141
_ub_comm->use_ce = _use_ce;
11261142
_ub_comm->sms = _num_comm_sm;
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
/*************************************************************************
2+
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
*
4+
* License for AMD contributions = MIT. See LICENSE for more information
5+
************************************************************************/
6+
7+
#include <transformer_engine/comm_gemm_overlap.h>
8+
#include <transformer_engine/gemm.h>
9+
#include <transformer_engine/transformer_engine.h>
10+
11+
#include "common/common.h"
12+
#include "common/util/cuda_driver.h"
13+
#include "common/util/cuda_runtime.h"
14+
#include "common/util/logging.h"
15+
#include "common/util/system.h"
16+
#include "userbuffers/userbuffers.h"
17+
18+
namespace transformer_engine {
19+
20+
void CommOverlapP2PBase::rocm_split_overlap_ag_rd(const TensorWrapper &A, bool transa, const TensorWrapper &B,
21+
bool transb, TensorWrapper &D, TensorWrapper &bias,
22+
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
23+
bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy,
24+
cudaStream_t stream_main) {
25+
printf("rocm_split_overlap_ag_rd\n");
26+
int ori_sms = _ub_comm->sms;
27+
_ub_comm->use_ce = _use_ce;
28+
_ub_comm->sms = _num_comm_sm;
29+
_ub_comm->cga_size = _cga_size;
30+
// Get GEMM dimensions between TN and NN input layouts
31+
const size_t m = (transa) ? A.size(0) : A.size(1);
32+
const size_t k = (transa) ? A.size(1) : A.size(0);
33+
const size_t n_chunk = _ubufs[0].size(0);
34+
const int comm_bytes = _ubufs[0].bytes();
35+
const bool do_gelu = pre_gelu_out.numel() > 0;
36+
const size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
37+
38+
// Check B copy sizing
39+
if (B_copy.numel() > 0) {
40+
NVTE_CHECK(B_copy.numel() == _ubuf.numel(), "Expected all-gathered B copy buffer with ",
41+
_ubuf.numel(), " elements but got ", B_copy.numel());
42+
NVTE_CHECK(B_copy.element_size() == _ubuf.element_size(),
43+
"Expected all-gathered B copy buffer with ", _ubuf.element_size() * 8,
44+
"-bit data type but got ", B_copy.element_size() * 8, "-bit");
45+
}
46+
47+
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
48+
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0));
49+
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
50+
for (size_t i = 0; i < _stream_compute.size(); i++) {
51+
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0));
52+
}
53+
54+
int steps = 31 - __builtin_clz(_tp_size);
55+
56+
// Chunk dims
57+
std::vector<size_t> input_b_chunk_shape =
58+
(transb ? std::vector<size_t>{k, n_chunk} : std::vector<size_t>{n_chunk, k});
59+
std::vector<size_t> output_chunk_shape = {n_chunk, m};
60+
size_t input_b_chunk_size = n_chunk * k;
61+
size_t output_chunk_size = n_chunk * m;
62+
63+
// GEMM
64+
auto input_b_chunk =
65+
get_buffer_chunk_like(B, input_b_chunk_size * _tp_id, input_b_chunk_shape);
66+
auto output_chunk =
67+
get_tensor_chunk(D, output_chunk_size * _tp_id, output_chunk_shape);
68+
auto aux_chunk =
69+
(do_gelu)
70+
? get_tensor_chunk(pre_gelu_out, output_chunk_size * _tp_id, {n_chunk, k})
71+
: TensorWrapper(nullptr, std::vector<size_t>{0}, pre_gelu_out.dtype());
72+
auto workspace_chunk = get_tensor_chunk(
73+
workspace, (_tp_id % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
74+
75+
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
76+
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
77+
use_split_accumulator, _math_sms,
78+
_stream_compute[_tp_id % _stream_compute.size()]);
79+
80+
std::vector<size_t> owned_chunks;
81+
owned_chunks.reserve(_tp_size);
82+
owned_chunks.push_back(_tp_id);
83+
size_t offset = 1;
84+
85+
for (int step = 0; step < steps; step++) {
86+
int send_rank = (_tp_id + offset) % _tp_size;
87+
int recv_rank = (_tp_id - offset + _tp_size) % _tp_size;
88+
89+
for (int i = 0; i < owned_chunks.size(); i++) {
90+
size_t send_offset = owned_chunks[i] * comm_bytes;
91+
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset,
92+
comm_bytes, _ub_comm, send_rank, _stream_send[i % _stream_send.size()]);
93+
}
94+
95+
std::vector<size_t> new_chunks;
96+
for (size_t i = 0; i < owned_chunks.size(); i++) {
97+
size_t new_chunk_id = (recv_rank + i * offset) % _tp_size;
98+
if (new_chunk_id >= _tp_size ||
99+
std::find(owned_chunks.begin(), owned_chunks.end(), new_chunk_id) != owned_chunks.end()) continue;
100+
size_t recv_offset = new_chunk_id * comm_bytes;
101+
size_t stream_id = new_chunks.size() % _stream_compute.size();
102+
103+
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset,
104+
comm_bytes, _ub_comm, recv_rank, _stream_recv);
105+
106+
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
107+
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[stream_id], _stop_recv, 0));
108+
109+
auto input_b_chunk = get_buffer_chunk_like(B, input_b_chunk_size * new_chunk_id, input_b_chunk_shape);
110+
output_chunk = get_tensor_chunk(D, output_chunk_size * new_chunk_id, output_chunk_shape);
111+
aux_chunk = (do_gelu) ? get_tensor_chunk(pre_gelu_out, output_chunk_size * new_chunk_id, {n_chunk, k})
112+
: TensorWrapper(nullptr, std::vector<size_t>{0}, pre_gelu_out.dtype());
113+
workspace_chunk = get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk});
114+
115+
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
116+
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
117+
use_split_accumulator, _math_sms,
118+
_stream_compute[stream_id]);
119+
120+
new_chunks.push_back(new_chunk_id);
121+
}
122+
owned_chunks.insert(owned_chunks.end(), new_chunks.begin(), new_chunks.end());
123+
offset <<= 1;
124+
}
125+
126+
if (B_copy.numel() > 0) {
127+
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(),
128+
cudaMemcpyDeviceToDevice, _stream_send[0]));
129+
}
130+
131+
_ub_comm->sms = ori_sms;
132+
for (size_t i = 0; i < _stream_compute.size(); i++) {
133+
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i]));
134+
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0));
135+
}
136+
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0]));
137+
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
138+
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
139+
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
140+
} // rocm_split_overlap_ag_rd
141+
142+
} // namespace transformer_engine

transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,11 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks,
375375
cudaMalloc(reinterpret_cast<void **>(&(*comm)->flags_baseptr), 2 * GPU_PAGE_SIZE));
376376
NVTE_CHECK_CUDA(cudaMemset((*comm)->flags_baseptr, 0, 2 * GPU_PAGE_SIZE));
377377
(*comm)->flags = reinterpret_cast<int *>(
378+
#ifdef __HIP_PLATFORM_AMD__
379+
(reinterpret_cast<uintptr_t>((*comm)->flags) + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
380+
#else
378381
((CUdeviceptr)(*comm)->flags_baseptr + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
379-
382+
#endif
380383
using namespace std;
381384

382385
sched_param param;

0 commit comments

Comments
 (0)