Skip to content

Commit c4e53d3

Browse files
committed
Squash all commits
Signed-off-by: Dongfeng Yu <[email protected]>
1 parent 5845951 commit c4e53d3

File tree

9 files changed

+946
-247
lines changed

9 files changed

+946
-247
lines changed

cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp

Lines changed: 78 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,17 @@ using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::computeSelectedTileN;
3434
std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch::Tensor> const& routing_logits,
3535
torch::optional<torch::Tensor> const& routing_bias, torch::Tensor const& hidden_states,
3636
torch::optional<torch::Tensor> const& hidden_states_scale, torch::Tensor const& gemm1_weights,
37-
torch::Tensor const& gemm1_weights_scale, torch::Tensor const& gemm2_weights,
38-
torch::Tensor const& gemm2_weights_scale, torch::Tensor const& output1_scales_scalar,
39-
torch::Tensor const& output1_scales_gate_scalar, torch::Tensor const& output2_scales_scalar,
40-
int64_t const num_experts, int64_t const top_k, std::optional<int64_t> const n_group,
41-
std::optional<int64_t> const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset,
42-
int64_t const local_num_experts, std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim,
43-
int64_t const routing_method_type, bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner,
44-
int64_t const moeConfigIndex, torch::optional<torch::Tensor> const& topk_weights,
45-
torch::optional<torch::Tensor> const& topk_ids)
37+
torch::Tensor const& gemm1_weights_scale, std::optional<torch::Tensor> const& gemm1_bias,
38+
std::optional<torch::Tensor> const& gemm1_alpha, std::optional<torch::Tensor> const& gemm1_beta,
39+
std::optional<torch::Tensor> const& gemm1_clamp_limit, torch::Tensor const& gemm2_weights,
40+
torch::Tensor const& gemm2_weights_scale, std::optional<torch::Tensor> const& gemm2_bias,
41+
torch::Tensor const& output1_scales_scalar, torch::Tensor const& output1_scales_gate_scalar,
42+
torch::Tensor const& output2_scales_scalar, int64_t const num_experts, int64_t const top_k,
43+
std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t const intermediate_size,
44+
int64_t const local_expert_offset, int64_t const local_num_experts,
45+
std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type,
46+
bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner, int64_t const moeConfigIndex,
47+
torch::optional<torch::Tensor> const& topk_weights, torch::optional<torch::Tensor> const& topk_ids)
4648
{
4749
TORCH_CHECK(dtype == btg::Dtype::E4m3 || dtype == btg::Dtype::E2m1, "dtype can only be e4m3 or e2m1.");
4850
TORCH_CHECK(tensorrt_llm::common::isSM100Family(), "Only SM100f is supported by FP4 block scale MOE");
@@ -161,8 +163,13 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
161163

162164
args.gemm1_weights = gemm1_weights.data_ptr();
163165
args.gemm1_weights_scale = gemm1_weights_scale.data_ptr();
166+
args.gemm1_bias = gemm1_bias.has_value() ? gemm1_bias.value().data_ptr<float>() : nullptr;
167+
args.gemm1_alpha = gemm1_alpha.has_value() ? gemm1_alpha.value().data_ptr<float>() : nullptr;
168+
args.gemm1_beta = gemm1_beta.has_value() ? gemm1_beta.value().data_ptr<float>() : nullptr;
169+
args.gemm1_clamp_limit = gemm1_clamp_limit.has_value() ? gemm1_clamp_limit.value().data_ptr<float>() : nullptr;
164170
args.gemm2_weights = gemm2_weights.data_ptr();
165171
args.gemm2_weights_scale = gemm2_weights_scale.data_ptr();
172+
args.gemm2_bias = gemm2_bias.has_value() ? gemm2_bias.value().data_ptr<float>() : nullptr;
166173
args.num_tokens = hidden_states.sizes()[0];
167174
args.num_experts = num_experts;
168175
if (dtype == btg::Dtype::E4m3)
@@ -313,6 +320,38 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
313320
TORCH_CHECK(intermediate_size % 16 == 0, "the second dimension of weights must be a multiple of 16.");
314321
TORCH_CHECK(gemm1_weights_scale.sizes()[1] == 2 * intermediate_size, "gemm1_weights_scale has incorrect dim 1.");
315322

323+
if (gemm1_bias.has_value())
324+
{
325+
TORCH_CHECK(gemm1_bias.value().scalar_type() == at::ScalarType::Float, "gemm1_bias must be float, got %s.",
326+
c10::toString(gemm1_bias.value().scalar_type()));
327+
TORCH_CHECK(gemm1_bias.value().dim() == 2, "gemm1_bias must be 2D.");
328+
TORCH_CHECK(gemm1_bias.value().sizes()[0] == local_num_experts, "gemm1_bias has incorrect dim 0.");
329+
TORCH_CHECK(gemm1_bias.value().sizes()[1] == 2 * intermediate_size, "gemm1_bias has incorrect dim 1.");
330+
}
331+
332+
if (gemm1_alpha.has_value())
333+
{
334+
TORCH_CHECK(gemm1_alpha.value().scalar_type() == at::ScalarType::Float, "gemm1_alpha must be float, got %s.",
335+
c10::toString(gemm1_alpha.value().scalar_type()));
336+
TORCH_CHECK(gemm1_alpha.value().dim() == 1, "gemm1_alpha must be 1D.");
337+
TORCH_CHECK(gemm1_alpha.value().sizes()[0] == local_num_experts, "gemm1_alpha has incorrect dim 0.");
338+
}
339+
if (gemm1_beta.has_value())
340+
{
341+
TORCH_CHECK(gemm1_beta.value().scalar_type() == at::ScalarType::Float, "gemm1_beta must be float, got %s.",
342+
c10::toString(gemm1_beta.value().scalar_type()));
343+
TORCH_CHECK(gemm1_beta.value().dim() == 1, "gemm1_beta must be 1D.");
344+
TORCH_CHECK(gemm1_beta.value().sizes()[0] == local_num_experts, "gemm1_beta has incorrect dim 0.");
345+
}
346+
if (gemm1_clamp_limit.has_value())
347+
{
348+
TORCH_CHECK(gemm1_clamp_limit.value().scalar_type() == at::ScalarType::Float,
349+
"gemm1_clamp_limit must be float, got %s.", c10::toString(gemm1_clamp_limit.value().scalar_type()));
350+
TORCH_CHECK(gemm1_clamp_limit.value().dim() == 1, "gemm1_clamp_limit must be 1D.");
351+
TORCH_CHECK(
352+
gemm1_clamp_limit.value().sizes()[0] == local_num_experts, "gemm1_clamp_limit has incorrect dim 0.");
353+
}
354+
316355
TORCH_CHECK(gemm2_weights.scalar_type() == FLOAT4_E2M1X2, "gemm2_weights must be byte.");
317356

318357
TORCH_CHECK(gemm2_weights.dim() == 3, "gemm2_weights must be 3D.");
@@ -322,6 +361,15 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
322361

323362
TORCH_CHECK(gemm2_weights_scale.scalar_type() == at::ScalarType::Float8_e4m3fn, "gemm2_weights_scale must be fp8.");
324363

364+
if (gemm2_bias.has_value())
365+
{
366+
TORCH_CHECK(gemm2_bias.value().scalar_type() == at::ScalarType::Float, "gemm2_bias must be float, got %s.",
367+
c10::toString(gemm2_bias.value().scalar_type()));
368+
TORCH_CHECK(gemm2_bias.value().dim() == 2, "gemm2_bias must be 2D.");
369+
TORCH_CHECK(gemm2_bias.value().sizes()[0] == local_num_experts, "gemm2_bias has incorrect dim 0.");
370+
TORCH_CHECK(gemm2_bias.value().sizes()[1] == args.hidden_size, "gemm2_bias has incorrect dim 1.");
371+
}
372+
325373
TORCH_CHECK(gemm2_weights_scale.dim() == 3, "gemm2_weights_scale must be 3D.");
326374
TORCH_CHECK(gemm2_weights_scale.sizes()[0] == local_num_experts, "gemm2_weights_scale has incorrect dim 0.");
327375
TORCH_CHECK(gemm2_weights_scale.sizes()[1] == args.hidden_size, "gemm2_weights_scale has incorrect dim 1.");
@@ -440,14 +488,17 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
440488
[[nodiscard]] std::vector<torch::Tensor> run(torch::optional<torch::Tensor> const& routing_logits,
441489
torch::optional<torch::Tensor> const& routing_bias, torch::Tensor const& hidden_states,
442490
torch::Tensor const& hidden_states_scale, torch::Tensor const& gemm1_weights,
443-
torch::Tensor const& gemm1_weights_scale, torch::Tensor const& gemm2_weights,
444-
torch::Tensor const& gemm2_weights_scale, torch::Tensor const& output1_scales_scalar,
445-
torch::Tensor const& output1_scales_gate_scalar, torch::Tensor const& output2_scales_scalar,
446-
int64_t const num_experts, int64_t const top_k, std::optional<int64_t> const n_group,
447-
std::optional<int64_t> const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset,
448-
int64_t const local_num_experts, std::optional<double> const routed_scaling_factor,
449-
int64_t const routing_method_type, bool const do_finalize, std::vector<int64_t> moeConfigIndex,
450-
torch::optional<torch::Tensor> const& topk_weights, torch::optional<torch::Tensor> const& topk_ids)
491+
torch::Tensor const& gemm1_weights_scale, std::optional<torch::Tensor> const& gemm1_bias,
492+
std::optional<torch::Tensor> const& gemm1_alpha, std::optional<torch::Tensor> const& gemm1_beta,
493+
std::optional<torch::Tensor> const& gemm1_clamp_limit, torch::Tensor const& gemm2_weights,
494+
torch::Tensor const& gemm2_weights_scale, std::optional<torch::Tensor> const& gemm2_bias,
495+
torch::Tensor const& output1_scales_scalar, torch::Tensor const& output1_scales_gate_scalar,
496+
torch::Tensor const& output2_scales_scalar, int64_t const num_experts, int64_t const top_k,
497+
std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t const intermediate_size,
498+
int64_t const local_expert_offset, int64_t const local_num_experts,
499+
std::optional<double> const routed_scaling_factor, int64_t const routing_method_type, bool const do_finalize,
500+
std::vector<int64_t> moeConfigIndex, torch::optional<torch::Tensor> const& topk_weights,
501+
torch::optional<torch::Tensor> const& topk_ids)
451502
{
452503
// moeConfigIndex corresponds to pair (tileN, config)
453504
auto [tileN, config] = std::tie(moeConfigIndex[0], moeConfigIndex[1]);
@@ -468,10 +519,11 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
468519
}
469520

470521
return run_fp4_block_scale_moe_runner(routing_logits, routing_bias, hidden_states, hidden_states_scale,
471-
gemm1_weights, gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, output1_scales_scalar,
472-
output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
473-
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN,
474-
routing_method_type, do_finalize, mDtypeElt, *mRunners[tileN], config, topk_weights, topk_ids);
522+
gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights,
523+
gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar,
524+
num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
525+
routed_scaling_factor, tileN, routing_method_type, do_finalize, mDtypeElt, *mRunners[tileN], config,
526+
topk_weights, topk_ids);
475527
}
476528

477529
private:
@@ -553,11 +605,11 @@ class FP8FP4BlockScaleMoeRunner : public torch::CustomClassHolder
553605
}
554606

555607
return run_fp4_block_scale_moe_runner(routing_logits, routing_bias, hidden_states,
556-
std::nullopt /*hidden_states_scale*/, gemm1_weights, gemm1_weights_scale, gemm2_weights,
557-
gemm2_weights_scale, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, num_experts,
558-
top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
559-
routed_scaling_factor, tileN, routing_method_type, do_finalize, mDtypeAct, *mRunners[tileN], config,
560-
topk_weights, topk_ids);
608+
std::nullopt /*hidden_states_scale*/, gemm1_weights, gemm1_weights_scale, std::nullopt, std::nullopt,
609+
std::nullopt, std::nullopt, gemm2_weights, gemm2_weights_scale, std::nullopt, output1_scales_scalar,
610+
output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
611+
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN,
612+
routing_method_type, do_finalize, mDtypeAct, *mRunners[tileN], config, topk_weights, topk_ids);
561613
}
562614

563615
private:

tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,13 @@ class FP4BlockScaleMoEInputs:
176176
hidden_states_scale: torch.Tensor
177177
gemm1_weights: torch.Tensor
178178
gemm1_weights_scale: torch.Tensor
179+
gemm1_bias: torch.Tensor
180+
gemm1_alpha: torch.Tensor
181+
gemm1_beta: torch.Tensor
182+
gemm1_clamp_limit: torch.Tensor
179183
gemm2_weights: torch.Tensor
180184
gemm2_weights_scale: torch.Tensor
185+
gemm2_bias: torch.Tensor
181186
output1_scale_scalar: torch.Tensor
182187
output1_scale_gate_scalar: torch.Tensor
183188
output2_scale_scalar: torch.Tensor
@@ -235,14 +240,15 @@ def forward(
235240
return kernel_runner.run_moe(
236241
args.routing_logits, args.routing_bias, args.hidden_states,
237242
args.hidden_states_scale, args.gemm1_weights,
238-
args.gemm1_weights_scale, args.gemm2_weights,
239-
args.gemm2_weights_scale, args.output1_scale_scalar,
240-
args.output1_scale_gate_scalar, args.output2_scale_scalar,
241-
self.num_experts, self.top_k, self.n_group, self.topk_group,
242-
self.intermediate_size, self.local_expert_offset,
243-
self.local_num_experts, self.routed_scaling_factor,
244-
self.routing_method_type, self.do_finalize, tactic,
245-
args.topk_weights, args.topk_ids)
243+
args.gemm1_weights_scale, args.gemm1_bias, args.gemm1_alpha,
244+
args.gemm1_beta, args.gemm1_clamp_limit, args.gemm2_weights,
245+
args.gemm2_weights_scale, args.gemm2_bias,
246+
args.output1_scale_scalar, args.output1_scale_gate_scalar,
247+
args.output2_scale_scalar, self.num_experts, self.top_k,
248+
self.n_group, self.topk_group, self.intermediate_size,
249+
self.local_expert_offset, self.local_num_experts,
250+
self.routed_scaling_factor, self.routing_method_type,
251+
self.do_finalize, tactic, args.topk_weights, args.topk_ids)
246252

247253
def get_valid_tactics(self, inputs: List[torch.Tensor],
248254
profile: OptimizationProfile,
@@ -359,8 +365,13 @@ def fp4_block_scale_moe_runner(
359365
hidden_states_scale: torch.Tensor,
360366
gemm1_weights: torch.Tensor,
361367
gemm1_weights_scale: torch.Tensor,
368+
gemm1_bias: torch.Tensor,
369+
gemm1_alpha: torch.Tensor,
370+
gemm1_beta: torch.Tensor,
371+
gemm1_clamp_limit: torch.Tensor,
362372
gemm2_weights: torch.Tensor,
363373
gemm2_weights_scale: torch.Tensor,
374+
gemm2_bias: torch.Tensor,
364375
output1_scale_scalar: torch.Tensor,
365376
output1_scale_gate_scalar: torch.Tensor,
366377
output2_scale_scalar: torch.Tensor,
@@ -416,8 +427,13 @@ def fp4_block_scale_moe_runner(
416427
hidden_states_scale,
417428
gemm1_weights,
418429
gemm1_weights_scale,
430+
gemm1_bias,
431+
gemm1_alpha,
432+
gemm1_beta,
433+
gemm1_clamp_limit,
419434
gemm2_weights,
420435
gemm2_weights_scale,
436+
gemm2_bias,
421437
output1_scale_scalar,
422438
output1_scale_gate_scalar,
423439
output2_scale_scalar,
@@ -474,8 +490,13 @@ def _(routing_logits,
474490
hidden_states_scale,
475491
gemm1_weights,
476492
gemm1_weights_scale,
493+
gemm1_bias,
494+
gemm1_alpha,
495+
gemm1_beta,
496+
gemm1_clamp_limit,
477497
gemm2_weights,
478498
gemm2_weights_scale,
499+
gemm2_bias,
479500
output1_scale_scalar,
480501
output1_scale_gate_scalar,
481502
output2_scale_scalar,

0 commit comments

Comments
 (0)