Skip to content

Commit eeb56c2

Browse files
timlee0212liji-nv
andauthored
[None][feat] MNNVLAllreduce Kernel Refactor (#8018)
Signed-off-by: Shiyu Li <[email protected]> Co-authored-by: Jin Li <[email protected]>
1 parent ed81173 commit eeb56c2

File tree

9 files changed

+1613
-930
lines changed

9 files changed

+1613
-930
lines changed
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
// Helper functions for lamport-based synchronization
18+
19+
#ifndef TRTLLM_CUDA_LAMPORT_UTILS_CUH
20+
#define TRTLLM_CUDA_LAMPORT_UTILS_CUH
21+
22+
#include <array>
23+
#include <cuda_bf16.h>
24+
#include <cuda_fp16.h>
25+
#include <cuda_runtime.h>
26+
#include <type_traits>
27+
28+
#include <cooperative_groups.h>
29+
30+
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
31+
32+
namespace tensorrt_llm::common
33+
{
34+
35+
constexpr uint16_t kNEGZERO_FP16 = 0x8000U;
36+
37+
template <typename T>
38+
union Fp16BitCast
39+
{
40+
T mFp;
41+
uint16_t mInt;
42+
43+
constexpr Fp16BitCast()
44+
: mInt(0)
45+
{
46+
}
47+
48+
constexpr Fp16BitCast(T val)
49+
: mFp(val)
50+
{
51+
}
52+
53+
constexpr Fp16BitCast(uint16_t val)
54+
: mInt(val)
55+
{
56+
}
57+
};
58+
59+
template <typename T>
60+
static constexpr __device__ __host__ T negZero()
61+
{
62+
if constexpr (std::is_same_v<T, float>)
63+
{
64+
return -0.0F;
65+
}
66+
else if constexpr (std::is_same_v<T, __nv_bfloat16> || std::is_same_v<T, __nv_half>)
67+
{
68+
return Fp16BitCast<T>(kNEGZERO_FP16).mFp;
69+
}
70+
else
71+
{
72+
static_assert(sizeof(T) == 0, "negativeZero not specialized for this type");
73+
}
74+
return T{}; // Never reached, but needed for compilation
75+
}
76+
77+
template <typename T>
78+
static inline __device__ bool isNegZero(T val)
79+
{
80+
81+
if constexpr (std::is_same_v<T, float>)
82+
{
83+
return val == 0.F && signbit(val);
84+
}
85+
else if constexpr (std::is_same_v<T, __nv_bfloat16> || std::is_same_v<T, __nv_half>)
86+
{
87+
return Fp16BitCast<T>(val).mInt == kNEGZERO_FP16;
88+
}
89+
else
90+
{
91+
static_assert(sizeof(T) == 0, "isNegZero not specialized for this type");
92+
}
93+
return false; // Never reached, but needed for compilation
94+
}
95+
96+
template <typename PackedType, typename T>
97+
constexpr __device__ __host__ PackedType getPackedLamportInit()
98+
{
99+
static_assert(sizeof(PackedType) % sizeof(T) == 0, "PackedType size must be divisible by T size");
100+
constexpr int kNumElements = sizeof(PackedType) / sizeof(T);
101+
102+
union PackedT
103+
{
104+
PackedType mPacked;
105+
std::array<T, kNumElements> mElements;
106+
107+
constexpr PackedT()
108+
: mElements{}
109+
{
110+
for (int i = 0; i < kNumElements; i++)
111+
{
112+
mElements[i] = negZero<T>();
113+
}
114+
}
115+
};
116+
117+
PackedT initValue{};
118+
return initValue.mPacked;
119+
}
120+
121+
// A helper class to get the correct base pointer for a given layout
122+
struct LamportBufferLayout
123+
{
124+
uint32_t numStages = 1;
125+
uint32_t bytesPerBuffer = 0;
126+
static constexpr uint32_t sNumLamportBuffers = 3;
127+
128+
// Implicitly inlined
129+
[[nodiscard]] __device__ __host__ size_t getTotalBytes() const
130+
{
131+
return numStages * static_cast<size_t>(bytesPerBuffer / numStages) * sNumLamportBuffers;
132+
}
133+
134+
// Implicitly inlined
135+
[[nodiscard]] __device__ __host__ void* getStagePtr(
136+
void* bufferBasePtr, uint32_t lamportIndex, uint32_t stageIndex) const
137+
{
138+
// Typecast to avoid warnings
139+
return reinterpret_cast<void*>(reinterpret_cast<char*>(bufferBasePtr)
140+
+ static_cast<size_t>(
141+
(lamportIndex * numStages + stageIndex) * static_cast<size_t>(bytesPerBuffer / numStages)));
142+
}
143+
};
144+
// Current Index
145+
// Dirty Index
146+
// bytes_per_buffer
147+
// Dirty num_stages
148+
// Dirty bytes_to_clear = {stage0, stage1, stage2, stage3} # We fix this to 4 stages
149+
// offset_access_ptr
150+
151+
namespace cg = cooperative_groups;
152+
153+
// PackedType is the one used in kernel for Lamport buffer (LDG.128 or LDG.64)
154+
template <typename PackedType = float4>
155+
__device__ struct __attribute__((aligned(32))) LamportFlags
156+
{
157+
public:
158+
__device__ explicit LamportFlags(uint32_t* bufferFlags, uint32_t numStages = 1)
159+
: mBufferFlagsPtr(bufferFlags)
160+
, mFlagAccessPtr(&bufferFlags[8])
161+
{
162+
mCurBufferLayout.numStages = numStages;
163+
uint4 flag = reinterpret_cast<uint4*>(bufferFlags)[0];
164+
mCurrentIndex = flag.x;
165+
mDirtyIndex = flag.y;
166+
// Buffer size is unchanged as the flag should be coupled to each buffer
167+
mCurBufferLayout.bytesPerBuffer = flag.z;
168+
mDirtyBufferLayout.bytesPerBuffer = flag.z;
169+
mDirtyBufferLayout.numStages = flag.w;
170+
*reinterpret_cast<uint4*>(&mBytesToClear) = reinterpret_cast<uint4*>(bufferFlags)[1];
171+
}
172+
173+
// Return the base pointer of the lamport buffer indexed by mCurrentIndex and the stageIdx
174+
[[nodiscard]] __device__ void* getCurLamportBuf(void* bufferBasePtr, int stageIdx = 0) const
175+
{
176+
return mCurBufferLayout.getStagePtr(bufferBasePtr, mCurrentIndex, stageIdx);
177+
}
178+
179+
// Fill the dirty lamport buffer with the init value; Use stageIdx to select the stage to clear, -1 to clear all
180+
// FIXME: Current kernel may use less stages than the dirty numStages; How to guarantee the correctness?
181+
// CAUTION: This function requires all threads in the grid to participate and ASSUME 1D thread block layout!
182+
__device__ void clearDirtyLamportBuf(void* bufferBasePtr, int stageIdx = -1)
183+
{
184+
// Rasterize the threads to 1D for flexible clearing
185+
186+
uint32_t globalCtaIdx = blockIdx.x * gridDim.y + blockIdx.y;
187+
uint32_t globalTid = globalCtaIdx * blockDim.x + threadIdx.x;
188+
uint32_t numThreads = gridDim.x * gridDim.y * blockDim.x;
189+
190+
if (stageIdx == -1)
191+
{
192+
// Clear all stages
193+
for (uint32_t i = 0; i < mDirtyBufferLayout.numStages; i++)
194+
{
195+
clearPackedBuf(bufferBasePtr, globalTid, numThreads, mBytesToClear[i], mDirtyIndex, i);
196+
}
197+
}
198+
else if (stageIdx < mDirtyBufferLayout.numStages)
199+
{
200+
clearPackedBuf(bufferBasePtr, globalTid, numThreads, mBytesToClear[stageIdx], mDirtyIndex, stageIdx);
201+
}
202+
}
203+
204+
__device__ void ctaArrive()
205+
{
206+
int tid{0};
207+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
208+
209+
cg::cluster_group cluster = cg::this_cluster();
210+
// We update the atomic counter per cluster
211+
tid = cluster.thread_rank();
212+
cluster.sync();
213+
#else
214+
tid = threadIdx.x;
215+
__syncthreads();
216+
#endif
217+
if (tid == 0)
218+
{
219+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
220+
asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(mFlagAccessPtr), "r"(1) : "memory");
221+
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700))
222+
asm volatile("red.release.global.gpu.add.u32 [%0], %1;" ::"l"(mFlagAccessPtr), "r"(1) : "memory");
223+
#else
224+
atomicAdd(mFlagAccessPtr, 1);
225+
#endif
226+
}
227+
}
228+
229+
__device__ void waitAndUpdate(uint4 bytesToClearPerStage)
230+
{
231+
bool isLastCtaT0{false};
232+
int targetCount{0};
233+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
234+
cg::grid_group grid = cg::this_grid();
235+
// Use the first thread instead of the last thread as the last thread may exit early
236+
isLastCtaT0 = grid.thread_rank() == 0;
237+
targetCount = grid.num_clusters();
238+
#else
239+
isLastCtaT0 = threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0;
240+
targetCount = gridDim.x * gridDim.y;
241+
#endif
242+
if (isLastCtaT0)
243+
{
244+
uint4* flagPtr = reinterpret_cast<uint4*>(mBufferFlagsPtr);
245+
while (*reinterpret_cast<uint32_t volatile*>(mFlagAccessPtr) < targetCount)
246+
{
247+
}
248+
// 'Current' becomes 'Dirty'
249+
flagPtr[0] = {(mCurrentIndex + 1) % 3, // Current index
250+
mCurrentIndex, // Dirty index
251+
mCurBufferLayout.bytesPerBuffer, // Buffer size
252+
mCurBufferLayout.numStages}; // Dirty - Number of stages
253+
flagPtr[1] = bytesToClearPerStage;
254+
*mFlagAccessPtr = 0;
255+
}
256+
}
257+
258+
private:
259+
uint32_t* mBufferFlagsPtr;
260+
uint32_t* mFlagAccessPtr;
261+
262+
uint32_t mCurrentIndex, mDirtyIndex;
263+
// So that we can access it with uint4
264+
alignas(16) std::array<uint32_t, 4> mBytesToClear;
265+
LamportBufferLayout mCurBufferLayout, mDirtyBufferLayout;
266+
267+
inline __device__ void clearPackedBuf(void* bufferBasePtr, uint32_t globalTid, uint32_t numThreads,
268+
uint32_t bytesToClear, uint8_t dirtyIndex, uint8_t stageIdx)
269+
{
270+
// Round up to the float4 boundary
271+
// For the same reason that the divUp is shadowed, we have to define it again here.
272+
uint32_t clearBoundary = (bytesToClear + sizeof(PackedType) - 1) / sizeof(PackedType);
273+
for (uint32_t packedIdx = globalTid; packedIdx < clearBoundary; packedIdx += numThreads)
274+
{
275+
reinterpret_cast<PackedType*>(
276+
mDirtyBufferLayout.getStagePtr(bufferBasePtr, dirtyIndex, stageIdx))[packedIdx]
277+
= getPackedLamportInit<PackedType, float>();
278+
}
279+
}
280+
};
281+
282+
} // namespace tensorrt_llm::common
283+
284+
#endif // TRTLLM_CUDA_LAMPORT_UTILS_CUH

0 commit comments

Comments
 (0)