Skip to content

Commit 0887e98

Browse files
committed
Add bias 2
Signed-off-by: Dongfeng Yu <dongfengy@nvidia.com>
1 parent 7016197 commit 0887e98

File tree

4 files changed

+92
-27
lines changed

4 files changed

+92
-27
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
211211
int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas, void* workspace, CUstream stream, int device,
212212
int32_t configIndex)
213213
{
214+
std::cout << "run 1" << std::endl;
215+
std::cout << ptrBias << std::endl;
214216
auto bmm = BatchedGemmInterface();
215217

216218
BatchedGemmData gemmData;
@@ -305,6 +307,8 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
305307
void const* a, void const* sfA, void const* b, void const* sfB, void* c, void* outSfC, void* workspace,
306308
CUstream stream, int device, int32_t configIndex)
307309
{
310+
std::cout << "run 2" << std::endl;
311+
std::cout << "no bias" << std::endl;
308312
// Dispatch with block scaling factors and with static batching.
309313
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a, sfA, b, sfB,
310314
/* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr,
@@ -320,6 +324,8 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
320324
float const* ptrBeta, float const* ptrClampLimit, void* c, void* outSfC, void* workspace, CUstream stream,
321325
int device, int32_t configIndex)
322326
{
327+
std::cout << "run 3" << std::endl;
328+
std::cout << ptrBias << std::endl;
323329
// Dispatch with block scaling factors and with static batching.
324330
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a, sfA, b, sfB,
325331
/* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr,
@@ -333,6 +339,8 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
333339
void const* a, void const* b, float const* scaleC, float const* scaleGateC, void* c, void* workspace,
334340
CUstream stream, int device, int32_t configIndex)
335341
{
342+
std::cout << "run 4" << std::endl;
343+
std::cout << "no bias" << std::endl;
336344
// Dispatch with block scaling factors and with static batching.
337345
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a,
338346
/* sfA */ nullptr, b, /* sfB */ nullptr, /* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr, scaleC,

cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
3434
torch::Tensor const& gemm1_weights_scale, std::optional<torch::Tensor> const& gemm1_bias,
3535
std::optional<torch::Tensor> const& gemm1_alpha, std::optional<torch::Tensor> const& gemm1_beta,
3636
std::optional<torch::Tensor> const& gemm1_clamp_limit, torch::Tensor const& gemm2_weights,
37-
torch::Tensor const& gemm2_weights_scale, torch::Tensor const& output1_scales_scalar,
38-
torch::Tensor const& output1_scales_gate_scalar, torch::Tensor const& output2_scales_scalar,
39-
int64_t const num_experts, int64_t const top_k, std::optional<int64_t> const n_group,
40-
std::optional<int64_t> const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset,
41-
int64_t const local_num_experts, std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim,
42-
int64_t const routing_method_type, bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner,
43-
int64_t const moeConfigIndex, torch::optional<torch::Tensor> const& topk_weights,
44-
torch::optional<torch::Tensor> const& topk_ids)
37+
torch::Tensor const& gemm2_weights_scale, std::optional<torch::Tensor> const& gemm2_bias,
38+
torch::Tensor const& output1_scales_scalar, torch::Tensor const& output1_scales_gate_scalar,
39+
torch::Tensor const& output2_scales_scalar, int64_t const num_experts, int64_t const top_k,
40+
std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t const intermediate_size,
41+
int64_t const local_expert_offset, int64_t const local_num_experts,
42+
std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type,
43+
bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner, int64_t const moeConfigIndex,
44+
torch::optional<torch::Tensor> const& topk_weights, torch::optional<torch::Tensor> const& topk_ids)
4545
{
4646
TORCH_CHECK(dtype == btg::Dtype::E4m3 || dtype == btg::Dtype::E2m1, "dtype can only be e4m3 or e2m1.");
4747
TORCH_CHECK(tensorrt_llm::common::isSM100Family(), "Only SM100f is supported by FP4 block scale MOE");
@@ -166,6 +166,7 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
166166
args.gemm1_clamp_limit = gemm1_clamp_limit.has_value() ? gemm1_clamp_limit.value().data_ptr<float>() : nullptr;
167167
args.gemm2_weights = gemm2_weights.data_ptr();
168168
args.gemm2_weights_scale = gemm2_weights_scale.data_ptr();
169+
args.gemm2_bias = gemm2_bias.has_value() ? gemm2_bias.value().data_ptr<float>() : nullptr;
169170
args.num_tokens = hidden_states.sizes()[0];
170171
args.num_experts = num_experts;
171172
if (dtype == btg::Dtype::E4m3)
@@ -357,6 +358,15 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
357358

358359
TORCH_CHECK(gemm2_weights_scale.scalar_type() == at::ScalarType::Float8_e4m3fn, "gemm2_weights_scale must be fp8.");
359360

361+
if (gemm2_bias.has_value())
362+
{
363+
TORCH_CHECK(gemm2_bias.value().scalar_type() == at::ScalarType::Float, "gemm2_bias must be float, got %s.",
364+
c10::toString(gemm2_bias.value().scalar_type()));
365+
TORCH_CHECK(gemm2_bias.value().dim() == 2, "gemm2_bias must be 2D.");
366+
TORCH_CHECK(gemm2_bias.value().sizes()[0] == local_num_experts, "gemm2_bias has incorrect dim 0.");
367+
TORCH_CHECK(gemm2_bias.value().sizes()[1] == args.hidden_size, "gemm2_bias has incorrect dim 1.");
368+
}
369+
360370
TORCH_CHECK(gemm2_weights_scale.dim() == 3, "gemm2_weights_scale must be 3D.");
361371
TORCH_CHECK(gemm2_weights_scale.sizes()[0] == local_num_experts, "gemm2_weights_scale has incorrect dim 0.");
362372
TORCH_CHECK(gemm2_weights_scale.sizes()[1] == args.hidden_size, "gemm2_weights_scale has incorrect dim 1.");
@@ -461,6 +471,29 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
461471
}
462472
std::cout << std::endl;
463473
}
474+
std::vector<float> gemm1_bias_vals;
475+
if (gemm1_bias.has_value())
476+
{
477+
auto bias_cpu = gemm1_bias.value().cpu().contiguous();
478+
float* bias_ptr = bias_cpu.data_ptr<float>();
479+
std::cout << "[FP4BlockScaleMoe] gemm1 bias: ";
480+
for (int i = 0; i < std::min(local_num_experts * intermediate_size * 2, int64_t(30)); ++i)
481+
{
482+
std::cout << bias_ptr[i] << " ";
483+
}
484+
std::cout << std::endl;
485+
}
486+
if (gemm2_bias.has_value())
487+
{
488+
auto bias_cpu = gemm2_bias.value().cpu().contiguous();
489+
float* bias_ptr = bias_cpu.data_ptr<float>();
490+
std::cout << "[FP4BlockScaleMoe] gemm2 bias: ";
491+
for (int i = 0; i < std::min(local_num_experts * args.hidden_size, int64_t(30)); ++i)
492+
{
493+
std::cout << bias_ptr[i] << " ";
494+
}
495+
std::cout << std::endl;
496+
}
464497

465498
moe_runner.run(args, workspace, hidden_states.get_device(), moe_stream, moeConfigIndex);
466499

@@ -510,13 +543,14 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
510543
torch::Tensor const& gemm1_weights_scale, std::optional<torch::Tensor> const& gemm1_bias,
511544
std::optional<torch::Tensor> const& gemm1_alpha, std::optional<torch::Tensor> const& gemm1_beta,
512545
std::optional<torch::Tensor> const& gemm1_clamp_limit, torch::Tensor const& gemm2_weights,
513-
torch::Tensor const& gemm2_weights_scale, torch::Tensor const& output1_scales_scalar,
514-
torch::Tensor const& output1_scales_gate_scalar, torch::Tensor const& output2_scales_scalar,
515-
int64_t const num_experts, int64_t const top_k, std::optional<int64_t> const n_group,
516-
std::optional<int64_t> const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset,
517-
int64_t const local_num_experts, std::optional<double> const routed_scaling_factor,
518-
int64_t const routing_method_type, bool const do_finalize, std::vector<int64_t> moeConfigIndex,
519-
torch::optional<torch::Tensor> const& topk_weights, torch::optional<torch::Tensor> const& topk_ids)
546+
torch::Tensor const& gemm2_weights_scale, std::optional<torch::Tensor> const& gemm2_bias,
547+
torch::Tensor const& output1_scales_scalar, torch::Tensor const& output1_scales_gate_scalar,
548+
torch::Tensor const& output2_scales_scalar, int64_t const num_experts, int64_t const top_k,
549+
std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t const intermediate_size,
550+
int64_t const local_expert_offset, int64_t const local_num_experts,
551+
std::optional<double> const routed_scaling_factor, int64_t const routing_method_type, bool const do_finalize,
552+
std::vector<int64_t> moeConfigIndex, torch::optional<torch::Tensor> const& topk_weights,
553+
torch::optional<torch::Tensor> const& topk_ids)
520554
{
521555
// moeConfigIndex corresponds to pair (tileN, config)
522556
auto [tileN, config] = std::tie(moeConfigIndex[0], moeConfigIndex[1]);
@@ -538,8 +572,8 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
538572

539573
return run_fp4_block_scale_moe_runner(routing_logits, routing_bias, hidden_states, hidden_states_scale,
540574
gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights,
541-
gemm2_weights_scale, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, num_experts,
542-
top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
575+
gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar,
576+
num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
543577
routed_scaling_factor, tileN, routing_method_type, do_finalize, mDtypeElt, *mRunners[tileN], config,
544578
topk_weights, topk_ids);
545579
}
@@ -619,7 +653,7 @@ class FP8FP4BlockScaleMoeRunner : public torch::CustomClassHolder
619653

620654
return run_fp4_block_scale_moe_runner(routing_logits, routing_bias, hidden_states,
621655
std::nullopt /*hidden_states_scale*/, gemm1_weights, gemm1_weights_scale, std::nullopt, std::nullopt,
622-
std::nullopt, std::nullopt, gemm2_weights, gemm2_weights_scale, output1_scales_scalar,
656+
std::nullopt, std::nullopt, gemm2_weights, gemm2_weights_scale, std::nullopt, output1_scales_scalar,
623657
output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
624658
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN,
625659
routing_method_type, do_finalize, mDtypeAct, *mRunners[tileN], config, topk_weights, topk_ids);

tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class FP4BlockScaleMoEInputs:
5454
gemm1_clamp_limit: torch.Tensor
5555
gemm2_weights: torch.Tensor
5656
gemm2_weights_scale: torch.Tensor
57+
gemm2_bias: torch.Tensor
5758
output1_scale_scalar: torch.Tensor
5859
output1_scale_gate_scalar: torch.Tensor
5960
output2_scale_scalar: torch.Tensor
@@ -127,13 +128,13 @@ def forward(
127128
args.hidden_states_scale, args.gemm1_weights,
128129
args.gemm1_weights_scale, args.gemm1_bias, args.gemm1_alpha,
129130
args.gemm1_beta, args.gemm1_clamp_limit, args.gemm2_weights,
130-
args.gemm2_weights_scale, args.output1_scale_scalar,
131-
args.output1_scale_gate_scalar, args.output2_scale_scalar,
132-
self.num_experts, self.top_k, self.n_group, self.topk_group,
133-
self.intermediate_size, self.local_expert_offset,
134-
self.local_num_experts, self.routed_scaling_factor,
135-
self.routing_method_type, self.do_finalize, tactic,
136-
args.topk_weights, args.topk_ids)
131+
args.gemm2_weights_scale, args.gemm2_bias,
132+
args.output1_scale_scalar, args.output1_scale_gate_scalar,
133+
args.output2_scale_scalar, self.num_experts, self.top_k,
134+
self.n_group, self.topk_group, self.intermediate_size,
135+
self.local_expert_offset, self.local_num_experts,
136+
self.routed_scaling_factor, self.routing_method_type,
137+
self.do_finalize, tactic, args.topk_weights, args.topk_ids)
137138

138139
def get_valid_tactics(self, inputs: List[torch.Tensor],
139140
profile: OptimizationProfile,
@@ -247,6 +248,7 @@ def fp4_block_scale_moe_runner(
247248
gemm1_clamp_limit: torch.Tensor,
248249
gemm2_weights: torch.Tensor,
249250
gemm2_weights_scale: torch.Tensor,
251+
gemm2_bias: torch.Tensor,
250252
output1_scale_scalar: torch.Tensor,
251253
output1_scale_gate_scalar: torch.Tensor,
252254
output2_scale_scalar: torch.Tensor,
@@ -299,6 +301,7 @@ def fp4_block_scale_moe_runner(
299301
gemm1_clamp_limit,
300302
gemm2_weights,
301303
gemm2_weights_scale,
304+
gemm2_bias,
302305
output1_scale_scalar,
303306
output1_scale_gate_scalar,
304307
output2_scale_scalar,
@@ -357,6 +360,7 @@ def _(routing_logits,
357360
gemm1_clamp_limit,
358361
gemm2_weights,
359362
gemm2_weights_scale,
363+
gemm2_bias,
360364
output1_scale_scalar,
361365
output1_scale_gate_scalar,
362366
output2_scale_scalar,

tests/unittest/_torch/thop/parallel/test_moe.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ def run_moe_reference_fp4(args):
564564
gemm1_alpha=args.gemm1_alpha,
565565
gemm1_beta=args.gemm1_beta,
566566
gemm1_clamp_limit=args.gemm1_clamp_limit,
567-
gemm2_bias=None)
567+
gemm2_bias=args.gemm2_bias)
568568

569569
return run_moe_dequant(args_dequant, "fp4"), args_dequant
570570

@@ -1557,6 +1557,12 @@ def run_moe_fp4_gptoss_test(self, num_tokens: int, hidden_size: int,
15571557
(num_experts, hidden_size, intermediate_size),
15581558
device='cuda',
15591559
dtype=torch.bfloat16)
1560+
gemm2_bias = 50 * torch.randn(
1561+
num_experts, hidden_size, device='cuda', dtype=torch.float)
1562+
1563+
# waived due to missing kernel support for bias in nvfp4
1564+
gemm1_bias[:] = 0
1565+
gemm2_bias[:] = 0
15601566

15611567
use_ue8m0 = False
15621568
# Quantize hidden states. Produces scales for activations in 128x4 layout for ref impl.
@@ -1650,7 +1656,7 @@ def run_moe_fp4_gptoss_test(self, num_tokens: int, hidden_size: int,
16501656
gemm1_alpha=swiglu_alpha_tensor,
16511657
gemm1_beta=swiglu_beta_tensor,
16521658
gemm1_clamp_limit=swiglu_limit_tensor,
1653-
gemm2_bias=None)
1659+
gemm2_bias=gemm2_bias)
16541660
#
16551661
# Run the reference implementations
16561662
#
@@ -1691,6 +1697,7 @@ def run_moe_fp4_gptoss_test(self, num_tokens: int, hidden_size: int,
16911697
gemm1_bias_shuffled = []
16921698
gemm2_weights_fp4_shuffled = []
16931699
gemm2_scales_fp4_shuffled = []
1700+
gemm2_bias_shuffled = []
16941701
for i in range(num_experts):
16951702
gemm1_weights_fp4_shuffled.append(
16961703
shuffle_matrix_a(
@@ -1711,6 +1718,10 @@ def run_moe_fp4_gptoss_test(self, num_tokens: int, hidden_size: int,
17111718
gemm2_scales_linear_fp4[i].view(torch.uint8),
17121719
epilogue_tile_m))
17131720

1721+
gemm2_bias_shuffled.append(
1722+
shuffle_matrix_a(gemm2_bias[i].clone().reshape(-1, 1),
1723+
epilogue_tile_m))
1724+
17141725
# Stack weights for all experts
17151726
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
17161727
gemm1_scales_fp4_shuffled = torch.stack(gemm1_scales_fp4_shuffled).view(
@@ -1725,13 +1736,20 @@ def run_moe_fp4_gptoss_test(self, num_tokens: int, hidden_size: int,
17251736
gemm1_bias_shuffled = torch.stack(gemm1_bias_shuffled).reshape(
17261737
num_experts, -1)
17271738

1739+
gemm2_bias_shuffled = torch.stack(gemm2_bias_shuffled).reshape(
1740+
num_experts, -1)
1741+
17281742
# NOTE: correct the beta and clamp to account for the global scale factor
17291743
# Check cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/GemmGatedActOptions.h
17301744
# for more details
17311745
swiglu_beta_tensor = swiglu_beta_tensor * args.gemm1_scales_global * args.hidden_states_scale_global
17321746
swiglu_limit_tensor = swiglu_limit_tensor * args.gemm1_scales_global * args.hidden_states_scale_global
1747+
# Check cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmInterface.h
1748+
# for more details
17331749
gemm1_bias_shuffled = gemm1_bias_shuffled * args.gemm1_scales_global[:,
17341750
None] * args.hidden_states_scale_global
1751+
gemm2_bias_shuffled = gemm2_bias_shuffled * args_dequant.c_global_sf * args.gemm2_scales_global[:,
1752+
None]
17351753

17361754
#
17371755
# Run the TRT-LLM kernel
@@ -1765,6 +1783,7 @@ def run_moe_fp4_gptoss_test(self, num_tokens: int, hidden_size: int,
17651783
swiglu_limit_tensor,
17661784
gemm2_weights_fp4_shuffled,
17671785
gemm2_scales_fp4_shuffled,
1786+
gemm2_bias_shuffled,
17681787
scale_c_fc1,
17691788
scale_gate_fc1,
17701789
scale_c_fc2,

0 commit comments

Comments
 (0)