Skip to content

Commit 56d5824

Browse files
committed
Squash all commits
Add unit test Signed-off-by: Dongfeng Yu <[email protected]> Make swiglu alpha work Signed-off-by: Dongfeng Yu <[email protected]> Add beta Signed-off-by: Dongfeng Yu <[email protected]> Add limit Signed-off-by: Dongfeng Yu <[email protected]> Add beta correction Signed-off-by: Dongfeng Yu <[email protected]> Add correction Signed-off-by: Dongfeng Yu <[email protected]> Add bias1 Signed-off-by: Dongfeng Yu <[email protected]> Add bias 2 Signed-off-by: Dongfeng Yu <[email protected]> Add moe module level changes Signed-off-by: Dongfeng Yu <[email protected]> Weights loading part Signed-off-by: Dongfeng Yu <[email protected]> update test config Signed-off-by: Dongfeng Yu <[email protected]> Add validM/N/K WAR Signed-off-by: Dongfeng Yu <[email protected]> Fix bias issue Signed-off-by: Dongfeng Yu <[email protected]> clean up Signed-off-by: Dongfeng Yu <[email protected]> Fix non-gptoss tests failing Signed-off-by: Dongfeng Yu <[email protected]> Fix rebase issue Signed-off-by: Dongfeng Yu <[email protected]> refactor model loading Signed-off-by: Dongfeng Yu <[email protected]> Add a WAR for loading moe only quantizations Signed-off-by: Dongfeng Yu <[email protected]> Fix weights transpose and bias dtype Signed-off-by: Dongfeng Yu <[email protected]> fix scale transpose Signed-off-by: Dongfeng Yu <[email protected]> Fix rebase Signed-off-by: Dongfeng Yu <[email protected]> Add WAR for block scaling dim Signed-off-by: Dongfeng Yu <[email protected]> Revert "Add WAR for block scaling dim" This reverts commit 2829568. Refine war Signed-off-by: Dongfeng Yu <[email protected]> Remove ckpt WAR after getting new ckpt and wip changes to module level unit tests Signed-off-by: Dongfeng Yu <[email protected]> Add paddings Signed-off-by: Dongfeng Yu <[email protected]> fix padding and make unit test pass Signed-off-by: Dongfeng Yu <[email protected]> remove debug prints Signed-off-by: Dongfeng Yu <[email protected]> Fix bugs Signed-off-by: Dongfeng Yu <[email protected]> update length Signed-off-by: Dongfeng Yu <[email protected]> Disable kernels with issue Signed-off-by: Dongfeng Yu <[email protected]> Fix bias scaling Signed-off-by: Dongfeng Yu <[email protected]> fix accuracy val Signed-off-by: Dongfeng Yu <[email protected]> fix limit and beta Signed-off-by: Dongfeng Yu <[email protected]> Add additional parameters to module test Signed-off-by: Dongfeng Yu <[email protected]> remove wrong interleave Signed-off-by: Dongfeng Yu <[email protected]> Revert "remove wrong interleave" This reverts commit b143db2. Use the correct layout Signed-off-by: Dongfeng Yu <[email protected]>
1 parent 899fda9 commit 56d5824

File tree

8 files changed

+931
-135
lines changed

8 files changed

+931
-135
lines changed

cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp

Lines changed: 80 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "tensorrt_llm/thop/thUtils.h"
2222
#include <ATen/cuda/EmptyTensor.h>
2323
#include <ATen/ops/index_select.h>
24+
#include <iostream>
2425

2526
namespace torch_ext
2627
{
@@ -32,15 +33,17 @@ using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::computeSelectedTileN;
3233
std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch::Tensor> const& routing_logits,
3334
torch::optional<torch::Tensor> const& routing_bias, torch::Tensor const& hidden_states,
3435
torch::optional<torch::Tensor> const& hidden_states_scale, torch::Tensor const& gemm1_weights,
35-
torch::Tensor const& gemm1_weights_scale, torch::Tensor const& gemm2_weights,
36-
torch::Tensor const& gemm2_weights_scale, torch::Tensor const& output1_scales_scalar,
37-
torch::Tensor const& output1_scales_gate_scalar, torch::Tensor const& output2_scales_scalar,
38-
int64_t const num_experts, int64_t const top_k, std::optional<int64_t> const n_group,
39-
std::optional<int64_t> const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset,
40-
int64_t const local_num_experts, std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim,
41-
int64_t const routing_method_type, bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner,
42-
int64_t const moeConfigIndex, torch::optional<torch::Tensor> const& topk_weights,
43-
torch::optional<torch::Tensor> const& topk_ids)
36+
torch::Tensor const& gemm1_weights_scale, std::optional<torch::Tensor> const& gemm1_bias,
37+
std::optional<torch::Tensor> const& gemm1_alpha, std::optional<torch::Tensor> const& gemm1_beta,
38+
std::optional<torch::Tensor> const& gemm1_clamp_limit, torch::Tensor const& gemm2_weights,
39+
torch::Tensor const& gemm2_weights_scale, std::optional<torch::Tensor> const& gemm2_bias,
40+
torch::Tensor const& output1_scales_scalar, torch::Tensor const& output1_scales_gate_scalar,
41+
torch::Tensor const& output2_scales_scalar, int64_t const num_experts, int64_t const top_k,
42+
std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t const intermediate_size,
43+
int64_t const local_expert_offset, int64_t const local_num_experts,
44+
std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type,
45+
bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner, int64_t const moeConfigIndex,
46+
torch::optional<torch::Tensor> const& topk_weights, torch::optional<torch::Tensor> const& topk_ids)
4447
{
4548
TORCH_CHECK(dtype == btg::Dtype::E4m3 || dtype == btg::Dtype::E2m1, "dtype can only be e4m3 or e2m1.");
4649
TORCH_CHECK(tensorrt_llm::common::isSM100Family(), "Only SM100f is supported by FP4 block scale MOE");
@@ -159,8 +162,13 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
159162

160163
args.gemm1_weights = gemm1_weights.data_ptr();
161164
args.gemm1_weights_scale = gemm1_weights_scale.data_ptr();
165+
args.gemm1_bias = gemm1_bias.has_value() ? gemm1_bias.value().data_ptr<float>() : nullptr;
166+
args.gemm1_alpha = gemm1_alpha.has_value() ? gemm1_alpha.value().data_ptr<float>() : nullptr;
167+
args.gemm1_beta = gemm1_beta.has_value() ? gemm1_beta.value().data_ptr<float>() : nullptr;
168+
args.gemm1_clamp_limit = gemm1_clamp_limit.has_value() ? gemm1_clamp_limit.value().data_ptr<float>() : nullptr;
162169
args.gemm2_weights = gemm2_weights.data_ptr();
163170
args.gemm2_weights_scale = gemm2_weights_scale.data_ptr();
171+
args.gemm2_bias = gemm2_bias.has_value() ? gemm2_bias.value().data_ptr<float>() : nullptr;
164172
args.num_tokens = hidden_states.sizes()[0];
165173
args.num_experts = num_experts;
166174
if (dtype == btg::Dtype::E4m3)
@@ -311,6 +319,38 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
311319
TORCH_CHECK(intermediate_size % 16 == 0, "the second dimension of weights must be a multiple of 16.");
312320
TORCH_CHECK(gemm1_weights_scale.sizes()[1] == 2 * intermediate_size, "gemm1_weights_scale has incorrect dim 1.");
313321

322+
if (gemm1_bias.has_value())
323+
{
324+
TORCH_CHECK(gemm1_bias.value().scalar_type() == at::ScalarType::Float, "gemm1_bias must be float, got %s.",
325+
c10::toString(gemm1_bias.value().scalar_type()));
326+
TORCH_CHECK(gemm1_bias.value().dim() == 2, "gemm1_bias must be 2D.");
327+
TORCH_CHECK(gemm1_bias.value().sizes()[0] == local_num_experts, "gemm1_bias has incorrect dim 0.");
328+
TORCH_CHECK(gemm1_bias.value().sizes()[1] == 2 * intermediate_size, "gemm1_bias has incorrect dim 1.");
329+
}
330+
331+
if (gemm1_alpha.has_value())
332+
{
333+
TORCH_CHECK(gemm1_alpha.value().scalar_type() == at::ScalarType::Float, "gemm1_alpha must be float, got %s.",
334+
c10::toString(gemm1_alpha.value().scalar_type()));
335+
TORCH_CHECK(gemm1_alpha.value().dim() == 1, "gemm1_alpha must be 1D.");
336+
TORCH_CHECK(gemm1_alpha.value().sizes()[0] == local_num_experts, "gemm1_alpha has incorrect dim 0.");
337+
}
338+
if (gemm1_beta.has_value())
339+
{
340+
TORCH_CHECK(gemm1_beta.value().scalar_type() == at::ScalarType::Float, "gemm1_beta must be float, got %s.",
341+
c10::toString(gemm1_beta.value().scalar_type()));
342+
TORCH_CHECK(gemm1_beta.value().dim() == 1, "gemm1_beta must be 1D.");
343+
TORCH_CHECK(gemm1_beta.value().sizes()[0] == local_num_experts, "gemm1_beta has incorrect dim 0.");
344+
}
345+
if (gemm1_clamp_limit.has_value())
346+
{
347+
TORCH_CHECK(gemm1_clamp_limit.value().scalar_type() == at::ScalarType::Float,
348+
"gemm1_clamp_limit must be float, got %s.", c10::toString(gemm1_clamp_limit.value().scalar_type()));
349+
TORCH_CHECK(gemm1_clamp_limit.value().dim() == 1, "gemm1_clamp_limit must be 1D.");
350+
TORCH_CHECK(
351+
gemm1_clamp_limit.value().sizes()[0] == local_num_experts, "gemm1_clamp_limit has incorrect dim 0.");
352+
}
353+
314354
TORCH_CHECK(gemm2_weights.scalar_type() == FLOAT4_E2M1X2, "gemm2_weights must be byte.");
315355

316356
TORCH_CHECK(gemm2_weights.dim() == 3, "gemm2_weights must be 3D.");
@@ -320,6 +360,15 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
320360

321361
TORCH_CHECK(gemm2_weights_scale.scalar_type() == at::ScalarType::Float8_e4m3fn, "gemm2_weights_scale must be fp8.");
322362

363+
if (gemm2_bias.has_value())
364+
{
365+
TORCH_CHECK(gemm2_bias.value().scalar_type() == at::ScalarType::Float, "gemm2_bias must be float, got %s.",
366+
c10::toString(gemm2_bias.value().scalar_type()));
367+
TORCH_CHECK(gemm2_bias.value().dim() == 2, "gemm2_bias must be 2D.");
368+
TORCH_CHECK(gemm2_bias.value().sizes()[0] == local_num_experts, "gemm2_bias has incorrect dim 0.");
369+
TORCH_CHECK(gemm2_bias.value().sizes()[1] == args.hidden_size, "gemm2_bias has incorrect dim 1.");
370+
}
371+
323372
TORCH_CHECK(gemm2_weights_scale.dim() == 3, "gemm2_weights_scale must be 3D.");
324373
TORCH_CHECK(gemm2_weights_scale.sizes()[0] == local_num_experts, "gemm2_weights_scale has incorrect dim 0.");
325374
TORCH_CHECK(gemm2_weights_scale.sizes()[1] == args.hidden_size, "gemm2_weights_scale has incorrect dim 1.");
@@ -405,7 +454,7 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
405454
public:
406455
explicit FP4BlockScaleMoeRunner()
407456
// Update this as new cubins come in
408-
: mSupportedTileN{8, 16, 32, 64, 128, 256}
457+
: mSupportedTileN{8, 16, 32, 64, 128}
409458
{
410459
for (int tileN : mSupportedTileN)
411460
{
@@ -438,14 +487,17 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
438487
[[nodiscard]] std::vector<torch::Tensor> run(torch::optional<torch::Tensor> const& routing_logits,
439488
torch::optional<torch::Tensor> const& routing_bias, torch::Tensor const& hidden_states,
440489
torch::Tensor const& hidden_states_scale, torch::Tensor const& gemm1_weights,
441-
torch::Tensor const& gemm1_weights_scale, torch::Tensor const& gemm2_weights,
442-
torch::Tensor const& gemm2_weights_scale, torch::Tensor const& output1_scales_scalar,
443-
torch::Tensor const& output1_scales_gate_scalar, torch::Tensor const& output2_scales_scalar,
444-
int64_t const num_experts, int64_t const top_k, std::optional<int64_t> const n_group,
445-
std::optional<int64_t> const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset,
446-
int64_t const local_num_experts, std::optional<double> const routed_scaling_factor,
447-
int64_t const routing_method_type, bool const do_finalize, std::vector<int64_t> moeConfigIndex,
448-
torch::optional<torch::Tensor> const& topk_weights, torch::optional<torch::Tensor> const& topk_ids)
490+
torch::Tensor const& gemm1_weights_scale, std::optional<torch::Tensor> const& gemm1_bias,
491+
std::optional<torch::Tensor> const& gemm1_alpha, std::optional<torch::Tensor> const& gemm1_beta,
492+
std::optional<torch::Tensor> const& gemm1_clamp_limit, torch::Tensor const& gemm2_weights,
493+
torch::Tensor const& gemm2_weights_scale, std::optional<torch::Tensor> const& gemm2_bias,
494+
torch::Tensor const& output1_scales_scalar, torch::Tensor const& output1_scales_gate_scalar,
495+
torch::Tensor const& output2_scales_scalar, int64_t const num_experts, int64_t const top_k,
496+
std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t const intermediate_size,
497+
int64_t const local_expert_offset, int64_t const local_num_experts,
498+
std::optional<double> const routed_scaling_factor, int64_t const routing_method_type, bool const do_finalize,
499+
std::vector<int64_t> moeConfigIndex, torch::optional<torch::Tensor> const& topk_weights,
500+
torch::optional<torch::Tensor> const& topk_ids)
449501
{
450502
// moeConfigIndex corresponds to pair (tileN, config)
451503
auto [tileN, config] = std::tie(moeConfigIndex[0], moeConfigIndex[1]);
@@ -466,10 +518,11 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
466518
}
467519

468520
return run_fp4_block_scale_moe_runner(routing_logits, routing_bias, hidden_states, hidden_states_scale,
469-
gemm1_weights, gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, output1_scales_scalar,
470-
output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
471-
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN,
472-
routing_method_type, do_finalize, mDtypeElt, *mRunners[tileN], config, topk_weights, topk_ids);
521+
gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights,
522+
gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar,
523+
num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
524+
routed_scaling_factor, tileN, routing_method_type, do_finalize, mDtypeElt, *mRunners[tileN], config,
525+
topk_weights, topk_ids);
473526
}
474527

475528
private:
@@ -551,11 +604,11 @@ class FP8FP4BlockScaleMoeRunner : public torch::CustomClassHolder
551604
}
552605

553606
return run_fp4_block_scale_moe_runner(routing_logits, routing_bias, hidden_states,
554-
std::nullopt /*hidden_states_scale*/, gemm1_weights, gemm1_weights_scale, gemm2_weights,
555-
gemm2_weights_scale, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, num_experts,
556-
top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
557-
routed_scaling_factor, tileN, routing_method_type, do_finalize, mDtypeAct, *mRunners[tileN], config,
558-
topk_weights, topk_ids);
607+
std::nullopt /*hidden_states_scale*/, gemm1_weights, gemm1_weights_scale, std::nullopt, std::nullopt,
608+
std::nullopt, std::nullopt, gemm2_weights, gemm2_weights_scale, std::nullopt, output1_scales_scalar,
609+
output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
610+
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN,
611+
routing_method_type, do_finalize, mDtypeAct, *mRunners[tileN], config, topk_weights, topk_ids);
559612
}
560613

561614
private:

tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,13 @@ class FP4BlockScaleMoEInputs:
157157
hidden_states_scale: torch.Tensor
158158
gemm1_weights: torch.Tensor
159159
gemm1_weights_scale: torch.Tensor
160+
gemm1_bias: torch.Tensor
161+
gemm1_alpha: torch.Tensor
162+
gemm1_beta: torch.Tensor
163+
gemm1_clamp_limit: torch.Tensor
160164
gemm2_weights: torch.Tensor
161165
gemm2_weights_scale: torch.Tensor
166+
gemm2_bias: torch.Tensor
162167
output1_scale_scalar: torch.Tensor
163168
output1_scale_gate_scalar: torch.Tensor
164169
output2_scale_scalar: torch.Tensor
@@ -230,14 +235,15 @@ def forward(
230235
return kernel_runner.run_moe(
231236
args.routing_logits, args.routing_bias, args.hidden_states,
232237
args.hidden_states_scale, args.gemm1_weights,
233-
args.gemm1_weights_scale, args.gemm2_weights,
234-
args.gemm2_weights_scale, args.output1_scale_scalar,
235-
args.output1_scale_gate_scalar, args.output2_scale_scalar,
236-
self.num_experts, self.top_k, self.n_group, self.topk_group,
237-
self.intermediate_size, self.local_expert_offset,
238-
self.local_num_experts, self.routed_scaling_factor,
239-
self.routing_method_type, self.do_finalize, tactic,
240-
args.topk_weights, args.topk_ids)
238+
args.gemm1_weights_scale, args.gemm1_bias, args.gemm1_alpha,
239+
args.gemm1_beta, args.gemm1_clamp_limit, args.gemm2_weights,
240+
args.gemm2_weights_scale, args.gemm2_bias,
241+
args.output1_scale_scalar, args.output1_scale_gate_scalar,
242+
args.output2_scale_scalar, self.num_experts, self.top_k,
243+
self.n_group, self.topk_group, self.intermediate_size,
244+
self.local_expert_offset, self.local_num_experts,
245+
self.routed_scaling_factor, self.routing_method_type,
246+
self.do_finalize, tactic, args.topk_weights, args.topk_ids)
241247

242248
def get_valid_tactics(self, inputs: List[torch.Tensor],
243249
profile: OptimizationProfile,
@@ -354,8 +360,13 @@ def fp4_block_scale_moe_runner(
354360
hidden_states_scale: torch.Tensor,
355361
gemm1_weights: torch.Tensor,
356362
gemm1_weights_scale: torch.Tensor,
363+
gemm1_bias: torch.Tensor,
364+
gemm1_alpha: torch.Tensor,
365+
gemm1_beta: torch.Tensor,
366+
gemm1_clamp_limit: torch.Tensor,
357367
gemm2_weights: torch.Tensor,
358368
gemm2_weights_scale: torch.Tensor,
369+
gemm2_bias: torch.Tensor,
359370
output1_scale_scalar: torch.Tensor,
360371
output1_scale_gate_scalar: torch.Tensor,
361372
output2_scale_scalar: torch.Tensor,
@@ -408,8 +419,13 @@ def fp4_block_scale_moe_runner(
408419
hidden_states_scale,
409420
gemm1_weights,
410421
gemm1_weights_scale,
422+
gemm1_bias,
423+
gemm1_alpha,
424+
gemm1_beta,
425+
gemm1_clamp_limit,
411426
gemm2_weights,
412427
gemm2_weights_scale,
428+
gemm2_bias,
413429
output1_scale_scalar,
414430
output1_scale_gate_scalar,
415431
output2_scale_scalar,
@@ -433,8 +449,13 @@ def fp4_block_scale_moe_runner(
433449
hidden_states_scale,
434450
gemm1_weights,
435451
gemm1_weights_scale,
452+
gemm1_bias,
453+
gemm1_alpha,
454+
gemm1_beta,
455+
gemm1_clamp_limit,
436456
gemm2_weights,
437457
gemm2_weights_scale,
458+
gemm2_bias,
438459
output1_scale_scalar,
439460
output1_scale_gate_scalar,
440461
output2_scale_scalar,
@@ -479,8 +500,13 @@ def _(routing_logits,
479500
hidden_states_scale,
480501
gemm1_weights,
481502
gemm1_weights_scale,
503+
gemm1_bias,
504+
gemm1_alpha,
505+
gemm1_beta,
506+
gemm1_clamp_limit,
482507
gemm2_weights,
483508
gemm2_weights_scale,
509+
gemm2_bias,
484510
output1_scale_scalar,
485511
output1_scale_gate_scalar,
486512
output2_scale_scalar,

tensorrt_llm/_torch/model_config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,19 @@ def load_modelopt_quant_config(quant_config_file, checkpoint_dir,
285285
quant_config.exclude_modules = [
286286
"*kv_b_proj*", "*k_b_proj*", "*eh_proj"
287287
]
288+
289+
# --- NVFP4 GPT‑OSS WAR: only MoE is NVFP4 ---
290+
if quant_config.quant_algo == "NVFP4":
291+
# Exclude all non‑MoE linears; adjust patterns as needed
292+
quant_config.exclude_modules = [
293+
'block.*.attn.qkv',
294+
'block.*.attn.out',
295+
'block.*.mlp.gate',
296+
'embedding',
297+
'unembedding',
298+
]
299+
# --------------------------------------------
300+
288301
return quant_config, layer_quant_config
289302

290303
@staticmethod

0 commit comments

Comments
 (0)