Skip to content

Commit 8acda3c

Browse files
committed
Fix bugs
Signed-off-by: Dongfeng Yu <[email protected]>
1 parent a4d029f commit 8acda3c

File tree

3 files changed

+81
-9
lines changed

3 files changed

+81
-9
lines changed

cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "tensorrt_llm/thop/thUtils.h"
2222
#include <ATen/cuda/EmptyTensor.h>
2323
#include <ATen/ops/index_select.h>
24+
#include <iostream>
2425

2526
namespace torch_ext
2627
{
@@ -44,6 +45,78 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
4445
bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner, int64_t const moeConfigIndex,
4546
torch::optional<torch::Tensor> const& topk_weights, torch::optional<torch::Tensor> const& topk_ids)
4647
{
48+
std::cout << "Function: run_fp4_block_scale_moe_runner" << std::endl;
49+
50+
auto print_tensor = [](std::string name, torch::Tensor const& t)
51+
{
52+
std::cout << name << ": shape=[";
53+
for (auto s : t.sizes())
54+
{
55+
std::cout << s << ",";
56+
}
57+
std::cout << "], dtype=" << t.scalar_type() << std::endl;
58+
};
59+
60+
auto print_opt_tensor = [&](std::string name, auto const& t)
61+
{
62+
if (t.has_value())
63+
{
64+
print_tensor(name, t.value());
65+
}
66+
else
67+
{
68+
std::cout << name << ": None" << std::endl;
69+
}
70+
};
71+
72+
auto print_val = [](std::string name, auto const& v) { std::cout << name << ": " << v << std::endl; };
73+
74+
auto print_opt_val = [&](std::string name, auto const& v)
75+
{
76+
if (v.has_value())
77+
{
78+
std::cout << name << ": " << v.value() << std::endl;
79+
}
80+
else
81+
{
82+
std::cout << name << ": None" << std::endl;
83+
}
84+
};
85+
86+
print_opt_tensor("routing_logits", routing_logits);
87+
print_opt_tensor("routing_bias", routing_bias);
88+
print_tensor("hidden_states", hidden_states);
89+
print_opt_tensor("hidden_states_scale", hidden_states_scale);
90+
print_tensor("gemm1_weights", gemm1_weights);
91+
print_tensor("gemm1_weights_scale", gemm1_weights_scale);
92+
print_opt_tensor("gemm1_bias", gemm1_bias);
93+
print_opt_tensor("gemm1_alpha", gemm1_alpha);
94+
print_opt_tensor("gemm1_beta", gemm1_beta);
95+
print_opt_tensor("gemm1_clamp_limit", gemm1_clamp_limit);
96+
print_tensor("gemm2_weights", gemm2_weights);
97+
print_tensor("gemm2_weights_scale", gemm2_weights_scale);
98+
print_opt_tensor("gemm2_bias", gemm2_bias);
99+
print_tensor("output1_scales_scalar", output1_scales_scalar);
100+
print_tensor("output1_scales_gate_scalar", output1_scales_gate_scalar);
101+
print_tensor("output2_scales_scalar", output2_scales_scalar);
102+
103+
print_val("num_experts", num_experts);
104+
print_val("top_k", top_k);
105+
print_opt_val("n_group", n_group);
106+
print_opt_val("topk_group", topk_group);
107+
print_val("intermediate_size", intermediate_size);
108+
print_val("local_expert_offset", local_expert_offset);
109+
print_val("local_num_experts", local_num_experts);
110+
print_opt_val("routed_scaling_factor", routed_scaling_factor);
111+
print_val("tile_tokens_dim", tile_tokens_dim);
112+
print_val("routing_method_type", routing_method_type);
113+
print_val("do_finalize", do_finalize);
114+
print_val("dtype", static_cast<int>(dtype));
115+
print_val("moeConfigIndex", moeConfigIndex);
116+
print_opt_tensor("topk_weights", topk_weights);
117+
print_opt_tensor("topk_ids", topk_ids);
118+
std::cout << "--------------------------------" << std::endl;
119+
47120
TORCH_CHECK(dtype == btg::Dtype::E4m3 || dtype == btg::Dtype::E2m1, "dtype can only be e4m3 or e2m1.");
48121
TORCH_CHECK(tensorrt_llm::common::isSM100Family(), "Only SM100f is supported by FP4 block scale MOE");
49122
TORCH_CHECK(tile_tokens_dim == 8 || tile_tokens_dim == 16 || tile_tokens_dim == 32 || tile_tokens_dim == 64

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,11 @@ def create_weights(
201201

202202
# bias
203203
if module.bias:
204+
# The shape might be padded so we use weight shape[:2]
204205
if w3_w1_bias_shape is None:
205-
w3_w1_bias_shape = (module.expert_size_per_partition,
206-
module.intermediate_size_per_partition * 2)
206+
w3_w1_bias_shape = w3_w1_weight_shape[:2]
207207
if w2_bias_shape is None:
208-
w2_bias_shape = (module.expert_size_per_partition,
209-
module.hidden_size)
208+
w2_bias_shape = w2_weight_shape[:2]
210209
bias_dtype = bias_dtype or module.dtype
211210
w3_w1_bias = nn.Parameter(torch.empty(w3_w1_bias_shape,
212211
dtype=bias_dtype),

tests/unittest/_torch/modules/test_fused_moe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,11 +1381,11 @@ def test_fused_moe_nvfp4(dtype, moe_backend, hidden_size, intermediate_size):
13811381
with torch.device(f"cuda:{mapping.rank}"):
13821382
SCALING_VECTOR_SIZE = 16
13831383

1384-
SEQ_LEN = 4
1384+
SEQ_LEN = 1024
13851385
HIDDEN_SIZE = hidden_size
13861386
INTERMEDIATE_SIZE = intermediate_size
1387-
NUM_EXPERTS = 4
1388-
TOP_K = 2
1387+
NUM_EXPERTS = 32
1388+
TOP_K = 4
13891389
routing_method = RenormalizeMoeRoutingMethod(top_k=TOP_K)
13901390
torch.manual_seed(0)
13911391
torch.cuda.manual_seed(0)
@@ -1487,7 +1487,7 @@ def test_fused_moe_nvfp4(dtype, moe_backend, hidden_size, intermediate_size):
14871487
fused_moe.forward(x, router_logits)
14881488

14891489
output = fused_moe.forward(x, router_logits)
1490-
torch.testing.assert_close(output, ref_output, rtol=0.1, atol=0.4)
1490+
torch.testing.assert_close(output, ref_output, rtol=0.1, atol=0.5)
14911491

14921492
if not test_all_kernels:
14931493
return
@@ -1503,7 +1503,7 @@ def test_fused_moe_nvfp4(dtype, moe_backend, hidden_size, intermediate_size):
15031503
torch.testing.assert_close(output,
15041504
ref_output,
15051505
rtol=0.1,
1506-
atol=0.4)
1506+
atol=0.5)
15071507

15081508

15091509
@skip_pre_blackwell

0 commit comments

Comments
 (0)