Skip to content

Commit ec89f0c

Browse files
committed
temp rm all
1 parent 9d1456a commit ec89f0c

File tree

1 file changed

+24
-110
lines changed

1 file changed

+24
-110
lines changed

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh

Lines changed: 24 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
#include <numeric>
2525
#include <random>
2626
#include <sstream>
27-
#include <type_traits>
2827

2928
#include "tensorrt_llm/common/memoryUtils.h"
3029
#include "tensorrt_llm/common/workspace.h"
@@ -866,7 +865,7 @@ void threeStepBuildExpertMapsSortFirstToken(
866865
// ============================== Infer GEMM sizes =================================
867866
// TODO Could linear search be better for small # experts
868867
template <class T>
869-
__device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices,
868+
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices,
870869
int64_t const arr_length, T const target) {
871870
int64_t low = 0, high = arr_length - 1, target_location = -1;
872871
while (low <= high) {
@@ -882,49 +881,6 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices
882881
return target_location + 1;
883882
}
884883

885-
template <class T>
886-
__device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices, int64_t const arr_length, T const target) {
887-
constexpr int ARR_LENGTH_CONST = 128;
888-
if (arr_length != ARR_LENGTH_CONST) {
889-
asm("trap;");
890-
}
891-
892-
constexpr unsigned full_mask = 0xffffffffu;
893-
constexpr int WARP_SZ = 32;
894-
const int lane_id = threadIdx.x & (WARP_SZ - 1);
895-
896-
int local_count = 0;
897-
#pragma unroll
898-
for (int k = 0; k < ARR_LENGTH_CONST / WARP_SZ; ++k) {
899-
const int idx = lane_id + k * WARP_SZ;
900-
T v = sorted_indices[idx];
901-
local_count += (v < target) ? 1 : 0;
902-
}
903-
904-
#pragma unroll
905-
for (int offset = 16; offset > 0; offset >>= 1) {
906-
local_count += __shfl_down_sync(full_mask, local_count, offset);
907-
}
908-
int total = __shfl_sync(full_mask, local_count, 0);
909-
910-
return (int64_t)total;
911-
}
912-
913-
template <class T>
914-
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) {
915-
// return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
916-
917-
return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target);
918-
919-
// int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
920-
// int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target);
921-
// if (out_v1 != out_v2) {
922-
// printf("different output! v1=%lld v2=%lld\n", out_v1, out_v2);
923-
// asm("trap;");
924-
// }
925-
// return out_v1;
926-
}
927-
928884
template <class T>
929885
using sizeof_bits = cutlass::sizeof_bits<
930886
typename cutlass_kernels::TllmToCutlassTypeAdapter<std::remove_cv_t<T>>::type>;
@@ -1458,23 +1414,20 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) {
14581414
// (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the
14591415
// source matrix, we simply take the modulus of the expanded index.
14601416

1461-
constexpr static int EXPAND_THREADS_PER_BLOCK = 128;
1417+
constexpr static int EXPAND_THREADS_PER_BLOCK = 256;
14621418

14631419
template <class InputActivationsType, class ExpandedActivationsType,
14641420
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType,
14651421
bool PRE_QUANT_AWQ>
14661422
__global__ void expandInputRowsKernel(
14671423
InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output,
14681424
float const* unpermuted_scales, float* permuted_scales,
1469-
int const* permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size_real_,
1425+
int const* permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size,
14701426
int64_t const k, float const* fc1_act_global_scale, bool use_per_expert_act_scale,
14711427
int64_t const* expert_first_token_offset,
14721428
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
14731429
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf,
14741430
int64_t const num_experts_per_node, InputActivationsType const* prequant_scales = nullptr) {
1475-
constexpr int hidden_size = 7168;
1476-
if (hidden_size != hidden_size_real_) { asm("trap;"); }
1477-
14781431
static_assert(BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE ||
14791432
!PRE_QUANT_AWQ,
14801433
"AWQ and Block Scaling are mutually exclusive");
@@ -1550,8 +1503,8 @@ __global__ void expandInputRowsKernel(
15501503
permuted_row * hidden_size / ELEM_PER_THREAD;
15511504

15521505
int64_t const start_offset = threadIdx.x;
1553-
constexpr int64_t stride = EXPAND_THREADS_PER_BLOCK;
1554-
constexpr int64_t num_elems_in_col = hidden_size / ELEM_PER_THREAD;
1506+
int64_t const stride = EXPAND_THREADS_PER_BLOCK;
1507+
int64_t const num_elems_in_col = hidden_size / ELEM_PER_THREAD;
15551508
assert(hidden_size % ELEM_PER_THREAD == 0);
15561509
assert(hidden_size % VecSize == 0);
15571510

@@ -1566,7 +1519,6 @@ __global__ void expandInputRowsKernel(
15661519
float global_scale_val = fc1_act_global_scale ? fc1_act_global_scale[act_scale_idx] : 1.0f;
15671520
int64_t num_tokens_before_expert = expert_first_token_offset[expert];
15681521

1569-
#pragma unroll
15701522
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
15711523
auto in_vec = source_row_ptr[elem_index];
15721524
if constexpr (need_nvfp4_quant || need_mxfp8_quant) {
@@ -1698,7 +1650,7 @@ void expandInputRowsKernelLauncher(
16981650

16991651
static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
17001652
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
1701-
int64_t const blocks = std::min(smCount * 16, std::max(num_rows * k, num_padding_tokens));
1653+
int64_t const blocks = std::min(smCount * 8, std::max(num_rows * k, num_padding_tokens));
17021654
int64_t const threads = EXPAND_THREADS_PER_BLOCK;
17031655

17041656
auto func = [&]() {
@@ -1796,20 +1748,11 @@ constexpr static int FINALIZE_THREADS_PER_BLOCK = 256;
17961748
// This kernel unpermutes the original data, does the k-way reduction and performs the final skip
17971749
// connection.
17981750
template <typename OutputType, class GemmOutputType, class ScaleBiasType, ScaleMode SCALE_MODE>
1799-
__global__
1800-
__maxnreg__(64)
1801-
void finalizeMoeRoutingKernel(
1751+
__global__ void finalizeMoeRoutingKernel(
18021752
GemmOutputType const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
18031753
ScaleBiasType const* bias, float const* scales, int const* unpermuted_row_to_permuted_row,
1804-
int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token_real_,
1754+
int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token,
18051755
int const num_experts_per_node, int const start_expert_id) {
1806-
if constexpr (not (std::is_same_v<GemmOutputType, __nv_bfloat16> and std::is_same_v<OutputType, __nv_bfloat16>)) {
1807-
printf("finalizeMoeRoutingKernel see unsupported dtype\n");
1808-
asm("trap;");
1809-
} else {
1810-
constexpr int experts_per_token = 8;
1811-
if (experts_per_token != experts_per_token_real_) { asm("trap;"); }
1812-
18131756
int64_t const original_row = blockIdx.x;
18141757
int64_t const num_rows = gridDim.x;
18151758
auto const offset = original_row * orig_cols;
@@ -1841,67 +1784,43 @@ if constexpr (not (std::is_same_v<GemmOutputType, __nv_bfloat16> and std::is_sam
18411784
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
18421785
ComputeElem thread_output;
18431786
thread_output.fill(0);
1844-
1845-
int4 input_val_buf[experts_per_token];
1846-
uint32_t enable_input_buf = 0;
1847-
1848-
#pragma unroll
18491787
for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) {
18501788
int64_t const k_offset = original_row * experts_per_token + k_idx;
18511789
int64_t const expert_id = token_selected_experts[k_offset] - start_expert_id;
1852-
1853-
int64_t const expanded_original_row = original_row + k_idx * num_rows;
1854-
int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row];
1855-
18561790
if (expert_id < 0 || expert_id >= num_experts_per_node) {
18571791
continue;
18581792
}
18591793

1794+
int64_t const expanded_original_row = original_row + k_idx * num_rows;
1795+
int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row];
1796+
18601797
int64_t expanded_rows = num_rows * experts_per_token;
18611798
if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) {
18621799
continue;
18631800
}
18641801

1802+
float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset];
1803+
18651804
auto const* expanded_permuted_rows_row_ptr =
18661805
expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col;
18671806

1868-
// ComputeElem expert_result =
1869-
// arrayConvert<InputElem, ComputeElem>(expanded_permuted_rows_row_ptr[elem_index]);
1870-
static_assert(sizeof(expanded_permuted_rows_row_ptr[0]) == sizeof(int4));
1871-
input_val_buf[k_idx] = *reinterpret_cast<const int4*>(expanded_permuted_rows_row_ptr + elem_index);
1872-
enable_input_buf |= 1 << k_idx;
1873-
}
1874-
1875-
#pragma unroll
1876-
for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) {
1877-
if (not (enable_input_buf & (1 << k_idx))) continue;
1878-
1879-
int64_t const k_offset = original_row * experts_per_token + k_idx;
1880-
float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset];
1881-
1882-
int4 input_val = input_val_buf[k_idx];
1883-
ComputeElem expert_result = arrayConvert<InputElem, ComputeElem>(*reinterpret_cast<const InputElem*>(&input_val));
1884-
// if (bias) {
1885-
// auto const* bias_ptr = bias_v + expert_id * num_elems_in_col;
1886-
// expert_result = expert_result + arrayConvert<BiasElem, ComputeElem>(bias_ptr[elem_index]);
1887-
// }
1807+
ComputeElem expert_result =
1808+
arrayConvert<InputElem, ComputeElem>(expanded_permuted_rows_row_ptr[elem_index]);
1809+
if (bias) {
1810+
auto const* bias_ptr = bias_v + expert_id * num_elems_in_col;
1811+
expert_result = expert_result + arrayConvert<BiasElem, ComputeElem>(bias_ptr[elem_index]);
1812+
}
18881813

18891814
thread_output = thread_output + row_scale * expert_result;
18901815
}
18911816

1892-
// OutputElem output_elem = arrayConvert<ComputeElem, OutputElem>(thread_output);
1893-
// reduced_row_ptr_v[elem_index] = output_elem;
1894-
// TODO alignment issue?
1895-
__align__(16) OutputElem output_elem_original = arrayConvert<ComputeElem, OutputElem>(thread_output);
1896-
int4 output_elem = *reinterpret_cast<int4*>(&output_elem_original);
1897-
static_assert(sizeof(reduced_row_ptr_v[0]) == sizeof(int4));
1898-
*reinterpret_cast<int4*>(reduced_row_ptr_v + elem_index) = output_elem;
1817+
OutputElem output_elem = arrayConvert<ComputeElem, OutputElem>(thread_output);
1818+
reduced_row_ptr_v[elem_index] = output_elem;
18991819
}
19001820
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
19011821
asm volatile("griddepcontrol.launch_dependents;");
19021822
#endif
19031823
}
1904-
}
19051824

19061825
// Final kernel to unpermute and scale
19071826
// This kernel unpermutes the original data, does the k-way reduction and performs the final skip
@@ -2211,13 +2130,10 @@ template <class T, class GemmOutputType, class ScaleBiasType, class ActFn,
22112130
__global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
22122131
float const* fp8_quant, ScaleBiasType const* bias_ptr,
22132132
bool bias_is_broadcast, int64_t const* expert_first_token_offset,
2214-
int num_experts_per_node, int64_t inter_size_real_,
2133+
int num_experts_per_node, int64_t inter_size,
22152134
float const* fc2_act_global_scale, bool use_per_expert_act_scale,
22162135
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat,
22172136
ActivationParams activation_params) {
2218-
constexpr int inter_size = 2048;
2219-
if (inter_size != inter_size_real_) { asm("trap;"); }
2220-
22212137
#ifdef ENABLE_FP4
22222138
constexpr bool IsNVFP4 =
22232139
std::is_same_v<T, __nv_fp4_e2m1> &&
@@ -2302,18 +2218,16 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
23022218
auto output_vec = reinterpret_cast<OutputElem*>(safe_inc_ptr(output, output_offset));
23032219
auto bias_ptr_vec = reinterpret_cast<BiasElem const*>(bias_ptr + bias_offset);
23042220
int64_t const start_offset = tid;
2305-
constexpr int64_t stride = ACTIVATION_THREADS_PER_BLOCK;
2221+
int64_t const stride = ACTIVATION_THREADS_PER_BLOCK;
23062222
assert(inter_size % ACTIVATION_ELEM_PER_THREAD == 0);
2307-
constexpr int64_t num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
2223+
int64_t const num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
23082224
assert(gated_off % ACTIVATION_ELEM_PER_THREAD == 0);
23092225
int64_t const gated_off_vec = gated_off / ACTIVATION_ELEM_PER_THREAD;
23102226

23112227
ActFn fn{};
23122228
fn.alpha = gate_alpha;
23132229
fn.beta = gate_beta;
23142230
fn.limit = gate_limit;
2315-
2316-
#pragma unroll
23172231
for (int64_t elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
23182232
auto fc1_value =
23192233
arrayConvert<GemmResultElem, ComputeElem>(gemm_result_vec[elem_index + gated_off_vec]);

0 commit comments

Comments
 (0)