Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 78 additions & 26 deletions cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,17 @@ using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::computeSelectedTileN;
std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch::Tensor> const& routing_logits,
torch::optional<torch::Tensor> const& routing_bias, torch::Tensor const& hidden_states,
torch::optional<torch::Tensor> const& hidden_states_scale, torch::Tensor const& gemm1_weights,
torch::Tensor const& gemm1_weights_scale, torch::Tensor const& gemm2_weights,
torch::Tensor const& gemm2_weights_scale, torch::Tensor const& output1_scales_scalar,
torch::Tensor const& output1_scales_gate_scalar, torch::Tensor const& output2_scales_scalar,
int64_t const num_experts, int64_t const top_k, std::optional<int64_t> const n_group,
std::optional<int64_t> const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset,
int64_t const local_num_experts, std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim,
int64_t const routing_method_type, bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner,
int64_t const moeConfigIndex, torch::optional<torch::Tensor> const& topk_weights,
torch::optional<torch::Tensor> const& topk_ids)
torch::Tensor const& gemm1_weights_scale, std::optional<torch::Tensor> const& gemm1_bias,
std::optional<torch::Tensor> const& gemm1_alpha, std::optional<torch::Tensor> const& gemm1_beta,
std::optional<torch::Tensor> const& gemm1_clamp_limit, torch::Tensor const& gemm2_weights,
torch::Tensor const& gemm2_weights_scale, std::optional<torch::Tensor> const& gemm2_bias,
torch::Tensor const& output1_scales_scalar, torch::Tensor const& output1_scales_gate_scalar,
torch::Tensor const& output2_scales_scalar, int64_t const num_experts, int64_t const top_k,
std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t const intermediate_size,
int64_t const local_expert_offset, int64_t const local_num_experts,
std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type,
bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner, int64_t const moeConfigIndex,
torch::optional<torch::Tensor> const& topk_weights, torch::optional<torch::Tensor> const& topk_ids)
{
TORCH_CHECK(dtype == btg::Dtype::E4m3 || dtype == btg::Dtype::E2m1, "dtype can only be e4m3 or e2m1.");
TORCH_CHECK(tensorrt_llm::common::isSM100Family(), "Only SM100f is supported by FP4 block scale MOE");
Expand Down Expand Up @@ -161,8 +163,13 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:

args.gemm1_weights = gemm1_weights.data_ptr();
args.gemm1_weights_scale = gemm1_weights_scale.data_ptr();
args.gemm1_bias = gemm1_bias.has_value() ? gemm1_bias.value().data_ptr<float>() : nullptr;
args.gemm1_alpha = gemm1_alpha.has_value() ? gemm1_alpha.value().data_ptr<float>() : nullptr;
args.gemm1_beta = gemm1_beta.has_value() ? gemm1_beta.value().data_ptr<float>() : nullptr;
args.gemm1_clamp_limit = gemm1_clamp_limit.has_value() ? gemm1_clamp_limit.value().data_ptr<float>() : nullptr;
args.gemm2_weights = gemm2_weights.data_ptr();
args.gemm2_weights_scale = gemm2_weights_scale.data_ptr();
args.gemm2_bias = gemm2_bias.has_value() ? gemm2_bias.value().data_ptr<float>() : nullptr;
args.num_tokens = hidden_states.sizes()[0];
args.num_experts = num_experts;
if (dtype == btg::Dtype::E4m3)
Expand Down Expand Up @@ -313,6 +320,38 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
TORCH_CHECK(intermediate_size % 16 == 0, "the second dimension of weights must be a multiple of 16.");
TORCH_CHECK(gemm1_weights_scale.sizes()[1] == 2 * intermediate_size, "gemm1_weights_scale has incorrect dim 1.");

if (gemm1_bias.has_value())
{
TORCH_CHECK(gemm1_bias.value().scalar_type() == at::ScalarType::Float, "gemm1_bias must be float, got %s.",
c10::toString(gemm1_bias.value().scalar_type()));
TORCH_CHECK(gemm1_bias.value().dim() == 2, "gemm1_bias must be 2D.");
TORCH_CHECK(gemm1_bias.value().sizes()[0] == local_num_experts, "gemm1_bias has incorrect dim 0.");
TORCH_CHECK(gemm1_bias.value().sizes()[1] == 2 * intermediate_size, "gemm1_bias has incorrect dim 1.");
}

if (gemm1_alpha.has_value())
{
TORCH_CHECK(gemm1_alpha.value().scalar_type() == at::ScalarType::Float, "gemm1_alpha must be float, got %s.",
c10::toString(gemm1_alpha.value().scalar_type()));
TORCH_CHECK(gemm1_alpha.value().dim() == 1, "gemm1_alpha must be 1D.");
TORCH_CHECK(gemm1_alpha.value().sizes()[0] == local_num_experts, "gemm1_alpha has incorrect dim 0.");
}
if (gemm1_beta.has_value())
{
TORCH_CHECK(gemm1_beta.value().scalar_type() == at::ScalarType::Float, "gemm1_beta must be float, got %s.",
c10::toString(gemm1_beta.value().scalar_type()));
TORCH_CHECK(gemm1_beta.value().dim() == 1, "gemm1_beta must be 1D.");
TORCH_CHECK(gemm1_beta.value().sizes()[0] == local_num_experts, "gemm1_beta has incorrect dim 0.");
}
if (gemm1_clamp_limit.has_value())
{
TORCH_CHECK(gemm1_clamp_limit.value().scalar_type() == at::ScalarType::Float,
"gemm1_clamp_limit must be float, got %s.", c10::toString(gemm1_clamp_limit.value().scalar_type()));
TORCH_CHECK(gemm1_clamp_limit.value().dim() == 1, "gemm1_clamp_limit must be 1D.");
TORCH_CHECK(
gemm1_clamp_limit.value().sizes()[0] == local_num_experts, "gemm1_clamp_limit has incorrect dim 0.");
}

TORCH_CHECK(gemm2_weights.scalar_type() == FLOAT4_E2M1X2, "gemm2_weights must be byte.");

TORCH_CHECK(gemm2_weights.dim() == 3, "gemm2_weights must be 3D.");
Expand All @@ -322,6 +361,15 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:

TORCH_CHECK(gemm2_weights_scale.scalar_type() == at::ScalarType::Float8_e4m3fn, "gemm2_weights_scale must be fp8.");

if (gemm2_bias.has_value())
{
TORCH_CHECK(gemm2_bias.value().scalar_type() == at::ScalarType::Float, "gemm2_bias must be float, got %s.",
c10::toString(gemm2_bias.value().scalar_type()));
TORCH_CHECK(gemm2_bias.value().dim() == 2, "gemm2_bias must be 2D.");
TORCH_CHECK(gemm2_bias.value().sizes()[0] == local_num_experts, "gemm2_bias has incorrect dim 0.");
TORCH_CHECK(gemm2_bias.value().sizes()[1] == args.hidden_size, "gemm2_bias has incorrect dim 1.");
}

TORCH_CHECK(gemm2_weights_scale.dim() == 3, "gemm2_weights_scale must be 3D.");
TORCH_CHECK(gemm2_weights_scale.sizes()[0] == local_num_experts, "gemm2_weights_scale has incorrect dim 0.");
TORCH_CHECK(gemm2_weights_scale.sizes()[1] == args.hidden_size, "gemm2_weights_scale has incorrect dim 1.");
Expand Down Expand Up @@ -440,14 +488,17 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
[[nodiscard]] std::vector<torch::Tensor> run(torch::optional<torch::Tensor> const& routing_logits,
torch::optional<torch::Tensor> const& routing_bias, torch::Tensor const& hidden_states,
torch::Tensor const& hidden_states_scale, torch::Tensor const& gemm1_weights,
torch::Tensor const& gemm1_weights_scale, torch::Tensor const& gemm2_weights,
torch::Tensor const& gemm2_weights_scale, torch::Tensor const& output1_scales_scalar,
torch::Tensor const& output1_scales_gate_scalar, torch::Tensor const& output2_scales_scalar,
int64_t const num_experts, int64_t const top_k, std::optional<int64_t> const n_group,
std::optional<int64_t> const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset,
int64_t const local_num_experts, std::optional<double> const routed_scaling_factor,
int64_t const routing_method_type, bool const do_finalize, std::vector<int64_t> moeConfigIndex,
torch::optional<torch::Tensor> const& topk_weights, torch::optional<torch::Tensor> const& topk_ids)
torch::Tensor const& gemm1_weights_scale, std::optional<torch::Tensor> const& gemm1_bias,
std::optional<torch::Tensor> const& gemm1_alpha, std::optional<torch::Tensor> const& gemm1_beta,
std::optional<torch::Tensor> const& gemm1_clamp_limit, torch::Tensor const& gemm2_weights,
torch::Tensor const& gemm2_weights_scale, std::optional<torch::Tensor> const& gemm2_bias,
torch::Tensor const& output1_scales_scalar, torch::Tensor const& output1_scales_gate_scalar,
torch::Tensor const& output2_scales_scalar, int64_t const num_experts, int64_t const top_k,
std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t const intermediate_size,
int64_t const local_expert_offset, int64_t const local_num_experts,
std::optional<double> const routed_scaling_factor, int64_t const routing_method_type, bool const do_finalize,
std::vector<int64_t> moeConfigIndex, torch::optional<torch::Tensor> const& topk_weights,
torch::optional<torch::Tensor> const& topk_ids)
{
// moeConfigIndex corresponds to pair (tileN, config)
auto [tileN, config] = std::tie(moeConfigIndex[0], moeConfigIndex[1]);
Expand All @@ -468,10 +519,11 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
}

return run_fp4_block_scale_moe_runner(routing_logits, routing_bias, hidden_states, hidden_states_scale,
gemm1_weights, gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, output1_scales_scalar,
output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN,
routing_method_type, do_finalize, mDtypeElt, *mRunners[tileN], config, topk_weights, topk_ids);
gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights,
gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar,
num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
routed_scaling_factor, tileN, routing_method_type, do_finalize, mDtypeElt, *mRunners[tileN], config,
topk_weights, topk_ids);
}

private:
Expand Down Expand Up @@ -553,11 +605,11 @@ class FP8FP4BlockScaleMoeRunner : public torch::CustomClassHolder
}

return run_fp4_block_scale_moe_runner(routing_logits, routing_bias, hidden_states,
std::nullopt /*hidden_states_scale*/, gemm1_weights, gemm1_weights_scale, gemm2_weights,
gemm2_weights_scale, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, num_experts,
top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
routed_scaling_factor, tileN, routing_method_type, do_finalize, mDtypeAct, *mRunners[tileN], config,
topk_weights, topk_ids);
std::nullopt /*hidden_states_scale*/, gemm1_weights, gemm1_weights_scale, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, gemm2_weights, gemm2_weights_scale, std::nullopt, output1_scales_scalar,
output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN,
routing_method_type, do_finalize, mDtypeAct, *mRunners[tileN], config, topk_weights, topk_ids);
}

private:
Expand Down
37 changes: 29 additions & 8 deletions tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,13 @@ class FP4BlockScaleMoEInputs:
hidden_states_scale: torch.Tensor
gemm1_weights: torch.Tensor
gemm1_weights_scale: torch.Tensor
gemm1_bias: torch.Tensor
gemm1_alpha: torch.Tensor
gemm1_beta: torch.Tensor
gemm1_clamp_limit: torch.Tensor
gemm2_weights: torch.Tensor
gemm2_weights_scale: torch.Tensor
gemm2_bias: torch.Tensor
output1_scale_scalar: torch.Tensor
output1_scale_gate_scalar: torch.Tensor
output2_scale_scalar: torch.Tensor
Expand Down Expand Up @@ -235,14 +240,15 @@ def forward(
return kernel_runner.run_moe(
args.routing_logits, args.routing_bias, args.hidden_states,
args.hidden_states_scale, args.gemm1_weights,
args.gemm1_weights_scale, args.gemm2_weights,
args.gemm2_weights_scale, args.output1_scale_scalar,
args.output1_scale_gate_scalar, args.output2_scale_scalar,
self.num_experts, self.top_k, self.n_group, self.topk_group,
self.intermediate_size, self.local_expert_offset,
self.local_num_experts, self.routed_scaling_factor,
self.routing_method_type, self.do_finalize, tactic,
args.topk_weights, args.topk_ids)
args.gemm1_weights_scale, args.gemm1_bias, args.gemm1_alpha,
args.gemm1_beta, args.gemm1_clamp_limit, args.gemm2_weights,
args.gemm2_weights_scale, args.gemm2_bias,
args.output1_scale_scalar, args.output1_scale_gate_scalar,
args.output2_scale_scalar, self.num_experts, self.top_k,
self.n_group, self.topk_group, self.intermediate_size,
self.local_expert_offset, self.local_num_experts,
self.routed_scaling_factor, self.routing_method_type,
self.do_finalize, tactic, args.topk_weights, args.topk_ids)

def get_valid_tactics(self, inputs: List[torch.Tensor],
profile: OptimizationProfile,
Expand Down Expand Up @@ -359,8 +365,13 @@ def fp4_block_scale_moe_runner(
hidden_states_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm1_weights_scale: torch.Tensor,
gemm1_bias: torch.Tensor,
gemm1_alpha: torch.Tensor,
gemm1_beta: torch.Tensor,
gemm1_clamp_limit: torch.Tensor,
gemm2_weights: torch.Tensor,
gemm2_weights_scale: torch.Tensor,
gemm2_bias: torch.Tensor,
output1_scale_scalar: torch.Tensor,
output1_scale_gate_scalar: torch.Tensor,
output2_scale_scalar: torch.Tensor,
Expand Down Expand Up @@ -416,8 +427,13 @@ def fp4_block_scale_moe_runner(
hidden_states_scale,
gemm1_weights,
gemm1_weights_scale,
gemm1_bias,
gemm1_alpha,
gemm1_beta,
gemm1_clamp_limit,
gemm2_weights,
gemm2_weights_scale,
gemm2_bias,
output1_scale_scalar,
output1_scale_gate_scalar,
output2_scale_scalar,
Expand Down Expand Up @@ -474,8 +490,13 @@ def _(routing_logits,
hidden_states_scale,
gemm1_weights,
gemm1_weights_scale,
gemm1_bias,
gemm1_alpha,
gemm1_beta,
gemm1_clamp_limit,
gemm2_weights,
gemm2_weights_scale,
gemm2_bias,
output1_scale_scalar,
output1_scale_gate_scalar,
output2_scale_scalar,
Expand Down
Loading