@@ -1053,7 +1053,7 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe(
1053
1053
std::optional<int64_t > n_group, std::optional<int64_t > topk_group, int64_t intermediate_size,
1054
1054
int64_t local_expert_offset, int64_t local_num_experts,
1055
1055
std::optional<double > routed_scaling_factor, int64_t tile_tokens_dim,
1056
- int64_t routing_method_type, bool do_finalize, at::Tensor& output) {
1056
+ int64_t routing_method_type, bool do_finalize, at::Tensor& output, int64_t config_index ) {
1057
1057
using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner;
1058
1058
1059
1059
int const num_tokens = hidden_states.sizes ()[0 ];
@@ -1106,8 +1106,10 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe(
1106
1106
mDtypeAct , mDtypeWeights , mUseDeepSeekFp8 , (int32_t )tile_tokens_dim,
1107
1107
tensorrt_llm::kernels::ActType::SwiGlu, /* useShuffledMatrixA*/ true );
1108
1108
1109
- auto const moeConfigIndex = mRunner ->getDefaultValidConfigIndex (
1110
- top_k, hidden_size, intermediate_size, local_num_experts, num_tokens);
1109
+ if (config_index == -1 ) {
1110
+ config_index = mRunner ->getDefaultValidConfigIndex (top_k, hidden_size, intermediate_size,
1111
+ local_num_experts, num_tokens);
1112
+ }
1111
1113
1112
1114
return trtllm_fp4_block_scale_moe_launcher (
1113
1115
routing_logits, topk_ids, expert_weights, routing_bias, hidden_states, hidden_states_scale,
@@ -1116,7 +1118,70 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe(
1116
1118
output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
1117
1119
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor,
1118
1120
tile_tokens_dim, routing_method_type, do_finalize, *mRunner , mDtypeAct , mDtypeWeights ,
1119
- moeConfigIndex, output);
1121
+ config_index, output);
1122
+ }
1123
+
1124
+ inline btg::Dtype get_dtype (int64_t const dtype) {
1125
+ switch (dtype) {
1126
+ case 0 :
1127
+ return btg::Dtype::Bfloat16;
1128
+ case 1 :
1129
+ return btg::Dtype::Bool;
1130
+ case 2 :
1131
+ return btg::Dtype::E2m1;
1132
+ case 3 :
1133
+ return btg::Dtype::E2m3;
1134
+ case 4 :
1135
+ return btg::Dtype::E3m2;
1136
+ case 5 :
1137
+ return btg::Dtype::E4m3;
1138
+ case 6 :
1139
+ return btg::Dtype::E5m2;
1140
+ case 7 :
1141
+ return btg::Dtype::Fp16;
1142
+ case 8 :
1143
+ return btg::Dtype::Fp32;
1144
+ case 9 :
1145
+ return btg::Dtype::Int8;
1146
+ case 10 :
1147
+ return btg::Dtype::Int32;
1148
+ case 11 :
1149
+ return btg::Dtype::Int64;
1150
+ case 12 :
1151
+ return btg::Dtype::MxE2m1;
1152
+ case 13 :
1153
+ return btg::Dtype::MxE4m3;
1154
+ case 14 :
1155
+ return btg::Dtype::UE8m0;
1156
+ case 15 :
1157
+ return btg::Dtype::UInt8;
1158
+ case 16 :
1159
+ return btg::Dtype::UInt16;
1160
+ case 17 :
1161
+ return btg::Dtype::UInt32;
1162
+ case 18 :
1163
+ return btg::Dtype::UInt64;
1164
+ case 19 :
1165
+ return btg::Dtype::UInt128;
1166
+ case 20 :
1167
+ return btg::Dtype::Void;
1168
+ default :
1169
+ TORCH_CHECK (false , " Invalid trtllm-gen dtype" );
1170
+ }
1171
+ return btg::Dtype::E2m1;
1172
+ }
1173
+
1174
+ std::vector<int64_t > trtllm_get_valid_moe_configs (
1175
+ int64_t const tile_tokens_dim, int64_t const dtype_act_, int64_t const dtype_weights_,
1176
+ bool const useDeepSeekFp8, int64_t const top_k, int64_t const hidden_size,
1177
+ int64_t const intermediate_size, int64_t const num_local_experts, int64_t const num_tokens) {
1178
+ auto dtype_act = get_dtype (dtype_act_);
1179
+ auto dtype_weights = get_dtype (dtype_weights_);
1180
+ tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner (
1181
+ dtype_act, dtype_weights, useDeepSeekFp8, (int32_t )tile_tokens_dim,
1182
+ tensorrt_llm::kernels::ActType::SwiGlu, /* useShuffledMatrixA*/ true );
1183
+ return moe_runner.getValidConfigIndices (top_k, hidden_size, intermediate_size, num_local_experts,
1184
+ num_tokens);
1120
1185
}
1121
1186
1122
1187
namespace trtllm_cubin_loader {
@@ -1127,6 +1192,7 @@ TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
1127
1192
m.def (" trtllm_fp8_per_tensor_scale_moe" , trtllm_fp8_per_tensor_scale_moe);
1128
1193
m.def (" trtllm_fp8_block_scale_moe" , trtllm_fp8_block_scale_moe);
1129
1194
m.def (" trtllm_fp4_block_scale_moe" , trtllm_fp4_block_scale_moe);
1195
+ m.def (" trtllm_get_valid_moe_configs" , trtllm_get_valid_moe_configs);
1130
1196
}
1131
1197
1132
1198
} // namespace flashinfer
0 commit comments