2121#include " tensorrt_llm/thop/thUtils.h"
2222#include < ATen/cuda/EmptyTensor.h>
2323#include < ATen/ops/index_select.h>
24+ #include < iostream>
2425
2526namespace torch_ext
2627{
@@ -32,15 +33,17 @@ using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::computeSelectedTileN;
3233std::vector<torch::Tensor> run_fp4_block_scale_moe_runner (torch::optional<torch::Tensor> const & routing_logits,
3334 torch::optional<torch::Tensor> const & routing_bias, torch::Tensor const & hidden_states,
3435 torch::optional<torch::Tensor> const & hidden_states_scale, torch::Tensor const & gemm1_weights,
35- torch::Tensor const & gemm1_weights_scale, torch::Tensor const & gemm2_weights,
36- torch::Tensor const & gemm2_weights_scale, torch::Tensor const & output1_scales_scalar,
37- torch::Tensor const & output1_scales_gate_scalar, torch::Tensor const & output2_scales_scalar,
38- int64_t const num_experts, int64_t const top_k, std::optional<int64_t > const n_group,
39- std::optional<int64_t > const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset,
40- int64_t const local_num_experts, std::optional<double > const routed_scaling_factor, int64_t const tile_tokens_dim,
41- int64_t const routing_method_type, bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner,
42- int64_t const moeConfigIndex, torch::optional<torch::Tensor> const & topk_weights,
43- torch::optional<torch::Tensor> const & topk_ids)
36+ torch::Tensor const & gemm1_weights_scale, std::optional<torch::Tensor> const & gemm1_bias,
37+ std::optional<torch::Tensor> const & gemm1_alpha, std::optional<torch::Tensor> const & gemm1_beta,
38+ std::optional<torch::Tensor> const & gemm1_clamp_limit, torch::Tensor const & gemm2_weights,
39+ torch::Tensor const & gemm2_weights_scale, std::optional<torch::Tensor> const & gemm2_bias,
40+ torch::Tensor const & output1_scales_scalar, torch::Tensor const & output1_scales_gate_scalar,
41+ torch::Tensor const & output2_scales_scalar, int64_t const num_experts, int64_t const top_k,
42+ std::optional<int64_t > const n_group, std::optional<int64_t > const topk_group, int64_t const intermediate_size,
43+ int64_t const local_expert_offset, int64_t const local_num_experts,
44+ std::optional<double > const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type,
45+ bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner, int64_t const moeConfigIndex,
46+ torch::optional<torch::Tensor> const & topk_weights, torch::optional<torch::Tensor> const & topk_ids)
4447{
4548 TORCH_CHECK (dtype == btg::Dtype::E4m3 || dtype == btg::Dtype::E2m1, " dtype can only be e4m3 or e2m1." );
4649 TORCH_CHECK (tensorrt_llm::common::isSM100Family (), " Only SM100f is supported by FP4 block scale MOE" );
@@ -159,8 +162,13 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
159162
160163 args.gemm1_weights = gemm1_weights.data_ptr ();
161164 args.gemm1_weights_scale = gemm1_weights_scale.data_ptr ();
165+ args.gemm1_bias = gemm1_bias.has_value () ? gemm1_bias.value ().data_ptr <float >() : nullptr ;
166+ args.gemm1_alpha = gemm1_alpha.has_value () ? gemm1_alpha.value ().data_ptr <float >() : nullptr ;
167+ args.gemm1_beta = gemm1_beta.has_value () ? gemm1_beta.value ().data_ptr <float >() : nullptr ;
168+ args.gemm1_clamp_limit = gemm1_clamp_limit.has_value () ? gemm1_clamp_limit.value ().data_ptr <float >() : nullptr ;
162169 args.gemm2_weights = gemm2_weights.data_ptr ();
163170 args.gemm2_weights_scale = gemm2_weights_scale.data_ptr ();
171+ args.gemm2_bias = gemm2_bias.has_value () ? gemm2_bias.value ().data_ptr <float >() : nullptr ;
164172 args.num_tokens = hidden_states.sizes ()[0 ];
165173 args.num_experts = num_experts;
166174 if (dtype == btg::Dtype::E4m3)
@@ -311,6 +319,38 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
311319 TORCH_CHECK (intermediate_size % 16 == 0 , " the second dimension of weights must be a multiple of 16." );
312320 TORCH_CHECK (gemm1_weights_scale.sizes ()[1 ] == 2 * intermediate_size, " gemm1_weights_scale has incorrect dim 1." );
313321
322+ if (gemm1_bias.has_value ())
323+ {
324+ TORCH_CHECK (gemm1_bias.value ().scalar_type () == at::ScalarType::Float, " gemm1_bias must be float, got %s." ,
325+ c10::toString (gemm1_bias.value ().scalar_type ()));
326+ TORCH_CHECK (gemm1_bias.value ().dim () == 2 , " gemm1_bias must be 2D." );
327+ TORCH_CHECK (gemm1_bias.value ().sizes ()[0 ] == local_num_experts, " gemm1_bias has incorrect dim 0." );
328+ TORCH_CHECK (gemm1_bias.value ().sizes ()[1 ] == 2 * intermediate_size, " gemm1_bias has incorrect dim 1." );
329+ }
330+
331+ if (gemm1_alpha.has_value ())
332+ {
333+ TORCH_CHECK (gemm1_alpha.value ().scalar_type () == at::ScalarType::Float, " gemm1_alpha must be float, got %s." ,
334+ c10::toString (gemm1_alpha.value ().scalar_type ()));
335+ TORCH_CHECK (gemm1_alpha.value ().dim () == 1 , " gemm1_alpha must be 1D." );
336+ TORCH_CHECK (gemm1_alpha.value ().sizes ()[0 ] == local_num_experts, " gemm1_alpha has incorrect dim 0." );
337+ }
338+ if (gemm1_beta.has_value ())
339+ {
340+ TORCH_CHECK (gemm1_beta.value ().scalar_type () == at::ScalarType::Float, " gemm1_beta must be float, got %s." ,
341+ c10::toString (gemm1_beta.value ().scalar_type ()));
342+ TORCH_CHECK (gemm1_beta.value ().dim () == 1 , " gemm1_beta must be 1D." );
343+ TORCH_CHECK (gemm1_beta.value ().sizes ()[0 ] == local_num_experts, " gemm1_beta has incorrect dim 0." );
344+ }
345+ if (gemm1_clamp_limit.has_value ())
346+ {
347+ TORCH_CHECK (gemm1_clamp_limit.value ().scalar_type () == at::ScalarType::Float,
348+ " gemm1_clamp_limit must be float, got %s." , c10::toString (gemm1_clamp_limit.value ().scalar_type ()));
349+ TORCH_CHECK (gemm1_clamp_limit.value ().dim () == 1 , " gemm1_clamp_limit must be 1D." );
350+ TORCH_CHECK (
351+ gemm1_clamp_limit.value ().sizes ()[0 ] == local_num_experts, " gemm1_clamp_limit has incorrect dim 0." );
352+ }
353+
314354 TORCH_CHECK (gemm2_weights.scalar_type () == FLOAT4_E2M1X2, " gemm2_weights must be byte." );
315355
316356 TORCH_CHECK (gemm2_weights.dim () == 3 , " gemm2_weights must be 3D." );
@@ -320,6 +360,15 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
320360
321361 TORCH_CHECK (gemm2_weights_scale.scalar_type () == at::ScalarType::Float8_e4m3fn, " gemm2_weights_scale must be fp8." );
322362
363+ if (gemm2_bias.has_value ())
364+ {
365+ TORCH_CHECK (gemm2_bias.value ().scalar_type () == at::ScalarType::Float, " gemm2_bias must be float, got %s." ,
366+ c10::toString (gemm2_bias.value ().scalar_type ()));
367+ TORCH_CHECK (gemm2_bias.value ().dim () == 2 , " gemm2_bias must be 2D." );
368+ TORCH_CHECK (gemm2_bias.value ().sizes ()[0 ] == local_num_experts, " gemm2_bias has incorrect dim 0." );
369+ TORCH_CHECK (gemm2_bias.value ().sizes ()[1 ] == args.hidden_size , " gemm2_bias has incorrect dim 1." );
370+ }
371+
323372 TORCH_CHECK (gemm2_weights_scale.dim () == 3 , " gemm2_weights_scale must be 3D." );
324373 TORCH_CHECK (gemm2_weights_scale.sizes ()[0 ] == local_num_experts, " gemm2_weights_scale has incorrect dim 0." );
325374 TORCH_CHECK (gemm2_weights_scale.sizes ()[1 ] == args.hidden_size , " gemm2_weights_scale has incorrect dim 1." );
@@ -405,7 +454,7 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
405454public:
406455 explicit FP4BlockScaleMoeRunner ()
407456 // Update this as new cubins come in
408- : mSupportedTileN{8 , 16 , 32 , 64 , 128 , 256 }
457+ : mSupportedTileN{8 , 16 , 32 , 64 , 128 }
409458 {
410459 for (int tileN : mSupportedTileN )
411460 {
@@ -438,14 +487,17 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
438487 [[nodiscard]] std::vector<torch::Tensor> run (torch::optional<torch::Tensor> const & routing_logits,
439488 torch::optional<torch::Tensor> const & routing_bias, torch::Tensor const & hidden_states,
440489 torch::Tensor const & hidden_states_scale, torch::Tensor const & gemm1_weights,
441- torch::Tensor const & gemm1_weights_scale, torch::Tensor const & gemm2_weights,
442- torch::Tensor const & gemm2_weights_scale, torch::Tensor const & output1_scales_scalar,
443- torch::Tensor const & output1_scales_gate_scalar, torch::Tensor const & output2_scales_scalar,
444- int64_t const num_experts, int64_t const top_k, std::optional<int64_t > const n_group,
445- std::optional<int64_t > const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset,
446- int64_t const local_num_experts, std::optional<double > const routed_scaling_factor,
447- int64_t const routing_method_type, bool const do_finalize, std::vector<int64_t > moeConfigIndex,
448- torch::optional<torch::Tensor> const & topk_weights, torch::optional<torch::Tensor> const & topk_ids)
490+ torch::Tensor const & gemm1_weights_scale, std::optional<torch::Tensor> const & gemm1_bias,
491+ std::optional<torch::Tensor> const & gemm1_alpha, std::optional<torch::Tensor> const & gemm1_beta,
492+ std::optional<torch::Tensor> const & gemm1_clamp_limit, torch::Tensor const & gemm2_weights,
493+ torch::Tensor const & gemm2_weights_scale, std::optional<torch::Tensor> const & gemm2_bias,
494+ torch::Tensor const & output1_scales_scalar, torch::Tensor const & output1_scales_gate_scalar,
495+ torch::Tensor const & output2_scales_scalar, int64_t const num_experts, int64_t const top_k,
496+ std::optional<int64_t > const n_group, std::optional<int64_t > const topk_group, int64_t const intermediate_size,
497+ int64_t const local_expert_offset, int64_t const local_num_experts,
498+ std::optional<double > const routed_scaling_factor, int64_t const routing_method_type, bool const do_finalize,
499+ std::vector<int64_t > moeConfigIndex, torch::optional<torch::Tensor> const & topk_weights,
500+ torch::optional<torch::Tensor> const & topk_ids)
449501 {
450502 // moeConfigIndex corresponds to pair (tileN, config)
451503 auto [tileN, config] = std::tie (moeConfigIndex[0 ], moeConfigIndex[1 ]);
@@ -466,10 +518,11 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
466518 }
467519
468520 return run_fp4_block_scale_moe_runner (routing_logits, routing_bias, hidden_states, hidden_states_scale,
469- gemm1_weights, gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, output1_scales_scalar,
470- output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
471- intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN,
472- routing_method_type, do_finalize, mDtypeElt , *mRunners [tileN], config, topk_weights, topk_ids);
521+ gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights,
522+ gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar,
523+ num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
524+ routed_scaling_factor, tileN, routing_method_type, do_finalize, mDtypeElt , *mRunners [tileN], config,
525+ topk_weights, topk_ids);
473526 }
474527
475528private:
@@ -551,11 +604,11 @@ class FP8FP4BlockScaleMoeRunner : public torch::CustomClassHolder
551604 }
552605
553606 return run_fp4_block_scale_moe_runner (routing_logits, routing_bias, hidden_states,
554- std::nullopt /* hidden_states_scale*/ , gemm1_weights, gemm1_weights_scale, gemm2_weights ,
555- gemm2_weights_scale, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, num_experts ,
556- top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts ,
557- routed_scaling_factor, tileN, routing_method_type, do_finalize, mDtypeAct , * mRunners [ tileN], config ,
558- topk_weights, topk_ids);
607+ std::nullopt /* hidden_states_scale*/ , gemm1_weights, gemm1_weights_scale, std:: nullopt , std:: nullopt ,
608+ std:: nullopt , std:: nullopt , gemm2_weights, gemm2_weights_scale, std:: nullopt , output1_scales_scalar ,
609+ output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group ,
610+ intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN,
611+ routing_method_type, do_finalize, mDtypeAct , * mRunners [tileN], config, topk_weights, topk_ids);
559612 }
560613
561614private:
0 commit comments