Skip to content

Commit fb1009b

Browse files
committed
Fix bias issue
Signed-off-by: Dongfeng Yu <dongfengy@nvidia.com>
1 parent 19d6aa3 commit fb1009b

File tree

5 files changed

+51
-13
lines changed

5 files changed

+51
-13
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,8 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
253253
gemmData.mProblemDimensions.mK = k;
254254
gemmData.mProblemDimensions.mRank = 0;
255255
gemmData.mProblemDimensions.mWorldSize = 1;
256-
gemmData.mProblemDimensions.mValidM = m;
257-
gemmData.mProblemDimensions.mValidN = n;
256+
gemmData.mProblemDimensions.mValidM = n;
257+
gemmData.mProblemDimensions.mValidN = m;
258258
gemmData.mProblemDimensions.mValidK = k;
259259

260260
// Inputs
@@ -377,8 +377,8 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(int32_t m
377377
gemmData.mProblemDimensions.mRank = 0;
378378
gemmData.mProblemDimensions.mWorldSize = 1;
379379
gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
380-
gemmData.mProblemDimensions.mValidM = m;
381-
gemmData.mProblemDimensions.mValidN = n;
380+
gemmData.mProblemDimensions.mValidM = n;
381+
gemmData.mProblemDimensions.mValidN = m;
382382
gemmData.mProblemDimensions.mValidK = k;
383383
auto cmpFunc = [&configs, &gemmData, &bmm, &multiProcessorCount](int64_t idx0, int64_t idx1)
384384
{
@@ -450,6 +450,7 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(int32_t m
450450
std::vector<int64_t> validConfigIndices;
451451
for (auto const& configIndex : prioritizedIndices)
452452
{
453+
std::cout << "checking config index " << configIndex << std::endl;
453454
auto const& config = configs[configIndex];
454455
auto isValidConfig = bmm.isValidConfig(config, gemmData);
455456
if (isValidConfig)
@@ -493,8 +494,8 @@ bool TrtllmGenBatchedGemmRunner::isValidConfigIndex(int32_t configIndex, int32_t
493494
gemmData.mProblemDimensions.mRank = 0;
494495
gemmData.mProblemDimensions.mWorldSize = 1;
495496
gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
496-
gemmData.mProblemDimensions.mValidM = m;
497-
gemmData.mProblemDimensions.mValidN = n;
497+
gemmData.mProblemDimensions.mValidM = n;
498+
gemmData.mProblemDimensions.mValidN = m;
498499
gemmData.mProblemDimensions.mValidK = k;
499500

500501
auto const& config = configs[configIndex];

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmOptions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,13 @@ inline bool checkAndUpdateBatchedGemmOptions(
209209
}
210210
if (options.mFusedAct)
211211
{
212+
std::cout << "checking fused act options" << std::endl;
212213
// ensure that we check the fused options as well
213214
isValid = gemmGatedAct::checkAndUpdateGemmGatedActOptions(options, cudaArch, updateOptions);
214215
}
215216
else
216217
{
218+
std::cout << "checking gemm options" << std::endl;
217219
isValid = gemm::checkAndUpdateGemmOptions(options, cudaArch, 1 /* tpGrpSize */, updateOptions);
218220
}
219221

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/GemmGatedActOptions.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,11 @@ inline bool checkAndUpdateGemmGatedActOptions(
161161
") must be a multiple of ", hiddenGranularity, " for block-scaled outputs.");
162162
}
163163

164+
std::cout << "checking gemm options instead" << std::endl;
164165
auto isValid = gemm::checkAndUpdateGemmOptions(options, cudaArch,
165166
/* tpGrpSize */ 1, updateOptions);
167+
std::cout << "finished checking gemm options" << std::endl;
168+
std::cout << "the result is " << isValid << std::endl;
166169

167170
if (!isValid)
168171
{

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/GemmOptions.h

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@ inline int32_t getShuffleBlockSize(int epilogueTileM)
629629
inline bool checkAndUpdateGemmOptions(
630630
GemmOptions& options, tg::CudaArch cudaArch, int tpGrpSize, bool updateOptions = true)
631631
{
632+
std::cout << "Checking GemmOptions..." << std::endl;
632633
options.mWorldSize = tpGrpSize;
633634

634635
bool isBlackwell = tg::isArchBlackwell(cudaArch);
@@ -641,9 +642,11 @@ inline bool checkAndUpdateGemmOptions(
641642
}
642643
else
643644
{
645+
std::cout << "failed at dtypeB" << std::endl;
644646
return false;
645647
}
646648
}
649+
std::cout << "ckpt 0" << std::endl;
647650

648651
// If not specified, used the input dtypes as MMA dtypes (no cast required).
649652
if (options.mDtypeMmaA == tg::Dtype::Void)
@@ -654,6 +657,7 @@ inline bool checkAndUpdateGemmOptions(
654657
}
655658
else
656659
{
660+
std::cout << "failed at dtypeMmaA" << std::endl;
657661
return false;
658662
}
659663
}
@@ -665,6 +669,7 @@ inline bool checkAndUpdateGemmOptions(
665669
}
666670
else
667671
{
672+
std::cout << "failed at dtypeMmaB" << std::endl;
668673
return false;
669674
}
670675
}
@@ -686,8 +691,13 @@ inline bool checkAndUpdateGemmOptions(
686691
// It must not exceed the padded dimensions.
687692
if (options.mValidM > options.mM || options.mValidN > options.mN || options.mValidK > options.mK)
688693
{
694+
std::cout << "test validM/N/K start" << std::endl;
695+
std::cout << "options.mValidM=" << options.mValidM << ", options.mM=" << options.mM << std::endl;
696+
std::cout << "options.mValidN=" << options.mValidN << ", options.mN=" << options.mN << std::endl;
697+
std::cout << "options.mValidK=" << options.mValidK << ", options.mK=" << options.mK << std::endl;
689698
TLLM_LOG_WARNING(options.mValidK <= options.mK,
690699
"ValidM, ValidN, and ValidK must be less than or equal to M, N, and K respectively.");
700+
std::cout << "test validM/N/K start2" << std::endl;
691701
if (updateOptions)
692702
{
693703
options.mValidM = std::min(options.mValidM, options.mM);
@@ -696,6 +706,7 @@ inline bool checkAndUpdateGemmOptions(
696706
}
697707
else
698708
{
709+
std::cout << "failed at validM/N/K" << std::endl;
699710
return false;
700711
}
701712
}
@@ -706,10 +717,12 @@ inline bool checkAndUpdateGemmOptions(
706717
bool hasValidParams = (options.mValidM != -1 && options.mValidM != options.mM)
707718
|| (options.mValidN != -1 && options.mValidN != options.mN)
708719
|| (options.mValidK != -1 && options.mValidK != options.mK);
720+
std::cout << "test BlockMajorK start" << std::endl;
709721
TLLM_CHECK_ERROR(!hasValidParams,
710722
"BlockMajorK layout does not support validM/validN/validK parameters due to swizzled layout. "
711723
"Found validM=",
712724
options.mValidM, " validN=", options.mValidN, " validK=", options.mValidK);
725+
std::cout << "test BlockMajorK start2" << std::endl;
713726
}
714727

715728
#ifdef TLLM_PUBLIC_RELEASE
@@ -718,7 +731,7 @@ inline bool checkAndUpdateGemmOptions(
718731
TLLM_CHECK_ERROR(false, "E2m1 x E4m3 is not supported for JIT compile. Use cubins instead.");
719732
}
720733
#endif // TLLM_PUBLIC_RELEASE
721-
734+
std::cout << "ckpt 1" << std::endl;
722735
// Check that the A cast is supported.
723736
// Currently, we only support {MxFp4, NvFp4} -> Bf16.
724737
TLLM_CHECK_ERROR((options.mDtypeA == options.mDtypeMmaA)
@@ -762,7 +775,7 @@ inline bool checkAndUpdateGemmOptions(
762775
TLLM_CHECK_ERROR(options.mDtypeMmaB == tg::Dtype::E4m3 || options.mDtypeMmaB == tg::Dtype::E2m1,
763776
"For dtypeMmaA = E4m3/E2m1 A, dtypeMmaB must also be E4m3/E2m1.");
764777
}
765-
778+
std::cout << "ckpt 2" << std::endl;
766779
// kind::mxf8f6f4
767780
if (options.mDtypeMmaA == tg::Dtype::MxE4m3 || options.mDtypeMmaA == tg::Dtype::MxE2m1)
768781
{
@@ -774,7 +787,7 @@ inline bool checkAndUpdateGemmOptions(
774787
TLLM_CHECK_ERROR(options.mDtypeMmaA == tg::Dtype::MxE4m3 || options.mDtypeMmaA == tg::Dtype::MxE2m1,
775788
"For dtypeMmaB = MxE4m3 or MxE2m1, dtypeMmaA must also be MxE4m3 or MxE2m1.");
776789
}
777-
790+
std::cout << "ckpt 3" << std::endl;
778791
// kind::f16
779792
if (options.mDtypeMmaA == tg::Dtype::Fp16 || options.mDtypeMmaA == tg::Dtype::Bfloat16)
780793
{
@@ -806,6 +819,7 @@ inline bool checkAndUpdateGemmOptions(
806819
}
807820
else
808821
{
822+
std::cout << "failed at mmaKind" << std::endl;
809823
return false;
810824
}
811825
}
@@ -822,6 +836,7 @@ inline bool checkAndUpdateGemmOptions(
822836
}
823837
else
824838
{
839+
std::cout << "failed at mmaK" << std::endl;
825840
return false;
826841
}
827842
}
@@ -852,7 +867,7 @@ inline bool checkAndUpdateGemmOptions(
852867
"Hopper does not use TMEM. The register layout corresponds to 16dp256bit. Got ", options.mEpilogueLdtmDps,
853868
"dp", options.mEpilogueLdtmBits, "bit.");
854869
}
855-
870+
std::cout << "ckpt 4" << std::endl;
856871
// Constraints for NvFp4 and MxFp8.
857872
if ((options.mMmaKind == tg::MmaKind::MxFp4NvFp4 || options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4
858873
|| options.mDtypeC == tg::Dtype::MxE4m3)
@@ -872,6 +887,7 @@ inline bool checkAndUpdateGemmOptions(
872887
}
873888
else
874889
{
890+
std::cout << "failed at mmaM" << std::endl;
875891
return false;
876892
}
877893
}
@@ -916,6 +932,7 @@ inline bool checkAndUpdateGemmOptions(
916932
}
917933
else
918934
{
935+
std::cout << "failed at mmaK" << std::endl;
919936
return false;
920937
}
921938
}
@@ -1022,6 +1039,7 @@ inline bool checkAndUpdateGemmOptions(
10221039
}
10231040
else
10241041
{
1042+
std::cout << "failed at dtypeC" << std::endl;
10251043
return false;
10261044
}
10271045
}
@@ -1037,6 +1055,7 @@ inline bool checkAndUpdateGemmOptions(
10371055
}
10381056
else
10391057
{
1058+
std::cout << "failed at epilogueTileM" << std::endl;
10401059
return false;
10411060
}
10421061
}
@@ -1051,6 +1070,7 @@ inline bool checkAndUpdateGemmOptions(
10511070
}
10521071
else
10531072
{
1073+
std::cout << "failed at epilogueTileN" << std::endl;
10541074
return false;
10551075
}
10561076
}
@@ -1066,6 +1086,7 @@ inline bool checkAndUpdateGemmOptions(
10661086
}
10671087
else
10681088
{
1089+
std::cout << "failed at epilogueTileM/N" << std::endl;
10691090
return false;
10701091
}
10711092
}
@@ -1080,6 +1101,7 @@ inline bool checkAndUpdateGemmOptions(
10801101
}
10811102
else
10821103
{
1104+
std::cout << "failed at epilogueTileM" << std::endl;
10831105
return false;
10841106
}
10851107
}
@@ -1200,6 +1222,7 @@ inline bool checkAndUpdateGemmOptions(
12001222
}
12011223
else
12021224
{
1225+
std::cout << "failed at epilogueTileM/N" << std::endl;
12031226
return false;
12041227
}
12051228
}
@@ -1223,6 +1246,7 @@ inline bool checkAndUpdateGemmOptions(
12231246
}
12241247
else
12251248
{
1249+
std::cout << "failed at mmaStages" << std::endl;
12261250
return false;
12271251
}
12281252
}
@@ -1234,6 +1258,7 @@ inline bool checkAndUpdateGemmOptions(
12341258
}
12351259
else
12361260
{
1261+
std::cout << "failed at mmaStages" << std::endl;
12371262
return false;
12381263
}
12391264
}
@@ -1245,6 +1270,7 @@ inline bool checkAndUpdateGemmOptions(
12451270
}
12461271
else
12471272
{
1273+
std::cout << "failed at mmaStages" << std::endl;
12481274
return false;
12491275
}
12501276
}
@@ -1341,6 +1367,7 @@ inline bool checkAndUpdateGemmOptions(
13411367
}
13421368
else
13431369
{
1370+
std::cout << "failed at tileM" << std::endl;
13441371
return false;
13451372
}
13461373
}
@@ -1355,6 +1382,7 @@ inline bool checkAndUpdateGemmOptions(
13551382
}
13561383
else
13571384
{
1385+
std::cout << "failed at numSlicesForSliceK" << std::endl;
13581386
return false;
13591387
}
13601388
}
@@ -1399,6 +1427,7 @@ inline bool checkAndUpdateGemmOptions(
13991427
}
14001428
else
14011429
{
1430+
std::cout << "failed at unrollLoop2xForMma" << std::endl;
14021431
return false;
14031432
}
14041433
}
@@ -1419,6 +1448,7 @@ inline bool checkAndUpdateGemmOptions(
14191448
}
14201449
else
14211450
{
1451+
std::cout << "failed at tileScheduler" << std::endl;
14221452
return false;
14231453
}
14241454
}
@@ -1434,6 +1464,7 @@ inline bool checkAndUpdateGemmOptions(
14341464
}
14351465
else
14361466
{
1467+
std::cout << "failed at earlyExit" << std::endl;
14371468
return false;
14381469
}
14391470
}
@@ -1521,6 +1552,7 @@ inline bool checkAndUpdateGemmOptions(
15211552
}
15221553
else
15231554
{
1555+
std::cout << "failed at blockK" << std::endl;
15241556
return false;
15251557
}
15261558
}

tests/unittest/_torch/thop/parallel/test_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,8 +1550,8 @@ def run_moe_fp4_gptoss_test(self, num_tokens: int, hidden_size: int,
15501550
num_experts, hidden_size, device='cuda', dtype=torch.float)
15511551

15521552
# waived due to missing kernel support for bias in nvfp4
1553-
gemm1_bias[:] = 0
1554-
gemm2_bias[:] = 0
1553+
#gemm1_bias[:] = 0
1554+
#gemm2_bias[:] = 0
15551555

15561556
use_ue8m0 = False
15571557
# Quantize hidden states. Produces scales for activations in 128x4 layout for ref impl.
@@ -1793,7 +1793,7 @@ def run_moe_fp4_gptoss_test(self, num_tokens: int, hidden_size: int,
17931793
output_dequant_actual,
17941794
atol=0.2,
17951795
rtol=0.2,
1796-
percent=0.9)
1796+
percent=0.85)
17971797

17981798
def run_moe_fp8_fp4_test(self, num_tokens: int, hidden_size: int,
17991799
intermediate_size: int, routing_info: dict,

0 commit comments

Comments
 (0)