@@ -570,6 +570,7 @@ template <typename T,
570
570
int NUM_EXPERTS,
571
571
int WARPS_PER_CTA,
572
572
int BYTES_PER_LDG,
573
+ bool Norm_Weights = false ,
573
574
typename IdxT = int >
574
575
__launch_bounds__ (WARPS_PER_CTA * WARP_SIZE) __global__
575
576
void topk_gating_softmax (const T* input,
@@ -628,6 +629,7 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
628
629
// We compute row offset for each thread sub-group
629
630
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
630
631
const int thread_row = warp_base_row + thread_row_in_warp;
632
+ const int thread_row_in_cta = thread_row - cta_base_row;
631
633
632
634
// Threads with indices out of bounds should early exit here.
633
635
if (thread_row >= num_rows) return ;
@@ -643,6 +645,9 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
643
645
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
644
646
const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
645
647
648
+ T weight_sum = static_cast <T>(0 );
649
+ extern __shared__ T row_output[];
650
+
646
651
// Determine the pointer type to use to read in the data depending on the
647
652
// BYTES_PER_LDG template param. In theory, this can support all powers of 2
648
653
// up to 16.
@@ -765,7 +770,14 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
765
770
// global memory. (This will be a single) thread per row of the
766
771
// input/output matrices.
767
772
const int idx = k * thread_row + k_idx;
768
- output[idx] = T (max_val);
773
+ if constexpr (Norm_Weights) {
774
+ const int idx_in_cta = k * thread_row_in_cta + k_idx;
775
+ row_output[idx_in_cta] = T (max_val);
776
+ weight_sum += T (max_val);
777
+ }
778
+ else {
779
+ output[idx] = T (max_val);
780
+ }
769
781
indices[idx] = should_process_row ? expert : NUM_EXPERTS;
770
782
source_rows[idx] = k_idx * num_rows + thread_row;
771
783
}
@@ -788,6 +800,16 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
788
800
}
789
801
}
790
802
}
803
+ if constexpr (Norm_Weights) {
804
+ #pragma unroll
805
+ for (int k_idx = 0 ; k_idx < k; ++k_idx) {
806
+ if (thread_group_idx == 0 ) {
807
+ const int idx = k * thread_row + k_idx;
808
+ const int idx_in_cta = k * thread_row_in_cta + k_idx;
809
+ output[idx] = row_output[idx_in_cta] / weight_sum;
810
+ }
811
+ }
812
+ }
791
813
}
792
814
793
815
namespace detail {
@@ -807,7 +829,7 @@ struct TopkConstants {
807
829
};
808
830
} // namespace detail
809
831
810
- template <typename T, int EXPERTS, int WARPS_PER_TB, typename IdxT = int >
832
+ template <typename T, int EXPERTS, int WARPS_PER_TB, bool Norm_Weights = false , typename IdxT = int >
811
833
void topk_gating_softmax_launcher_helper (const T* input,
812
834
T* output,
813
835
IdxT* indices,
@@ -826,8 +848,9 @@ void topk_gating_softmax_launcher_helper(const T* input,
826
848
const int num_blocks = (num_warps + WARPS_PER_TB - 1 ) / WARPS_PER_TB;
827
849
828
850
dim3 block_dim (WARP_SIZE, WARPS_PER_TB);
829
- topk_gating_softmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG>
830
- <<<num_blocks, block_dim, 0 , stream>>>(
851
+ static constexpr int ROWS_PER_CTA = WARPS_PER_TB * ROWS_PER_WARP;
852
+ topk_gating_softmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, Norm_Weights>
853
+ <<<num_blocks, block_dim, ROWS_PER_CTA * k * sizeof (T), stream>>>(
831
854
input, output, num_rows, indices, source_row, k);
832
855
}
833
856
0 commit comments