Skip to content

Commit 6e1aee6

Browse files
[fix] Performance Optimization for MNNVL TwoShot Kernel (NVIDIA#5934)
Signed-off-by: Shiyu Li <[email protected]> Co-authored-by: Zongfei Jing <[email protected]>
1 parent fe070a0 commit 6e1aee6

File tree

4 files changed

+130
-79
lines changed

4 files changed

+130
-79
lines changed

cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu

Lines changed: 92 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,31 @@ inline __device__ __nv_bfloat16 fromFloat<__nv_bfloat16>(float val)
6161
return __float2bfloat16(val);
6262
}
6363

64+
__device__ float4 loadfloat4(void const* ptr)
65+
{
66+
67+
float return_value[4];
68+
69+
asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n"
70+
: "=f"(return_value[0]), "=f"(return_value[1]), "=f"(return_value[2]), "=f"(return_value[3])
71+
: "l"(ptr));
72+
73+
return *(float4*) return_value;
74+
}
75+
76+
__device__ __inline__ float2 loadfloat2(void const* ptr)
77+
{
78+
79+
float return_value[2];
80+
81+
asm volatile("ld.volatile.global.v2.f32 {%0, %1}, [%2];\n"
82+
: "=f"(return_value[0]), "=f"(return_value[1])
83+
: "l"(ptr)
84+
: "memory");
85+
86+
return *(float2*) return_value;
87+
}
88+
6489
template <int WORLD_SIZE, typename T>
6590
__global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, int num_tokens,
6691
int buffer_M, int token_dim, int rank, uint32_t* buffer_flags, bool wait_for_results)
@@ -74,20 +99,13 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
7499
cudaGridDependencySynchronize();
75100
#endif
76101

102+
// [input_ptr, clear_ptr, buffer_size, access_counter]
103+
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
104+
// Each buffer is M * N and we have 2 buffers in each group, one for reduce-scatter and one for allgather
105+
uint32_t buffer_group_size = flag.z << 1;
106+
uint32_t input_offset = flag.x * buffer_group_size;
107+
uint32_t clear_offset = flag.y * buffer_group_size;
77108
uint32_t* offset_access_ptr = &buffer_flags[3];
78-
// Buffer size is M * N, and we need two buffers for reduce-scatter and allgather
79-
uint32_t buffer_size = (buffer_flags[2] << 1);
80-
uint32_t input_offset = buffer_flags[0] * buffer_size;
81-
uint32_t clear_offset = buffer_flags[1] * buffer_size;
82-
83-
if (wait_for_results)
84-
{
85-
__syncthreads();
86-
if (threadIdx.x == 0)
87-
{
88-
atomicAdd(offset_access_ptr, 1);
89-
}
90-
}
91109

92110
if (elt < token_dim)
93111
{
@@ -101,17 +119,16 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
101119

102120
// Reduce and broadcast
103121

104-
int global_token = token * WORLD_SIZE + rank;
105-
if (global_token < num_tokens)
122+
if ((token % WORLD_SIZE) == rank)
106123
{
107-
124+
int local_token = token / WORLD_SIZE;
108125
float accum = 0.f;
109126

110127
T values[WORLD_SIZE];
111128

112129
for (int r = 0; r < WORLD_SIZE; r++)
113130
{
114-
input_ptrs[rank][clear_offset + token * token_dim * WORLD_SIZE + r * token_dim + elt]
131+
input_ptrs[rank][clear_offset + local_token * token_dim * WORLD_SIZE + r * token_dim + elt]
115132
= fromFloat<T>(-0.f);
116133
}
117134

@@ -121,7 +138,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
121138
for (int r = 0; r < WORLD_SIZE; r++)
122139
{
123140
T volatile* lamport_ptr = (T volatile*) &input_ptrs[rank][input_offset
124-
+ token * token_dim * WORLD_SIZE + r * token_dim + elt];
141+
+ local_token * token_dim * WORLD_SIZE + r * token_dim + elt];
125142
values[r] = *lamport_ptr;
126143
valid &= !isNegZero(values[r]);
127144
}
@@ -132,7 +149,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
132149
{
133150
accum += toFloat<T>(values[r]);
134151
}
135-
mcast_ptr[input_offset + buffer_M * token_dim + global_token * token_dim + elt] = fromFloat<T>(accum);
152+
mcast_ptr[input_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat<T>(accum);
136153
}
137154
}
138155

@@ -145,23 +162,50 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
145162
// Optionally wait for results if the next layer isn't doing the Lamport check
146163
if (wait_for_results)
147164
{
148-
T volatile* lamport_ptr
149-
= (T volatile*) &input_ptrs[rank][input_offset + buffer_M * token_dim + token * token_dim + elt];
150-
T val = *lamport_ptr;
151-
while (isNegZero(val))
152-
val = *lamport_ptr;
153-
154-
// Copy if requested
155-
if (output_ptr)
156-
output_ptr[token * token_dim + elt] = val;
157-
if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0)
165+
// Update the atomic counter to indicate the block has read the offsets
166+
__syncthreads();
167+
168+
if (threadIdx.x == 0)
169+
{
170+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
171+
asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
172+
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
173+
asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
174+
#else
175+
atomicAdd(offset_access_ptr, 1);
176+
#endif
177+
}
178+
// Only use a set of CTAs for lamport sync, reargange the grid
179+
constexpr int ELTS_PER_LOAD = sizeof(float2) / sizeof(T);
180+
// blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32)
181+
if (threadIdx.x < (blockDim.x / ELTS_PER_LOAD))
182+
{
183+
uint64_t current_pos = blockIdx.x * token_dim + blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD;
184+
185+
void* lamport_ptr = (void*) &input_ptrs[rank][input_offset + buffer_M * token_dim + current_pos];
186+
// We have 2 assumptions here:
187+
// 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B
188+
// 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32)
189+
float2 val = loadfloat2(lamport_ptr);
190+
while (isNegZero(*(T*) &val))
191+
{
192+
val = loadfloat2(lamport_ptr);
193+
}
194+
if (output_ptr)
195+
{
196+
*((float2*) &output_ptr[current_pos]) = val;
197+
}
198+
}
199+
200+
// Update the buffer flags
201+
if (threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == 0)
158202
{
159203
// Make sure all blocks have finished reading the offsets, 2-D grid
160204
while (*reinterpret_cast<uint32_t volatile*>(offset_access_ptr) < gridDim.x * gridDim.y)
161205
{
162206
}
163-
buffer_flags[0] = (buffer_flags[0] + 1) % 3;
164-
buffer_flags[1] = (buffer_flags[1] + 1) % 3;
207+
buffer_flags[0] = (flag.x + 1) % 3;
208+
buffer_flags[1] = (flag.y + 1) % 3;
165209
*(offset_access_ptr) = 0;
166210
}
167211
}
@@ -251,18 +295,6 @@ __device__ void copy_f4_ldg(T_IN* dst, T_IN const* src)
251295
*dst4 = *src4;
252296
}
253297

254-
__device__ float4 loadfloat4(void const* ptr)
255-
{
256-
257-
float return_value[4];
258-
259-
asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n"
260-
: "=f"(return_value[0]), "=f"(return_value[1]), "=f"(return_value[2]), "=f"(return_value[3])
261-
: "l"(ptr));
262-
263-
return *(float4*) return_value;
264-
}
265-
266298
template <typename T>
267299
inline __device__ T add(T a, T b)
268300
{
@@ -322,19 +354,14 @@ __global__ void __launch_bounds__(128, 1)
322354
int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)];
323355

324356
uint32_t* offset_access_ptr = &buffer_flags[3];
357+
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
325358
// Buffer size is M * N, and we need two buffers for reduce-scatter and allgather
326-
uint32_t buffer_size = buffer_flags[2];
327-
uint32_t buffer_offset = buffer_flags[0] * (buffer_size << 1);
359+
uint32_t buffer_size = flag.z;
360+
uint32_t buffer_offset = flag.x * (buffer_size << 1);
328361
T_IN const* input = &buffer_input[buffer_offset + buffer_size];
329362

330363
cudaTriggerProgrammaticLaunchCompletion();
331364

332-
__syncthreads();
333-
if (threadIdx.x == 0)
334-
{
335-
atomicAdd(offset_access_ptr, 1);
336-
}
337-
338365
for (int i = 0; i < NUM_INPUTS; i++)
339366
{
340367
for (int j = 0; j < DIM / (1 * ELTS_PER_THREAD * NUM_THREADS); j++)
@@ -361,7 +388,17 @@ __global__ void __launch_bounds__(128, 1)
361388
}
362389

363390
__pipeline_commit();
364-
391+
__syncthreads();
392+
if (threadIdx.x == 0)
393+
{
394+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
395+
asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
396+
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
397+
asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
398+
#else
399+
atomicAdd(offset_access_ptr, 1);
400+
#endif
401+
}
365402
// Load all inputs
366403
bool valid = false;
367404

@@ -494,14 +531,13 @@ __global__ void __launch_bounds__(128, 1)
494531
if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0)
495532
{
496533
// Make sure all blocks have finished accessing the buffer
497-
while (*reinterpret_cast<uint32_t volatile*>(offset_access_ptr) != gridDim.x * gridDim.y)
534+
while (*reinterpret_cast<uint32_t volatile*>(offset_access_ptr) < gridDim.x * gridDim.y)
498535
{
499536
}
500-
buffer_flags[0] = (buffer_flags[0] + 1) % 3;
501-
buffer_flags[1] = (buffer_flags[1] + 1) % 3;
537+
buffer_flags[0] = (flag.x + 1) % 3;
538+
buffer_flags[1] = (flag.y + 1) % 3;
502539
*(offset_access_ptr) = 0;
503540
}
504-
__syncthreads();
505541
#endif
506542
}
507543

cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ McastDeviceMemory::McastDeviceMemory(
5050
, mMcHandle(0)
5151
{
5252

53-
cudaSetDevice(mDeviceIdx);
53+
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceIdx));
5454
// Check if the device support multicasting
5555
int multicast_supported{0};
5656
TLLM_CU_CHECK(cuDeviceGetAttribute(&multicast_supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, mDeviceIdx));
@@ -82,34 +82,41 @@ McastDeviceMemory::McastDeviceMemory(
8282
{
8383
allocNvlsMcastMem(mSignalPadOffset + kSIGNAL_PAD_SIZE);
8484
}
85-
mSignalPadsDev.resize(mGroupSize);
85+
// Initialize signal pads
86+
mSignalPads.resize(mGroupSize);
8687
for (size_t i = 0; i < mGroupSize; i++)
8788
{
88-
mSignalPadsDev[i] = mUcPtrs[i] + mSignalPadOffset;
89+
mSignalPads[i] = mUcPtrs[i] + mSignalPadOffset;
8990
if (i == mGroupRank)
9091
{
91-
cuMemsetD8(mSignalPadsDev[i], 0, kSIGNAL_PAD_SIZE);
92+
cuMemsetD8(mSignalPads[i], 0, kSIGNAL_PAD_SIZE);
9293
}
9394
}
95+
// Copy host array of pointers to device array
96+
TLLM_CUDA_CHECK(cudaMalloc(&mSignalPadsDev, mGroupSize * sizeof(CUdeviceptr)));
97+
TLLM_CUDA_CHECK(cudaMalloc(&mUcPtrsDev, mGroupSize * sizeof(CUdeviceptr)));
98+
TLLM_CUDA_CHECK(
99+
cudaMemcpy(mSignalPadsDev, mSignalPads.data(), mGroupSize * sizeof(CUdeviceptr), cudaMemcpyHostToDevice));
100+
TLLM_CUDA_CHECK(cudaMemcpy(mUcPtrsDev, mUcPtrs.data(), mGroupSize * sizeof(CUdeviceptr), cudaMemcpyHostToDevice));
94101
}
95102

96103
McastDeviceMemory::~McastDeviceMemory()
97104
{
98105
tensorrt_llm::common::unregisterMcastDevMemBuffer(this);
106+
TLLM_CUDA_CHECK(cudaFree(mSignalPadsDev));
107+
TLLM_CUDA_CHECK(cudaFree(mUcPtrsDev));
108+
99109
if (mIsMNNvlink)
100110
{
101111
for (uint32_t rank = 0; rank < mGroupSize; rank++)
102112
{
103-
if (rank == mGroupRank)
104-
{
105-
cuMemRelease(mUcHandles[rank]);
106-
}
107-
else
108-
{
109-
mUcHandles[rank] = 0;
110-
}
113+
TLLM_CU_CHECK(cuMemUnmap(mUcPtrs[rank], mAllocationSize));
114+
// We need to release the handle on each rank
115+
TLLM_CU_CHECK(cuMemRelease(mUcHandles[rank]));
111116
}
112-
cuMemRelease(mMcHandle);
117+
TLLM_CU_CHECK(cuMemUnmap(mMcPtr, mAllocationSize));
118+
TLLM_CU_CHECK(cuMemAddressFree(mMcPtr, mAllocationSize));
119+
TLLM_CU_CHECK(cuMemRelease(mMcHandle));
113120
}
114121
else
115122
{

cpp/tensorrt_llm/runtime/mcastDeviceMemory.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,18 @@ class McastDeviceMemory
4444

4545
McastDeviceMemory(size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink);
4646

47+
// We don't register the pointer in these two functions since we don't expect any python-level code would call
48+
// to obtain the raw pointers.
4749
//! Get the raw array of signal pad pointers to all ranks (including self)
4850
void** getSignalPadPtrsDev()
4951
{
50-
return reinterpret_cast<void**>(mSignalPadsDev.data());
52+
return mSignalPadsDev;
5153
}
5254

5355
//! Get the raw array of unicast pointers to all ranks (including self)
5456
void** getBufferPtrsDev()
5557
{
56-
return reinterpret_cast<void**>(mUcPtrs.data());
58+
return mUcPtrsDev;
5759
}
5860

5961
//! Get the raw unicast pointer to a given rank
@@ -93,11 +95,17 @@ class McastDeviceMemory
9395
size_t mAllocationSize;
9496

9597
CUdeviceptr mMcPtr;
96-
std::vector<CUdeviceptr> mUcPtrs;
97-
std::vector<CUdeviceptr> mSignalPadsDev;
9898
CUmemGenericAllocationHandle mMcHandle;
9999
std::vector<CUmemGenericAllocationHandle> mUcHandles;
100100

101+
// Host array of pointers
102+
std::vector<CUdeviceptr> mUcPtrs;
103+
std::vector<CUdeviceptr> mSignalPads;
104+
105+
// Device array of pointers
106+
void** mUcPtrsDev;
107+
void** mSignalPadsDev;
108+
101109
// For intra-node mcast
102110
tensorrt_llm::runtime::IpcNvlsHandle* mNvlsHandle;
103111

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -798,12 +798,12 @@ def _run_MoE(hidden_states, hidden_states_fp4, do_finalize):
798798
hidden_states, residual = self.post_attention_layernorm(
799799
hidden_states, residual)
800800

801-
# Note: this fusion pattern is only supported for TRTLLM-nvfp4 backend now
802-
do_finalize = not (hidden_states.shape[0]
803-
<= self.moe_allreduce.max_token
804-
and self.fusion_config.POST_MOE_FUSION
805-
and self.model_config.moe_backend == 'TRTLLM'
806-
and self.mlp.experts.has_nvfp4)
801+
# Note: this fusion pattern is only supported for single-node TRTLLM-nvfp4 backend now
802+
do_finalize = self.mapping.is_multi_node() or (
803+
not (hidden_states.shape[0] <= self.moe_allreduce.max_token
804+
and self.fusion_config.POST_MOE_FUSION
805+
and self.model_config.moe_backend == "TRTLLM"
806+
and self.mlp.experts.has_nvfp4))
807807

808808
hidden_states = _run_MoE(hidden_states,
809809
hidden_states_fp4=None,

0 commit comments

Comments
 (0)