@@ -34,15 +34,17 @@ using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::computeSelectedTileN;
3434std::vector<torch::Tensor> run_fp4_block_scale_moe_runner (torch::optional<torch::Tensor> const & routing_logits,
3535 torch::optional<torch::Tensor> const & routing_bias, torch::Tensor const & hidden_states,
3636 torch::optional<torch::Tensor> const & hidden_states_scale, torch::Tensor const & gemm1_weights,
37- torch::Tensor const & gemm1_weights_scale, torch::Tensor const & gemm2_weights,
38- torch::Tensor const & gemm2_weights_scale, torch::Tensor const & output1_scales_scalar,
39- torch::Tensor const & output1_scales_gate_scalar, torch::Tensor const & output2_scales_scalar,
40- int64_t const num_experts, int64_t const top_k, std::optional<int64_t > const n_group,
41- std::optional<int64_t > const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset,
42- int64_t const local_num_experts, std::optional<double > const routed_scaling_factor, int64_t const tile_tokens_dim,
43- int64_t const routing_method_type, bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner,
44- int64_t const moeConfigIndex, torch::optional<torch::Tensor> const & topk_weights,
45- torch::optional<torch::Tensor> const & topk_ids)
37+ torch::Tensor const & gemm1_weights_scale, std::optional<torch::Tensor> const & gemm1_bias,
38+ std::optional<torch::Tensor> const & gemm1_alpha, std::optional<torch::Tensor> const & gemm1_beta,
39+ std::optional<torch::Tensor> const & gemm1_clamp_limit, torch::Tensor const & gemm2_weights,
40+ torch::Tensor const & gemm2_weights_scale, std::optional<torch::Tensor> const & gemm2_bias,
41+ torch::Tensor const & output1_scales_scalar, torch::Tensor const & output1_scales_gate_scalar,
42+ torch::Tensor const & output2_scales_scalar, int64_t const num_experts, int64_t const top_k,
43+ std::optional<int64_t > const n_group, std::optional<int64_t > const topk_group, int64_t const intermediate_size,
44+ int64_t const local_expert_offset, int64_t const local_num_experts,
45+ std::optional<double > const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type,
46+ bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner, int64_t const moeConfigIndex,
47+ torch::optional<torch::Tensor> const & topk_weights, torch::optional<torch::Tensor> const & topk_ids)
4648{
4749 TORCH_CHECK (dtype == btg::Dtype::E4m3 || dtype == btg::Dtype::E2m1, " dtype can only be e4m3 or e2m1." );
4850 TORCH_CHECK (tensorrt_llm::common::isSM100Family (), " Only SM100f is supported by FP4 block scale MOE" );
@@ -161,8 +163,13 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
161163
162164 args.gemm1_weights = gemm1_weights.data_ptr ();
163165 args.gemm1_weights_scale = gemm1_weights_scale.data_ptr ();
166+ args.gemm1_bias = gemm1_bias.has_value () ? gemm1_bias.value ().data_ptr <float >() : nullptr ;
167+ args.gemm1_alpha = gemm1_alpha.has_value () ? gemm1_alpha.value ().data_ptr <float >() : nullptr ;
168+ args.gemm1_beta = gemm1_beta.has_value () ? gemm1_beta.value ().data_ptr <float >() : nullptr ;
169+ args.gemm1_clamp_limit = gemm1_clamp_limit.has_value () ? gemm1_clamp_limit.value ().data_ptr <float >() : nullptr ;
164170 args.gemm2_weights = gemm2_weights.data_ptr ();
165171 args.gemm2_weights_scale = gemm2_weights_scale.data_ptr ();
172+ args.gemm2_bias = gemm2_bias.has_value () ? gemm2_bias.value ().data_ptr <float >() : nullptr ;
166173 args.num_tokens = hidden_states.sizes ()[0 ];
167174 args.num_experts = num_experts;
168175 if (dtype == btg::Dtype::E4m3)
@@ -313,6 +320,38 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
313320 TORCH_CHECK (intermediate_size % 16 == 0 , " the second dimension of weights must be a multiple of 16." );
314321 TORCH_CHECK (gemm1_weights_scale.sizes ()[1 ] == 2 * intermediate_size, " gemm1_weights_scale has incorrect dim 1." );
315322
323+ if (gemm1_bias.has_value ())
324+ {
325+ TORCH_CHECK (gemm1_bias.value ().scalar_type () == at::ScalarType::Float, " gemm1_bias must be float, got %s." ,
326+ c10::toString (gemm1_bias.value ().scalar_type ()));
327+ TORCH_CHECK (gemm1_bias.value ().dim () == 2 , " gemm1_bias must be 2D." );
328+ TORCH_CHECK (gemm1_bias.value ().sizes ()[0 ] == local_num_experts, " gemm1_bias has incorrect dim 0." );
329+ TORCH_CHECK (gemm1_bias.value ().sizes ()[1 ] == 2 * intermediate_size, " gemm1_bias has incorrect dim 1." );
330+ }
331+
332+ if (gemm1_alpha.has_value ())
333+ {
334+ TORCH_CHECK (gemm1_alpha.value ().scalar_type () == at::ScalarType::Float, " gemm1_alpha must be float, got %s." ,
335+ c10::toString (gemm1_alpha.value ().scalar_type ()));
336+ TORCH_CHECK (gemm1_alpha.value ().dim () == 1 , " gemm1_alpha must be 1D." );
337+ TORCH_CHECK (gemm1_alpha.value ().sizes ()[0 ] == local_num_experts, " gemm1_alpha has incorrect dim 0." );
338+ }
339+ if (gemm1_beta.has_value ())
340+ {
341+ TORCH_CHECK (gemm1_beta.value ().scalar_type () == at::ScalarType::Float, " gemm1_beta must be float, got %s." ,
342+ c10::toString (gemm1_beta.value ().scalar_type ()));
343+ TORCH_CHECK (gemm1_beta.value ().dim () == 1 , " gemm1_beta must be 1D." );
344+ TORCH_CHECK (gemm1_beta.value ().sizes ()[0 ] == local_num_experts, " gemm1_beta has incorrect dim 0." );
345+ }
346+ if (gemm1_clamp_limit.has_value ())
347+ {
348+ TORCH_CHECK (gemm1_clamp_limit.value ().scalar_type () == at::ScalarType::Float,
349+ " gemm1_clamp_limit must be float, got %s." , c10::toString (gemm1_clamp_limit.value ().scalar_type ()));
350+ TORCH_CHECK (gemm1_clamp_limit.value ().dim () == 1 , " gemm1_clamp_limit must be 1D." );
351+ TORCH_CHECK (
352+ gemm1_clamp_limit.value ().sizes ()[0 ] == local_num_experts, " gemm1_clamp_limit has incorrect dim 0." );
353+ }
354+
316355 TORCH_CHECK (gemm2_weights.scalar_type () == FLOAT4_E2M1X2, " gemm2_weights must be byte." );
317356
318357 TORCH_CHECK (gemm2_weights.dim () == 3 , " gemm2_weights must be 3D." );
@@ -322,6 +361,15 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
322361
323362 TORCH_CHECK (gemm2_weights_scale.scalar_type () == at::ScalarType::Float8_e4m3fn, " gemm2_weights_scale must be fp8." );
324363
364+ if (gemm2_bias.has_value ())
365+ {
366+ TORCH_CHECK (gemm2_bias.value ().scalar_type () == at::ScalarType::Float, " gemm2_bias must be float, got %s." ,
367+ c10::toString (gemm2_bias.value ().scalar_type ()));
368+ TORCH_CHECK (gemm2_bias.value ().dim () == 2 , " gemm2_bias must be 2D." );
369+ TORCH_CHECK (gemm2_bias.value ().sizes ()[0 ] == local_num_experts, " gemm2_bias has incorrect dim 0." );
370+ TORCH_CHECK (gemm2_bias.value ().sizes ()[1 ] == args.hidden_size , " gemm2_bias has incorrect dim 1." );
371+ }
372+
325373 TORCH_CHECK (gemm2_weights_scale.dim () == 3 , " gemm2_weights_scale must be 3D." );
326374 TORCH_CHECK (gemm2_weights_scale.sizes ()[0 ] == local_num_experts, " gemm2_weights_scale has incorrect dim 0." );
327375 TORCH_CHECK (gemm2_weights_scale.sizes ()[1 ] == args.hidden_size , " gemm2_weights_scale has incorrect dim 1." );
@@ -440,14 +488,17 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
440488 [[nodiscard]] std::vector<torch::Tensor> run (torch::optional<torch::Tensor> const & routing_logits,
441489 torch::optional<torch::Tensor> const & routing_bias, torch::Tensor const & hidden_states,
442490 torch::Tensor const & hidden_states_scale, torch::Tensor const & gemm1_weights,
443- torch::Tensor const & gemm1_weights_scale, torch::Tensor const & gemm2_weights,
444- torch::Tensor const & gemm2_weights_scale, torch::Tensor const & output1_scales_scalar,
445- torch::Tensor const & output1_scales_gate_scalar, torch::Tensor const & output2_scales_scalar,
446- int64_t const num_experts, int64_t const top_k, std::optional<int64_t > const n_group,
447- std::optional<int64_t > const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset,
448- int64_t const local_num_experts, std::optional<double > const routed_scaling_factor,
449- int64_t const routing_method_type, bool const do_finalize, std::vector<int64_t > moeConfigIndex,
450- torch::optional<torch::Tensor> const & topk_weights, torch::optional<torch::Tensor> const & topk_ids)
491+ torch::Tensor const & gemm1_weights_scale, std::optional<torch::Tensor> const & gemm1_bias,
492+ std::optional<torch::Tensor> const & gemm1_alpha, std::optional<torch::Tensor> const & gemm1_beta,
493+ std::optional<torch::Tensor> const & gemm1_clamp_limit, torch::Tensor const & gemm2_weights,
494+ torch::Tensor const & gemm2_weights_scale, std::optional<torch::Tensor> const & gemm2_bias,
495+ torch::Tensor const & output1_scales_scalar, torch::Tensor const & output1_scales_gate_scalar,
496+ torch::Tensor const & output2_scales_scalar, int64_t const num_experts, int64_t const top_k,
497+ std::optional<int64_t > const n_group, std::optional<int64_t > const topk_group, int64_t const intermediate_size,
498+ int64_t const local_expert_offset, int64_t const local_num_experts,
499+ std::optional<double > const routed_scaling_factor, int64_t const routing_method_type, bool const do_finalize,
500+ std::vector<int64_t > moeConfigIndex, torch::optional<torch::Tensor> const & topk_weights,
501+ torch::optional<torch::Tensor> const & topk_ids)
451502 {
452503 // moeConfigIndex corresponds to pair (tileN, config)
453504 auto [tileN, config] = std::tie (moeConfigIndex[0 ], moeConfigIndex[1 ]);
@@ -468,10 +519,11 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
468519 }
469520
470521 return run_fp4_block_scale_moe_runner (routing_logits, routing_bias, hidden_states, hidden_states_scale,
471- gemm1_weights, gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, output1_scales_scalar,
472- output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
473- intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN,
474- routing_method_type, do_finalize, mDtypeElt , *mRunners [tileN], config, topk_weights, topk_ids);
522+ gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights,
523+ gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar,
524+ num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
525+ routed_scaling_factor, tileN, routing_method_type, do_finalize, mDtypeElt , *mRunners [tileN], config,
526+ topk_weights, topk_ids);
475527 }
476528
477529private:
@@ -553,11 +605,11 @@ class FP8FP4BlockScaleMoeRunner : public torch::CustomClassHolder
553605 }
554606
555607 return run_fp4_block_scale_moe_runner (routing_logits, routing_bias, hidden_states,
556- std::nullopt /* hidden_states_scale*/ , gemm1_weights, gemm1_weights_scale, gemm2_weights ,
557- gemm2_weights_scale, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, num_experts ,
558- top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts ,
559- routed_scaling_factor, tileN, routing_method_type, do_finalize, mDtypeAct , * mRunners [ tileN], config ,
560- topk_weights, topk_ids);
608+ std::nullopt /* hidden_states_scale*/ , gemm1_weights, gemm1_weights_scale, std:: nullopt , std:: nullopt ,
609+ std:: nullopt , std:: nullopt , gemm2_weights, gemm2_weights_scale, std:: nullopt , output1_scales_scalar ,
610+ output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group ,
611+ intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN,
612+ routing_method_type, do_finalize, mDtypeAct , * mRunners [tileN], config, topk_weights, topk_ids);
561613 }
562614
563615private:
0 commit comments