Skip to content

Commit 348a536

Browse files
committed
cp change-block-thread, pragma-unroll, mv-if-check
1 parent 3a74536 commit 348a536

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1457,7 +1457,7 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) {
14571457
// (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the
14581458
// source matrix, we simply take the modulus of the expanded index.
14591459

1460-
constexpr static int EXPAND_THREADS_PER_BLOCK = 256;
1460+
constexpr static int EXPAND_THREADS_PER_BLOCK = 128;
14611461

14621462
template <class InputActivationsType, class ExpandedActivationsType,
14631463
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType,
@@ -1697,7 +1697,7 @@ void expandInputRowsKernelLauncher(
16971697

16981698
static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
16991699
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
1700-
int64_t const blocks = std::min(smCount * 8, std::max(num_rows * k, num_padding_tokens));
1700+
int64_t const blocks = std::min(smCount * 16, std::max(num_rows * k, num_padding_tokens));
17011701
int64_t const threads = EXPAND_THREADS_PER_BLOCK;
17021702

17031703
auto func = [&]() {
@@ -1813,6 +1813,10 @@ void finalizeMoeRoutingKernel(
18131813
ScaleBiasType const* bias, float const* scales, int const* unpermuted_row_to_permuted_row,
18141814
int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token_real_,
18151815
int const num_experts_per_node, int const start_expert_id) {
1816+
if constexpr (not (std::is_same_v<GemmOutputType, __nv_bfloat16> and std::is_same_v<OutputType, __nv_bfloat16>)) {
1817+
printf("finalizeMoeRoutingKernel see unsupported dtype\n");
1818+
asm("trap;");
1819+
} else {
18161820
constexpr int experts_per_token = 8;
18171821
if (experts_per_token != experts_per_token_real_) { asm("trap;"); }
18181822

@@ -1847,16 +1851,19 @@ void finalizeMoeRoutingKernel(
18471851
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
18481852
ComputeElem thread_output;
18491853
thread_output.fill(0);
1854+
1855+
#pragma unroll
18501856
for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) {
18511857
int64_t const k_offset = original_row * experts_per_token + k_idx;
18521858
int64_t const expert_id = token_selected_experts[k_offset] - start_expert_id;
1853-
if (expert_id < 0 || expert_id >= num_experts_per_node) {
1854-
continue;
1855-
}
18561859

18571860
int64_t const expanded_original_row = original_row + k_idx * num_rows;
18581861
int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row];
18591862

1863+
if (expert_id < 0 || expert_id >= num_experts_per_node) {
1864+
continue;
1865+
}
1866+
18601867
int64_t expanded_rows = num_rows * experts_per_token;
18611868
if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) {
18621869
continue;
@@ -1884,6 +1891,7 @@ void finalizeMoeRoutingKernel(
18841891
asm volatile("griddepcontrol.launch_dependents;");
18851892
#endif
18861893
}
1894+
}
18871895

18881896
// Final kernel to unpermute and scale
18891897
// This kernel unpermutes the original data, does the k-way reduction and performs the final skip

0 commit comments

Comments
 (0)