@@ -45,6 +45,18 @@ namespace tensorrt_llm::kernels::moe_comm
4545#define SWITCH_TOP_K (top_k, TOP_K, ...) \
4646 switch (top_k) \
4747 { \
48+ case 16 : \
49+ { \
50+ constexpr int TOP_K = 16 ; \
51+ __VA_ARGS__; \
52+ break ; \
53+ } \
54+ case 10 : \
55+ { \
56+ constexpr int TOP_K = 10 ; \
57+ __VA_ARGS__; \
58+ break ; \
59+ } \
4860 case 8 : \
4961 { \
5062 constexpr int TOP_K = 8 ; \
@@ -611,6 +623,90 @@ __device__ void vectorized_combine_impl(
611623 // Load directly into the per-k accumulator; reduce across k below
612624 acc[k].load (recv_buffer + base_token + offset);
613625 }
626+ if constexpr (TOP_K == 16 )
627+ {
628+ T* a0 = reinterpret_cast <T*>(&acc[0 ]);
629+ T* a1 = reinterpret_cast <T*>(&acc[1 ]);
630+ T* a2 = reinterpret_cast <T*>(&acc[2 ]);
631+ T* a3 = reinterpret_cast <T*>(&acc[3 ]);
632+ T* a4 = reinterpret_cast <T*>(&acc[4 ]);
633+ T* a5 = reinterpret_cast <T*>(&acc[5 ]);
634+ T* a6 = reinterpret_cast <T*>(&acc[6 ]);
635+ T* a7 = reinterpret_cast <T*>(&acc[7 ]);
636+ T* a8 = reinterpret_cast <T*>(&acc[8 ]);
637+ T* a9 = reinterpret_cast <T*>(&acc[9 ]);
638+ T* a10 = reinterpret_cast <T*>(&acc[10 ]);
639+ T* a11 = reinterpret_cast <T*>(&acc[11 ]);
640+ T* a12 = reinterpret_cast <T*>(&acc[12 ]);
641+ T* a13 = reinterpret_cast <T*>(&acc[13 ]);
642+ T* a14 = reinterpret_cast <T*>(&acc[14 ]);
643+ T* a15 = reinterpret_cast <T*>(&acc[15 ]);
644+ #pragma unroll
645+ for (int j = 0 ; j < elems_per_vec; ++j)
646+ {
647+ a0[j] += a1[j];
648+ a2[j] += a3[j];
649+ a4[j] += a5[j];
650+ a6[j] += a7[j];
651+ a8[j] += a9[j];
652+ a10[j] += a11[j];
653+ a12[j] += a13[j];
654+ a14[j] += a15[j];
655+ }
656+ #pragma unroll
657+ for (int j = 0 ; j < elems_per_vec; ++j)
658+ {
659+ a0[j] += a2[j];
660+ a4[j] += a6[j];
661+ a8[j] += a10[j];
662+ a12[j] += a14[j];
663+ }
664+ #pragma unroll
665+ for (int j = 0 ; j < elems_per_vec; ++j)
666+ {
667+ a0[j] += a4[j];
668+ a8[j] += a12[j];
669+ }
670+ #pragma unroll
671+ for (int j = 0 ; j < elems_per_vec; ++j)
672+ {
673+ a0[j] += a8[j];
674+ }
675+ }
676+ else if constexpr (TOP_K == 10 )
677+ {
678+ T* a0 = reinterpret_cast <T*>(&acc[0 ]);
679+ T* a1 = reinterpret_cast <T*>(&acc[1 ]);
680+ T* a2 = reinterpret_cast <T*>(&acc[2 ]);
681+ T* a3 = reinterpret_cast <T*>(&acc[3 ]);
682+ T* a4 = reinterpret_cast <T*>(&acc[4 ]);
683+ T* a5 = reinterpret_cast <T*>(&acc[5 ]);
684+ T* a6 = reinterpret_cast <T*>(&acc[6 ]);
685+ T* a7 = reinterpret_cast <T*>(&acc[7 ]);
686+ T* a8 = reinterpret_cast <T*>(&acc[8 ]);
687+ T* a9 = reinterpret_cast <T*>(&acc[9 ]);
688+ #pragma unroll
689+ for (int j = 0 ; j < elems_per_vec; ++j)
690+ {
691+ a0[j] += a1[j];
692+ a2[j] += a3[j];
693+ a4[j] += a5[j];
694+ a6[j] += a7[j];
695+ a8[j] += a9[j];
696+ }
697+ #pragma unroll
698+ for (int j = 0 ; j < elems_per_vec; ++j)
699+ {
700+ a0[j] += a2[j];
701+ a4[j] += a6[j];
702+ }
703+ #pragma unroll
704+ for (int j = 0 ; j < elems_per_vec; ++j)
705+ {
706+ a0[j] += a4[j];
707+ a0[j] += a8[j];
708+ }
709+ }
614710
615711 // Reduce acc[TOP_K] into acc[0]
616712 if constexpr (TOP_K == 8 )
@@ -643,6 +739,28 @@ __device__ void vectorized_combine_impl(
643739 a0[j] += a4[j];
644740 }
645741 }
742+ else if constexpr (TOP_K == 6 )
743+ {
744+ T* a0 = reinterpret_cast <T*>(&acc[0 ]);
745+ T* a1 = reinterpret_cast <T*>(&acc[1 ]);
746+ T* a2 = reinterpret_cast <T*>(&acc[2 ]);
747+ T* a3 = reinterpret_cast <T*>(&acc[3 ]);
748+ T* a4 = reinterpret_cast <T*>(&acc[4 ]);
749+ T* a5 = reinterpret_cast <T*>(&acc[5 ]);
750+ #pragma unroll
751+ for (int j = 0 ; j < elems_per_vec; ++j)
752+ {
753+ a0[j] += a1[j];
754+ a2[j] += a3[j];
755+ a4[j] += a5[j];
756+ }
757+ #pragma unroll
758+ for (int j = 0 ; j < elems_per_vec; ++j)
759+ {
760+ a0[j] += a2[j];
761+ a0[j] += a4[j];
762+ }
763+ }
646764 else if constexpr (TOP_K == 4 )
647765 {
648766 T* a0 = reinterpret_cast <T*>(&acc[0 ]);
0 commit comments