Skip to content

Commit c6edf1a

Browse files
committed
WIP
1 parent cc46992 commit c6edf1a

File tree

3 files changed

+453
-5
lines changed

3 files changed

+453
-5
lines changed

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,7 +1053,7 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe(
10531053
std::optional<int64_t> n_group, std::optional<int64_t> topk_group, int64_t intermediate_size,
10541054
int64_t local_expert_offset, int64_t local_num_experts,
10551055
std::optional<double> routed_scaling_factor, int64_t tile_tokens_dim,
1056-
int64_t routing_method_type, bool do_finalize, at::Tensor& output) {
1056+
int64_t routing_method_type, bool do_finalize, at::Tensor& output, int64_t config_index) {
10571057
using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner;
10581058

10591059
int const num_tokens = hidden_states.sizes()[0];
@@ -1106,8 +1106,10 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe(
11061106
mDtypeAct, mDtypeWeights, mUseDeepSeekFp8, (int32_t)tile_tokens_dim,
11071107
tensorrt_llm::kernels::ActType::SwiGlu, /*useShuffledMatrixA*/ true);
11081108

1109-
auto const moeConfigIndex = mRunner->getDefaultValidConfigIndex(
1110-
top_k, hidden_size, intermediate_size, local_num_experts, num_tokens);
1109+
if (config_index == -1) {
1110+
config_index = mRunner->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size,
1111+
local_num_experts, num_tokens);
1112+
}
11111113

11121114
return trtllm_fp4_block_scale_moe_launcher(
11131115
routing_logits, topk_ids, expert_weights, routing_bias, hidden_states, hidden_states_scale,
@@ -1116,7 +1118,70 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe(
11161118
output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
11171119
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor,
11181120
tile_tokens_dim, routing_method_type, do_finalize, *mRunner, mDtypeAct, mDtypeWeights,
1119-
moeConfigIndex, output);
1121+
config_index, output);
1122+
}
1123+
1124+
inline btg::Dtype get_dtype(int64_t const dtype) {
1125+
switch (dtype) {
1126+
case 0:
1127+
return btg::Dtype::Bfloat16;
1128+
case 1:
1129+
return btg::Dtype::Bool;
1130+
case 2:
1131+
return btg::Dtype::E2m1;
1132+
case 3:
1133+
return btg::Dtype::E2m3;
1134+
case 4:
1135+
return btg::Dtype::E3m2;
1136+
case 5:
1137+
return btg::Dtype::E4m3;
1138+
case 6:
1139+
return btg::Dtype::E5m2;
1140+
case 7:
1141+
return btg::Dtype::Fp16;
1142+
case 8:
1143+
return btg::Dtype::Fp32;
1144+
case 9:
1145+
return btg::Dtype::Int8;
1146+
case 10:
1147+
return btg::Dtype::Int32;
1148+
case 11:
1149+
return btg::Dtype::Int64;
1150+
case 12:
1151+
return btg::Dtype::MxE2m1;
1152+
case 13:
1153+
return btg::Dtype::MxE4m3;
1154+
case 14:
1155+
return btg::Dtype::UE8m0;
1156+
case 15:
1157+
return btg::Dtype::UInt8;
1158+
case 16:
1159+
return btg::Dtype::UInt16;
1160+
case 17:
1161+
return btg::Dtype::UInt32;
1162+
case 18:
1163+
return btg::Dtype::UInt64;
1164+
case 19:
1165+
return btg::Dtype::UInt128;
1166+
case 20:
1167+
return btg::Dtype::Void;
1168+
default:
1169+
TORCH_CHECK(false, "Invalid trtllm-gen dtype");
1170+
}
1171+
return btg::Dtype::E2m1;
1172+
}
1173+
1174+
std::vector<int64_t> trtllm_get_valid_moe_configs(
1175+
int64_t const tile_tokens_dim, int64_t const dtype_act_, int64_t const dtype_weights_,
1176+
bool const useDeepSeekFp8, int64_t const top_k, int64_t const hidden_size,
1177+
int64_t const intermediate_size, int64_t const num_local_experts, int64_t const num_tokens) {
1178+
auto dtype_act = get_dtype(dtype_act_);
1179+
auto dtype_weights = get_dtype(dtype_weights_);
1180+
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner(
1181+
dtype_act, dtype_weights, useDeepSeekFp8, (int32_t)tile_tokens_dim,
1182+
tensorrt_llm::kernels::ActType::SwiGlu, /*useShuffledMatrixA*/ true);
1183+
return moe_runner.getValidConfigIndices(top_k, hidden_size, intermediate_size, num_local_experts,
1184+
num_tokens);
11201185
}
11211186

11221187
namespace trtllm_cubin_loader {
@@ -1127,6 +1192,7 @@ TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
11271192
m.def("trtllm_fp8_per_tensor_scale_moe", trtllm_fp8_per_tensor_scale_moe);
11281193
m.def("trtllm_fp8_block_scale_moe", trtllm_fp8_block_scale_moe);
11291194
m.def("trtllm_fp4_block_scale_moe", trtllm_fp4_block_scale_moe);
1195+
m.def("trtllm_get_valid_moe_configs", trtllm_get_valid_moe_configs);
11301196
}
11311197

11321198
} // namespace flashinfer

flashinfer/autotuner.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,15 @@ def _create_tensor_like(
707707
# TODO: FIXME, sometimes the content of the tensor can affect the performance, like MOE
708708
# One solution is to manituplate the tensor content to make it more like the real data
709709
# during the tuning process. This can by controlled in the preparation phase by the runner.
710-
return torch.zeros(shapes, dtype=dtype, device=device)
710+
# return torch.zeros(shapes, dtype=dtype, device=device)
711+
if dtype == torch.int8:
712+
return torch.randint(0, 127, shapes, dtype=dtype, device=device)
713+
elif dtype == torch.uint8:
714+
return torch.randint(0, 255, shapes, dtype=dtype, device=device)
715+
elif dtype == torch.int32:
716+
return torch.randint(0, 1000000, shapes, dtype=dtype, device=device)
717+
else:
718+
return torch.randn(shapes, dtype=dtype, device=device)
711719

712720
def _prepare_input_tensors(
713721
self, profile: OptimizationProfile, inputs: List[torch.Tensor]

0 commit comments

Comments
 (0)