@@ -34,14 +34,14 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
3434 torch::Tensor const & gemm1_weights_scale, std::optional<torch::Tensor> const & gemm1_bias,
3535 std::optional<torch::Tensor> const & gemm1_alpha, std::optional<torch::Tensor> const & gemm1_beta,
3636 std::optional<torch::Tensor> const & gemm1_clamp_limit, torch::Tensor const & gemm2_weights,
37- torch::Tensor const & gemm2_weights_scale, torch::Tensor const & output1_scales_scalar ,
38- torch::Tensor const & output1_scales_gate_scalar , torch::Tensor const & output2_scales_scalar ,
39- int64_t const num_experts , int64_t const top_k, std::optional< int64_t > const n_group ,
40- std::optional<int64_t > const topk_group, int64_t const intermediate_size , int64_t const local_expert_offset ,
41- int64_t const local_num_experts, std::optional< double > const routed_scaling_factor, int64_t const tile_tokens_dim ,
42- int64_t const routing_method_type, bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner ,
43- int64_t const moeConfigIndex, torch::optional<torch::Tensor> const & topk_weights ,
44- torch::optional<torch::Tensor> const & topk_ids)
37+ torch::Tensor const & gemm2_weights_scale, std::optional< torch::Tensor> const & gemm2_bias ,
38+ torch::Tensor const & output1_scales_scalar , torch::Tensor const & output1_scales_gate_scalar ,
39+ torch::Tensor const & output2_scales_scalar , int64_t const num_experts, int64_t const top_k ,
40+ std::optional<int64_t > const n_group, std::optional< int64_t > const topk_group , int64_t const intermediate_size ,
41+ int64_t const local_expert_offset, int64_t const local_num_experts ,
42+ std::optional< double > const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type ,
43+ bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner, int64_t const moeConfigIndex ,
44+ torch::optional<torch::Tensor> const & topk_weights, torch::optional<torch::Tensor> const & topk_ids)
4545{
4646 TORCH_CHECK (dtype == btg::Dtype::E4m3 || dtype == btg::Dtype::E2m1, " dtype can only be e4m3 or e2m1." );
4747 TORCH_CHECK (tensorrt_llm::common::isSM100Family (), " Only SM100f is supported by FP4 block scale MOE" );
@@ -166,6 +166,7 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
166166 args.gemm1_clamp_limit = gemm1_clamp_limit.has_value () ? gemm1_clamp_limit.value ().data_ptr <float >() : nullptr ;
167167 args.gemm2_weights = gemm2_weights.data_ptr ();
168168 args.gemm2_weights_scale = gemm2_weights_scale.data_ptr ();
169+ args.gemm2_bias = gemm2_bias.has_value () ? gemm2_bias.value ().data_ptr <float >() : nullptr ;
169170 args.num_tokens = hidden_states.sizes ()[0 ];
170171 args.num_experts = num_experts;
171172 if (dtype == btg::Dtype::E4m3)
@@ -357,6 +358,15 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
357358
358359 TORCH_CHECK (gemm2_weights_scale.scalar_type () == at::ScalarType::Float8_e4m3fn, " gemm2_weights_scale must be fp8." );
359360
361+ if (gemm2_bias.has_value ())
362+ {
363+ TORCH_CHECK (gemm2_bias.value ().scalar_type () == at::ScalarType::Float, " gemm2_bias must be float, got %s." ,
364+ c10::toString (gemm2_bias.value ().scalar_type ()));
365+ TORCH_CHECK (gemm2_bias.value ().dim () == 2 , " gemm2_bias must be 2D." );
366+ TORCH_CHECK (gemm2_bias.value ().sizes ()[0 ] == local_num_experts, " gemm2_bias has incorrect dim 0." );
367+ TORCH_CHECK (gemm2_bias.value ().sizes ()[1 ] == args.hidden_size , " gemm2_bias has incorrect dim 1." );
368+ }
369+
360370 TORCH_CHECK (gemm2_weights_scale.dim () == 3 , " gemm2_weights_scale must be 3D." );
361371 TORCH_CHECK (gemm2_weights_scale.sizes ()[0 ] == local_num_experts, " gemm2_weights_scale has incorrect dim 0." );
362372 TORCH_CHECK (gemm2_weights_scale.sizes ()[1 ] == args.hidden_size , " gemm2_weights_scale has incorrect dim 1." );
@@ -461,6 +471,29 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
461471 }
462472 std::cout << std::endl;
463473 }
474+ std::vector<float > gemm1_bias_vals;
475+ if (gemm1_bias.has_value ())
476+ {
477+ auto bias_cpu = gemm1_bias.value ().cpu ().contiguous ();
478+ float * bias_ptr = bias_cpu.data_ptr <float >();
479+ std::cout << " [FP4BlockScaleMoe] gemm1 bias: " ;
480+ for (int i = 0 ; i < std::min (local_num_experts * intermediate_size * 2 , int64_t (30 )); ++i)
481+ {
482+ std::cout << bias_ptr[i] << " " ;
483+ }
484+ std::cout << std::endl;
485+ }
486+ if (gemm2_bias.has_value ())
487+ {
488+ auto bias_cpu = gemm2_bias.value ().cpu ().contiguous ();
489+ float * bias_ptr = bias_cpu.data_ptr <float >();
490+ std::cout << " [FP4BlockScaleMoe] gemm2 bias: " ;
491+ for (int i = 0 ; i < std::min (local_num_experts * args.hidden_size , int64_t (30 )); ++i)
492+ {
493+ std::cout << bias_ptr[i] << " " ;
494+ }
495+ std::cout << std::endl;
496+ }
464497
465498 moe_runner.run (args, workspace, hidden_states.get_device (), moe_stream, moeConfigIndex);
466499
@@ -510,13 +543,14 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
510543 torch::Tensor const & gemm1_weights_scale, std::optional<torch::Tensor> const & gemm1_bias,
511544 std::optional<torch::Tensor> const & gemm1_alpha, std::optional<torch::Tensor> const & gemm1_beta,
512545 std::optional<torch::Tensor> const & gemm1_clamp_limit, torch::Tensor const & gemm2_weights,
513- torch::Tensor const & gemm2_weights_scale, torch::Tensor const & output1_scales_scalar,
514- torch::Tensor const & output1_scales_gate_scalar, torch::Tensor const & output2_scales_scalar,
515- int64_t const num_experts, int64_t const top_k, std::optional<int64_t > const n_group,
516- std::optional<int64_t > const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset,
517- int64_t const local_num_experts, std::optional<double > const routed_scaling_factor,
518- int64_t const routing_method_type, bool const do_finalize, std::vector<int64_t > moeConfigIndex,
519- torch::optional<torch::Tensor> const & topk_weights, torch::optional<torch::Tensor> const & topk_ids)
546+ torch::Tensor const & gemm2_weights_scale, std::optional<torch::Tensor> const & gemm2_bias,
547+ torch::Tensor const & output1_scales_scalar, torch::Tensor const & output1_scales_gate_scalar,
548+ torch::Tensor const & output2_scales_scalar, int64_t const num_experts, int64_t const top_k,
549+ std::optional<int64_t > const n_group, std::optional<int64_t > const topk_group, int64_t const intermediate_size,
550+ int64_t const local_expert_offset, int64_t const local_num_experts,
551+ std::optional<double > const routed_scaling_factor, int64_t const routing_method_type, bool const do_finalize,
552+ std::vector<int64_t > moeConfigIndex, torch::optional<torch::Tensor> const & topk_weights,
553+ torch::optional<torch::Tensor> const & topk_ids)
520554 {
521555 // moeConfigIndex corresponds to pair (tileN, config)
522556 auto [tileN, config] = std::tie (moeConfigIndex[0 ], moeConfigIndex[1 ]);
@@ -538,8 +572,8 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
538572
539573 return run_fp4_block_scale_moe_runner (routing_logits, routing_bias, hidden_states, hidden_states_scale,
540574 gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights,
541- gemm2_weights_scale, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, num_experts ,
542- top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
575+ gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar,
576+ num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
543577 routed_scaling_factor, tileN, routing_method_type, do_finalize, mDtypeElt , *mRunners [tileN], config,
544578 topk_weights, topk_ids);
545579 }
@@ -619,7 +653,7 @@ class FP8FP4BlockScaleMoeRunner : public torch::CustomClassHolder
619653
620654 return run_fp4_block_scale_moe_runner (routing_logits, routing_bias, hidden_states,
621655 std::nullopt /* hidden_states_scale*/ , gemm1_weights, gemm1_weights_scale, std::nullopt , std::nullopt ,
622- std::nullopt , std::nullopt , gemm2_weights, gemm2_weights_scale, output1_scales_scalar,
656+ std::nullopt , std::nullopt , gemm2_weights, gemm2_weights_scale, std:: nullopt , output1_scales_scalar,
623657 output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
624658 intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN,
625659 routing_method_type, do_finalize, mDtypeAct , *mRunners [tileN], config, topk_weights, topk_ids);
0 commit comments