Skip to content

Commit 9d3c675

Browse files
authored
[None][chore] Support larger topK for NVLinkOneSided AlltoAll. (NVIDIA#9816)
Signed-off-by: Bo Li <[email protected]>
1 parent 6a39bb9 commit 9d3c675

File tree

2 files changed

+121
-4
lines changed

2 files changed

+121
-4
lines changed

cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,18 @@ namespace tensorrt_llm::kernels::moe_comm
4545
#define SWITCH_TOP_K(top_k, TOP_K, ...) \
4646
switch (top_k) \
4747
{ \
48+
case 16: \
49+
{ \
50+
constexpr int TOP_K = 16; \
51+
__VA_ARGS__; \
52+
break; \
53+
} \
54+
case 10: \
55+
{ \
56+
constexpr int TOP_K = 10; \
57+
__VA_ARGS__; \
58+
break; \
59+
} \
4860
case 8: \
4961
{ \
5062
constexpr int TOP_K = 8; \
@@ -611,6 +623,90 @@ __device__ void vectorized_combine_impl(
611623
// Load directly into the per-k accumulator; reduce across k below
612624
acc[k].load(recv_buffer + base_token + offset);
613625
}
626+
if constexpr (TOP_K == 16)
627+
{
628+
T* a0 = reinterpret_cast<T*>(&acc[0]);
629+
T* a1 = reinterpret_cast<T*>(&acc[1]);
630+
T* a2 = reinterpret_cast<T*>(&acc[2]);
631+
T* a3 = reinterpret_cast<T*>(&acc[3]);
632+
T* a4 = reinterpret_cast<T*>(&acc[4]);
633+
T* a5 = reinterpret_cast<T*>(&acc[5]);
634+
T* a6 = reinterpret_cast<T*>(&acc[6]);
635+
T* a7 = reinterpret_cast<T*>(&acc[7]);
636+
T* a8 = reinterpret_cast<T*>(&acc[8]);
637+
T* a9 = reinterpret_cast<T*>(&acc[9]);
638+
T* a10 = reinterpret_cast<T*>(&acc[10]);
639+
T* a11 = reinterpret_cast<T*>(&acc[11]);
640+
T* a12 = reinterpret_cast<T*>(&acc[12]);
641+
T* a13 = reinterpret_cast<T*>(&acc[13]);
642+
T* a14 = reinterpret_cast<T*>(&acc[14]);
643+
T* a15 = reinterpret_cast<T*>(&acc[15]);
644+
#pragma unroll
645+
for (int j = 0; j < elems_per_vec; ++j)
646+
{
647+
a0[j] += a1[j];
648+
a2[j] += a3[j];
649+
a4[j] += a5[j];
650+
a6[j] += a7[j];
651+
a8[j] += a9[j];
652+
a10[j] += a11[j];
653+
a12[j] += a13[j];
654+
a14[j] += a15[j];
655+
}
656+
#pragma unroll
657+
for (int j = 0; j < elems_per_vec; ++j)
658+
{
659+
a0[j] += a2[j];
660+
a4[j] += a6[j];
661+
a8[j] += a10[j];
662+
a12[j] += a14[j];
663+
}
664+
#pragma unroll
665+
for (int j = 0; j < elems_per_vec; ++j)
666+
{
667+
a0[j] += a4[j];
668+
a8[j] += a12[j];
669+
}
670+
#pragma unroll
671+
for (int j = 0; j < elems_per_vec; ++j)
672+
{
673+
a0[j] += a8[j];
674+
}
675+
}
676+
else if constexpr (TOP_K == 10)
677+
{
678+
T* a0 = reinterpret_cast<T*>(&acc[0]);
679+
T* a1 = reinterpret_cast<T*>(&acc[1]);
680+
T* a2 = reinterpret_cast<T*>(&acc[2]);
681+
T* a3 = reinterpret_cast<T*>(&acc[3]);
682+
T* a4 = reinterpret_cast<T*>(&acc[4]);
683+
T* a5 = reinterpret_cast<T*>(&acc[5]);
684+
T* a6 = reinterpret_cast<T*>(&acc[6]);
685+
T* a7 = reinterpret_cast<T*>(&acc[7]);
686+
T* a8 = reinterpret_cast<T*>(&acc[8]);
687+
T* a9 = reinterpret_cast<T*>(&acc[9]);
688+
#pragma unroll
689+
for (int j = 0; j < elems_per_vec; ++j)
690+
{
691+
a0[j] += a1[j];
692+
a2[j] += a3[j];
693+
a4[j] += a5[j];
694+
a6[j] += a7[j];
695+
a8[j] += a9[j];
696+
}
697+
#pragma unroll
698+
for (int j = 0; j < elems_per_vec; ++j)
699+
{
700+
a0[j] += a2[j];
701+
a4[j] += a6[j];
702+
}
703+
#pragma unroll
704+
for (int j = 0; j < elems_per_vec; ++j)
705+
{
706+
a0[j] += a4[j];
707+
a0[j] += a8[j];
708+
}
709+
}
614710

615711
// Reduce acc[TOP_K] into acc[0]
616712
if constexpr (TOP_K == 8)
@@ -643,6 +739,28 @@ __device__ void vectorized_combine_impl(
643739
a0[j] += a4[j];
644740
}
645741
}
742+
else if constexpr (TOP_K == 6)
743+
{
744+
T* a0 = reinterpret_cast<T*>(&acc[0]);
745+
T* a1 = reinterpret_cast<T*>(&acc[1]);
746+
T* a2 = reinterpret_cast<T*>(&acc[2]);
747+
T* a3 = reinterpret_cast<T*>(&acc[3]);
748+
T* a4 = reinterpret_cast<T*>(&acc[4]);
749+
T* a5 = reinterpret_cast<T*>(&acc[5]);
750+
#pragma unroll
751+
for (int j = 0; j < elems_per_vec; ++j)
752+
{
753+
a0[j] += a1[j];
754+
a2[j] += a3[j];
755+
a4[j] += a5[j];
756+
}
757+
#pragma unroll
758+
for (int j = 0; j < elems_per_vec; ++j)
759+
{
760+
a0[j] += a2[j];
761+
a0[j] += a4[j];
762+
}
763+
}
646764
else if constexpr (TOP_K == 4)
647765
{
648766
T* a0 = reinterpret_cast<T*>(&acc[0]);

cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,9 @@ namespace tensorrt_llm::kernels::moe_comm
2323
{
2424

2525
// Configuration constants
26-
static constexpr int kMaxExperts = 256; // Maximum number of experts per rank
27-
static constexpr int kMaxTopK = 8; // Maximum top-k experts per token
28-
static constexpr int kMaxPayloads = 8; // Maximum number of different payload types
29-
static constexpr int kMaxRanks = 64; // Maximum supported EP size
26+
static constexpr int kMaxTopK = 16; // Maximum top-k experts per token
27+
static constexpr int kMaxPayloads = 4; // Maximum number of different payload types
28+
static constexpr int kMaxRanks = 64; // Maximum supported EP size
3029

3130
// Describes a single payload type to be communicated
3231
struct PayloadDescriptor

0 commit comments

Comments
 (0)