@@ -1457,7 +1457,7 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) {
1457
1457
// (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the
1458
1458
// source matrix, we simply take the modulus of the expanded index.
1459
1459
1460
- constexpr static int EXPAND_THREADS_PER_BLOCK = 256 ;
1460
+ constexpr static int EXPAND_THREADS_PER_BLOCK = 128 ;
1461
1461
1462
1462
template <class InputActivationsType , class ExpandedActivationsType ,
1463
1463
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType,
@@ -1697,7 +1697,7 @@ void expandInputRowsKernelLauncher(
1697
1697
1698
1698
static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount ();
1699
1699
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
1700
- int64_t const blocks = std::min (smCount * 8 , std::max (num_rows * k, num_padding_tokens));
1700
+ int64_t const blocks = std::min (smCount * 16 , std::max (num_rows * k, num_padding_tokens));
1701
1701
int64_t const threads = EXPAND_THREADS_PER_BLOCK;
1702
1702
1703
1703
auto func = [&]() {
@@ -1813,6 +1813,10 @@ void finalizeMoeRoutingKernel(
1813
1813
ScaleBiasType const * bias, float const * scales, int const * unpermuted_row_to_permuted_row,
1814
1814
int const * token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token_real_,
1815
1815
int const num_experts_per_node, int const start_expert_id) {
1816
+ if constexpr (not (std::is_same_v<GemmOutputType, __nv_bfloat16> and std::is_same_v<OutputType, __nv_bfloat16>)) {
1817
+ printf (" finalizeMoeRoutingKernel see unsupported dtype\n " );
1818
+ asm (" trap;" );
1819
+ } else {
1816
1820
constexpr int experts_per_token = 8 ;
1817
1821
if (experts_per_token != experts_per_token_real_) { asm (" trap;" ); }
1818
1822
@@ -1847,16 +1851,19 @@ void finalizeMoeRoutingKernel(
1847
1851
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
1848
1852
ComputeElem thread_output;
1849
1853
thread_output.fill (0 );
1854
+
1855
+ #pragma unroll
1850
1856
for (int k_idx = 0 ; k_idx < experts_per_token; ++k_idx) {
1851
1857
int64_t const k_offset = original_row * experts_per_token + k_idx;
1852
1858
int64_t const expert_id = token_selected_experts[k_offset] - start_expert_id;
1853
- if (expert_id < 0 || expert_id >= num_experts_per_node) {
1854
- continue ;
1855
- }
1856
1859
1857
1860
int64_t const expanded_original_row = original_row + k_idx * num_rows;
1858
1861
int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row];
1859
1862
1863
+ if (expert_id < 0 || expert_id >= num_experts_per_node) {
1864
+ continue ;
1865
+ }
1866
+
1860
1867
int64_t expanded_rows = num_rows * experts_per_token;
1861
1868
if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) {
1862
1869
continue ;
@@ -1884,6 +1891,7 @@ void finalizeMoeRoutingKernel(
1884
1891
asm volatile (" griddepcontrol.launch_dependents;" );
1885
1892
#endif
1886
1893
}
1894
+ }
1887
1895
1888
1896
// Final kernel to unpermute and scale
1889
1897
// This kernel unpermutes the original data, does the k-way reduction and performs the final skip
0 commit comments