Skip to content

Commit c8b9998

Browse files
authored
[TRTLLM-8637][feat] Optimize the routing kernel for DeepseekV3 (MoE CUTLASS backend); Add support for KimiK2 and Qwen-next (MoE TRTLLM backend) (NVIDIA#7761)
Signed-off-by: Christina Zhang <[email protected]>
1 parent f7722e2 commit c8b9998

File tree

19 files changed

+1013
-854
lines changed

19 files changed

+1013
-854
lines changed

cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ struct TopKRedType
5151
static __host__ __device__ inline TypeCmp makeCmpVal(T val, int32_t idx = 0)
5252
{
5353
auto valueBits = cub::Traits<T>::TwiddleIn(reinterpret_cast<typename cub::Traits<T>::UnsignedBits&>(val));
54-
TypeCmp compactTmp = reinterpret_cast<TypeCmp&>(valueBits);
54+
TypeCmp compactTmp = valueBits;
5555
compactTmp = (compactTmp << kMoveBits) | (0xFFFF & (kMaxIdx - idx));
5656
// Use 65535 minus idx to give higher priority to elements with smaller indices.
5757
return compactTmp;
@@ -162,9 +162,28 @@ struct Sort<4, RedType>
162162
}
163163
};
164164

165+
template <int K, typename Type>
166+
__forceinline__ __device__ void reduceTopK(cg::thread_block_tile<kWARP_SIZE> const& warp, Type (&out)[K],
167+
int32_t (&outIdx)[K], Type value, int32_t idx, Type const minValue, int actualK = K)
168+
{
169+
static_assert(K > 0, "Top K must have K > 0");
170+
static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE");
171+
using RedType = TopKRedType<Type>;
172+
RedType topK{value, idx};
173+
typename RedType::TypeCmp packedMax{};
174+
#pragma unroll
175+
for (int kk = 0; kk < actualK; ++kk) //@todo: check if actualK is correct
176+
{
177+
topK = kk > 0 && packedMax == topK.compValIdx ? RedType{minValue, idx} : topK;
178+
// get the next largest value
179+
packedMax = topK.reduce(warp);
180+
RedType::unpack(out[kk], outIdx[kk], packedMax);
181+
}
182+
};
183+
165184
template <int K, typename Type, int N, bool IsSorted = false>
166-
__device__ void reduceTopK(cg::thread_block_tile<kWARP_SIZE> const& warp, Type (&out)[K], int32_t (&outIdx)[K],
167-
Type (&value)[N], int32_t (&idx)[N], Type minValue)
185+
__device__ void reduceTopKFunc(cg::thread_block_tile<kWARP_SIZE> const& warp, Type (&out)[K], int32_t (&outIdx)[K],
186+
Type (&value)[N], int32_t (&idx)[N], Type minValue, int actualK = K)
168187
{
169188
static_assert(K > 0, "Top K must have K > 0");
170189
static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE");
@@ -184,7 +203,7 @@ __device__ void reduceTopK(cg::thread_block_tile<kWARP_SIZE> const& warp, Type (
184203
}
185204
typename RedType::TypeCmp packedMax{};
186205
#pragma unroll
187-
for (int kk = 0; kk < K; ++kk)
206+
for (int kk = 0; kk < actualK; ++kk)
188207
{
189208
bool update = kk > 0 && packedMax == topK[0].compValIdx;
190209
#pragma unroll
@@ -198,6 +217,67 @@ __device__ void reduceTopK(cg::thread_block_tile<kWARP_SIZE> const& warp, Type (
198217
}
199218
};
200219

220+
template <int K, typename Type, int N>
221+
__forceinline__ __device__ void reduceTopK(cg::thread_block_tile<kWARP_SIZE> const& warp, Type (&out)[K],
222+
int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], Type const minValue, int actualK = K)
223+
{
224+
static_assert(K > 0, "Top K must have K > 0");
225+
static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE");
226+
static_assert(N > 0, "Top K must have N > 0");
227+
static_assert(N <= 16, "Only support candidates number less than or equal to 16*32=512");
228+
static_assert(
229+
N <= 4 || N % 4 == 0, "Only support candidates number is a multiple of 4*32=128 or less than or equal to 4");
230+
using RedType = TopKRedType<Type>;
231+
232+
if constexpr (N <= 4)
233+
{
234+
reduceTopKFunc<K, Type, N>(warp, out, outIdx, value, idx, minValue, actualK);
235+
}
236+
else
237+
{
238+
239+
constexpr int numLoops = N / 4;
240+
constexpr int numResults = (numLoops * K - 1) / kWARP_SIZE + 1;
241+
242+
Type topKBufferValue[numResults];
243+
int32_t topKBufferIdx[numResults];
244+
int32_t laneIdx = threadIdx.x % kWARP_SIZE;
245+
246+
for (int ii = 0; ii < numResults; ++ii)
247+
{
248+
topKBufferValue[ii] = minValue;
249+
topKBufferIdx[ii] = ii * kWARP_SIZE - 1; //@todo: check if this is correct
250+
}
251+
for (int loop = 0; loop < numLoops; ++loop)
252+
{
253+
int start = loop * 4;
254+
Type topKValue[K];
255+
int32_t topKIdx[K];
256+
Type inValue[4];
257+
int32_t inIdx[4];
258+
for (int i = 0; i < 4; ++i)
259+
{
260+
inValue[i] = value[start + i];
261+
inIdx[i] = idx[start + i];
262+
}
263+
reduceTopKFunc<K, Type, 4>(warp, topKValue, topKIdx, inValue, inIdx, minValue, actualK);
264+
int inOffset = laneIdx % K;
265+
if (laneIdx >= loop * K && laneIdx < (loop + 1) * K)
266+
{
267+
topKBufferValue[0] = topKValue[inOffset];
268+
topKBufferIdx[0] = topKIdx[inOffset];
269+
}
270+
if (loop == numLoops - 1 && (laneIdx < (numLoops * K - kWARP_SIZE)))
271+
{
272+
topKBufferValue[1] = topKValue[inOffset];
273+
topKBufferIdx[1] = topKIdx[inOffset];
274+
}
275+
}
276+
277+
reduceTopKFunc<K, Type, numResults>(warp, out, outIdx, topKBufferValue, topKBufferIdx, minValue, actualK);
278+
}
279+
};
280+
201281
#undef TOPK_SWAP
202282

203283
} // namespace reduce_topk

0 commit comments

Comments
 (0)