Skip to content

Commit 7b0f471

Browse files
committed
Revert "temp rm all"
This reverts commit ec89f0c.
1 parent 783120b commit 7b0f471

File tree

1 file changed

+110
-24
lines changed

1 file changed

+110
-24
lines changed

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh

Lines changed: 110 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <numeric>
2525
#include <random>
2626
#include <sstream>
27+
#include <type_traits>
2728

2829
#include "tensorrt_llm/common/memoryUtils.h"
2930
#include "tensorrt_llm/common/workspace.h"
@@ -865,7 +866,7 @@ void threeStepBuildExpertMapsSortFirstToken(
865866
// ============================== Infer GEMM sizes =================================
866867
// TODO Could linear search be better for small # experts
867868
template <class T>
868-
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices,
869+
__device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices,
869870
int64_t const arr_length, T const target) {
870871
int64_t low = 0, high = arr_length - 1, target_location = -1;
871872
while (low <= high) {
@@ -881,6 +882,49 @@ __device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices,
881882
return target_location + 1;
882883
}
883884

885+
template <class T>
886+
__device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices, int64_t const arr_length, T const target) {
887+
constexpr int ARR_LENGTH_CONST = 128;
888+
if (arr_length != ARR_LENGTH_CONST) {
889+
asm("trap;");
890+
}
891+
892+
constexpr unsigned full_mask = 0xffffffffu;
893+
constexpr int WARP_SZ = 32;
894+
const int lane_id = threadIdx.x & (WARP_SZ - 1);
895+
896+
int local_count = 0;
897+
#pragma unroll
898+
for (int k = 0; k < ARR_LENGTH_CONST / WARP_SZ; ++k) {
899+
const int idx = lane_id + k * WARP_SZ;
900+
T v = sorted_indices[idx];
901+
local_count += (v < target) ? 1 : 0;
902+
}
903+
904+
#pragma unroll
905+
for (int offset = 16; offset > 0; offset >>= 1) {
906+
local_count += __shfl_down_sync(full_mask, local_count, offset);
907+
}
908+
int total = __shfl_sync(full_mask, local_count, 0);
909+
910+
return (int64_t)total;
911+
}
912+
913+
template <class T>
914+
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) {
915+
// return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
916+
917+
return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target);
918+
919+
// int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
920+
// int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target);
921+
// if (out_v1 != out_v2) {
922+
// printf("different output! v1=%lld v2=%lld\n", out_v1, out_v2);
923+
// asm("trap;");
924+
// }
925+
// return out_v1;
926+
}
927+
884928
template <class T>
885929
using sizeof_bits = cutlass::sizeof_bits<
886930
typename cutlass_kernels::TllmToCutlassTypeAdapter<std::remove_cv_t<T>>::type>;
@@ -1414,20 +1458,23 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) {
14141458
// (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the
14151459
// source matrix, we simply take the modulus of the expanded index.
14161460

1417-
constexpr static int EXPAND_THREADS_PER_BLOCK = 256;
1461+
constexpr static int EXPAND_THREADS_PER_BLOCK = 128;
14181462

14191463
template <class InputActivationsType, class ExpandedActivationsType,
14201464
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType,
14211465
bool PRE_QUANT_AWQ>
14221466
__global__ void expandInputRowsKernel(
14231467
InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output,
14241468
float const* unpermuted_scales, float* permuted_scales,
1425-
int const* permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size,
1469+
int const* permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size_real_,
14261470
int64_t const k, float const* fc1_act_global_scale, bool use_per_expert_act_scale,
14271471
int64_t const* expert_first_token_offset,
14281472
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
14291473
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf,
14301474
int64_t const num_experts_per_node, InputActivationsType const* prequant_scales = nullptr) {
1475+
constexpr int hidden_size = 7168;
1476+
if (hidden_size != hidden_size_real_) { asm("trap;"); }
1477+
14311478
static_assert(BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE ||
14321479
!PRE_QUANT_AWQ,
14331480
"AWQ and Block Scaling are mutually exclusive");
@@ -1503,8 +1550,8 @@ __global__ void expandInputRowsKernel(
15031550
permuted_row * hidden_size / ELEM_PER_THREAD;
15041551

15051552
int64_t const start_offset = threadIdx.x;
1506-
int64_t const stride = EXPAND_THREADS_PER_BLOCK;
1507-
int64_t const num_elems_in_col = hidden_size / ELEM_PER_THREAD;
1553+
constexpr int64_t stride = EXPAND_THREADS_PER_BLOCK;
1554+
constexpr int64_t num_elems_in_col = hidden_size / ELEM_PER_THREAD;
15081555
assert(hidden_size % ELEM_PER_THREAD == 0);
15091556
assert(hidden_size % VecSize == 0);
15101557

@@ -1519,6 +1566,7 @@ __global__ void expandInputRowsKernel(
15191566
float global_scale_val = fc1_act_global_scale ? fc1_act_global_scale[act_scale_idx] : 1.0f;
15201567
int64_t num_tokens_before_expert = expert_first_token_offset[expert];
15211568

1569+
#pragma unroll
15221570
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
15231571
auto in_vec = source_row_ptr[elem_index];
15241572
if constexpr (need_nvfp4_quant || need_mxfp8_quant) {
@@ -1650,7 +1698,7 @@ void expandInputRowsKernelLauncher(
16501698

16511699
static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
16521700
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
1653-
int64_t const blocks = std::min(smCount * 8, std::max(num_rows * k, num_padding_tokens));
1701+
int64_t const blocks = std::min(smCount * 16, std::max(num_rows * k, num_padding_tokens));
16541702
int64_t const threads = EXPAND_THREADS_PER_BLOCK;
16551703

16561704
auto func = [&]() {
@@ -1748,11 +1796,20 @@ constexpr static int FINALIZE_THREADS_PER_BLOCK = 256;
17481796
// This kernel unpermutes the original data, does the k-way reduction and performs the final skip
17491797
// connection.
17501798
template <typename OutputType, class GemmOutputType, class ScaleBiasType, ScaleMode SCALE_MODE>
1751-
__global__ void finalizeMoeRoutingKernel(
1799+
__global__
1800+
__maxnreg__(64)
1801+
void finalizeMoeRoutingKernel(
17521802
GemmOutputType const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
17531803
ScaleBiasType const* bias, float const* scales, int const* unpermuted_row_to_permuted_row,
1754-
int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token,
1804+
int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token_real_,
17551805
int const num_experts_per_node, int const start_expert_id) {
1806+
if constexpr (not (std::is_same_v<GemmOutputType, __nv_bfloat16> and std::is_same_v<OutputType, __nv_bfloat16>)) {
1807+
printf("finalizeMoeRoutingKernel see unsupported dtype\n");
1808+
asm("trap;");
1809+
} else {
1810+
constexpr int experts_per_token = 8;
1811+
if (experts_per_token != experts_per_token_real_) { asm("trap;"); }
1812+
17561813
int64_t const original_row = blockIdx.x;
17571814
int64_t const num_rows = gridDim.x;
17581815
auto const offset = original_row * orig_cols;
@@ -1784,43 +1841,67 @@ __global__ void finalizeMoeRoutingKernel(
17841841
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
17851842
ComputeElem thread_output;
17861843
thread_output.fill(0);
1844+
1845+
int4 input_val_buf[experts_per_token];
1846+
uint32_t enable_input_buf = 0;
1847+
1848+
#pragma unroll
17871849
for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) {
17881850
int64_t const k_offset = original_row * experts_per_token + k_idx;
17891851
int64_t const expert_id = token_selected_experts[k_offset] - start_expert_id;
1790-
if (expert_id < 0 || expert_id >= num_experts_per_node) {
1791-
continue;
1792-
}
17931852

17941853
int64_t const expanded_original_row = original_row + k_idx * num_rows;
17951854
int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row];
17961855

1856+
if (expert_id < 0 || expert_id >= num_experts_per_node) {
1857+
continue;
1858+
}
1859+
17971860
int64_t expanded_rows = num_rows * experts_per_token;
17981861
if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) {
17991862
continue;
18001863
}
18011864

1802-
float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset];
1803-
18041865
auto const* expanded_permuted_rows_row_ptr =
18051866
expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col;
18061867

1807-
ComputeElem expert_result =
1808-
arrayConvert<InputElem, ComputeElem>(expanded_permuted_rows_row_ptr[elem_index]);
1809-
if (bias) {
1810-
auto const* bias_ptr = bias_v + expert_id * num_elems_in_col;
1811-
expert_result = expert_result + arrayConvert<BiasElem, ComputeElem>(bias_ptr[elem_index]);
1812-
}
1868+
// ComputeElem expert_result =
1869+
// arrayConvert<InputElem, ComputeElem>(expanded_permuted_rows_row_ptr[elem_index]);
1870+
static_assert(sizeof(expanded_permuted_rows_row_ptr[0]) == sizeof(int4));
1871+
input_val_buf[k_idx] = *reinterpret_cast<const int4*>(expanded_permuted_rows_row_ptr + elem_index);
1872+
enable_input_buf |= 1 << k_idx;
1873+
}
1874+
1875+
#pragma unroll
1876+
for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) {
1877+
if (not (enable_input_buf & (1 << k_idx))) continue;
1878+
1879+
int64_t const k_offset = original_row * experts_per_token + k_idx;
1880+
float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset];
1881+
1882+
int4 input_val = input_val_buf[k_idx];
1883+
ComputeElem expert_result = arrayConvert<InputElem, ComputeElem>(*reinterpret_cast<const InputElem*>(&input_val));
1884+
// if (bias) {
1885+
// auto const* bias_ptr = bias_v + expert_id * num_elems_in_col;
1886+
// expert_result = expert_result + arrayConvert<BiasElem, ComputeElem>(bias_ptr[elem_index]);
1887+
// }
18131888

18141889
thread_output = thread_output + row_scale * expert_result;
18151890
}
18161891

1817-
OutputElem output_elem = arrayConvert<ComputeElem, OutputElem>(thread_output);
1818-
reduced_row_ptr_v[elem_index] = output_elem;
1892+
// OutputElem output_elem = arrayConvert<ComputeElem, OutputElem>(thread_output);
1893+
// reduced_row_ptr_v[elem_index] = output_elem;
1894+
// TODO alignment issue?
1895+
__align__(16) OutputElem output_elem_original = arrayConvert<ComputeElem, OutputElem>(thread_output);
1896+
int4 output_elem = *reinterpret_cast<int4*>(&output_elem_original);
1897+
static_assert(sizeof(reduced_row_ptr_v[0]) == sizeof(int4));
1898+
*reinterpret_cast<int4*>(reduced_row_ptr_v + elem_index) = output_elem;
18191899
}
18201900
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
18211901
asm volatile("griddepcontrol.launch_dependents;");
18221902
#endif
18231903
}
1904+
}
18241905

18251906
// Final kernel to unpermute and scale
18261907
// This kernel unpermutes the original data, does the k-way reduction and performs the final skip
@@ -2130,10 +2211,13 @@ template <class T, class GemmOutputType, class ScaleBiasType, class ActFn,
21302211
__global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
21312212
float const* fp8_quant, ScaleBiasType const* bias_ptr,
21322213
bool bias_is_broadcast, int64_t const* expert_first_token_offset,
2133-
int num_experts_per_node, int64_t inter_size,
2214+
int num_experts_per_node, int64_t inter_size_real_,
21342215
float const* fc2_act_global_scale, bool use_per_expert_act_scale,
21352216
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat,
21362217
ActivationParams activation_params) {
2218+
constexpr int inter_size = 2048;
2219+
if (inter_size != inter_size_real_) { asm("trap;"); }
2220+
21372221
#ifdef ENABLE_FP4
21382222
constexpr bool IsNVFP4 =
21392223
std::is_same_v<T, __nv_fp4_e2m1> &&
@@ -2218,16 +2302,18 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
22182302
auto output_vec = reinterpret_cast<OutputElem*>(safe_inc_ptr(output, output_offset));
22192303
auto bias_ptr_vec = reinterpret_cast<BiasElem const*>(bias_ptr + bias_offset);
22202304
int64_t const start_offset = tid;
2221-
int64_t const stride = ACTIVATION_THREADS_PER_BLOCK;
2305+
constexpr int64_t stride = ACTIVATION_THREADS_PER_BLOCK;
22222306
assert(inter_size % ACTIVATION_ELEM_PER_THREAD == 0);
2223-
int64_t const num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
2307+
constexpr int64_t num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
22242308
assert(gated_off % ACTIVATION_ELEM_PER_THREAD == 0);
22252309
int64_t const gated_off_vec = gated_off / ACTIVATION_ELEM_PER_THREAD;
22262310

22272311
ActFn fn{};
22282312
fn.alpha = gate_alpha;
22292313
fn.beta = gate_beta;
22302314
fn.limit = gate_limit;
2315+
2316+
#pragma unroll
22312317
for (int64_t elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
22322318
auto fc1_value =
22332319
arrayConvert<GemmResultElem, ComputeElem>(gemm_result_vec[elem_index + gated_off_vec]);

0 commit comments

Comments
 (0)