Skip to content

Commit d2d086a

Browse files
committed
clean up
Signed-off-by: Dongfeng Yu <[email protected]>
1 parent fb1009b commit d2d086a

File tree

6 files changed

+8
-120
lines changed

6 files changed

+8
-120
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,6 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
211211
int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas, void* workspace, CUstream stream, int device,
212212
int32_t configIndex)
213213
{
214-
std::cout << "run 1 fixed" << std::endl;
215-
std::cout << ptrBias << std::endl;
216214
auto bmm = BatchedGemmInterface();
217215

218216
BatchedGemmData gemmData;
@@ -253,8 +251,8 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
253251
gemmData.mProblemDimensions.mK = k;
254252
gemmData.mProblemDimensions.mRank = 0;
255253
gemmData.mProblemDimensions.mWorldSize = 1;
256-
gemmData.mProblemDimensions.mValidM = n;
257-
gemmData.mProblemDimensions.mValidN = m;
254+
gemmData.mProblemDimensions.mValidM = mOptions.transposeMmaOutput ? n : m;
255+
gemmData.mProblemDimensions.mValidN = mOptions.transposeMmaOutput ? m : n;
258256
gemmData.mProblemDimensions.mValidK = k;
259257

260258
// Inputs
@@ -310,8 +308,6 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
310308
void const* a, void const* sfA, void const* b, void const* sfB, void* c, void* outSfC, void* workspace,
311309
CUstream stream, int device, int32_t configIndex)
312310
{
313-
std::cout << "run 2" << std::endl;
314-
std::cout << "no bias" << std::endl;
315311
// Dispatch with block scaling factors and with static batching.
316312
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a, sfA, b, sfB,
317313
/* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr,
@@ -327,8 +323,6 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
327323
float const* ptrBeta, float const* ptrClampLimit, void* c, void* outSfC, void* workspace, CUstream stream,
328324
int device, int32_t configIndex)
329325
{
330-
std::cout << "run 3" << std::endl;
331-
std::cout << ptrBias << std::endl;
332326
// Dispatch with block scaling factors and with static batching.
333327
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a, sfA, b, sfB,
334328
/* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr,
@@ -342,8 +336,6 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
342336
void const* a, void const* b, float const* scaleC, float const* scaleGateC, void* c, void* workspace,
343337
CUstream stream, int device, int32_t configIndex)
344338
{
345-
std::cout << "run 4" << std::endl;
346-
std::cout << "no bias" << std::endl;
347339
// Dispatch with block scaling factors and with static batching.
348340
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a,
349341
/* sfA */ nullptr, b, /* sfB */ nullptr, /* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr, scaleC,
@@ -377,8 +369,8 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(int32_t m
377369
gemmData.mProblemDimensions.mRank = 0;
378370
gemmData.mProblemDimensions.mWorldSize = 1;
379371
gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
380-
gemmData.mProblemDimensions.mValidM = n;
381-
gemmData.mProblemDimensions.mValidN = m;
372+
gemmData.mProblemDimensions.mValidM = mOptions.transposeMmaOutput ? n : m;
373+
gemmData.mProblemDimensions.mValidN = mOptions.transposeMmaOutput ? m : n;
382374
gemmData.mProblemDimensions.mValidK = k;
383375
auto cmpFunc = [&configs, &gemmData, &bmm, &multiProcessorCount](int64_t idx0, int64_t idx1)
384376
{
@@ -450,7 +442,6 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(int32_t m
450442
std::vector<int64_t> validConfigIndices;
451443
for (auto const& configIndex : prioritizedIndices)
452444
{
453-
std::cout << "checking config index " << configIndex << std::endl;
454445
auto const& config = configs[configIndex];
455446
auto isValidConfig = bmm.isValidConfig(config, gemmData);
456447
if (isValidConfig)
@@ -494,8 +485,8 @@ bool TrtllmGenBatchedGemmRunner::isValidConfigIndex(int32_t configIndex, int32_t
494485
gemmData.mProblemDimensions.mRank = 0;
495486
gemmData.mProblemDimensions.mWorldSize = 1;
496487
gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
497-
gemmData.mProblemDimensions.mValidM = n;
498-
gemmData.mProblemDimensions.mValidN = m;
488+
gemmData.mProblemDimensions.mValidM = mOptions.transposeMmaOutput ? n : m;
489+
gemmData.mProblemDimensions.mValidN = mOptions.transposeMmaOutput ? m : n;
499490
gemmData.mProblemDimensions.mValidK = k;
500491

501492
auto const& config = configs[configIndex];

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmOptions.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,13 +209,11 @@ inline bool checkAndUpdateBatchedGemmOptions(
209209
}
210210
if (options.mFusedAct)
211211
{
212-
std::cout << "checking fused act options" << std::endl;
213212
// ensure that we check the fused options as well
214213
isValid = gemmGatedAct::checkAndUpdateGemmGatedActOptions(options, cudaArch, updateOptions);
215214
}
216215
else
217216
{
218-
std::cout << "checking gemm options" << std::endl;
219217
isValid = gemm::checkAndUpdateGemmOptions(options, cudaArch, 1 /* tpGrpSize */, updateOptions);
220218
}
221219

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/GemmGatedActOptions.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,8 @@ inline bool checkAndUpdateGemmGatedActOptions(
161161
") must be a multiple of ", hiddenGranularity, " for block-scaled outputs.");
162162
}
163163

164-
std::cout << "checking gemm options instead" << std::endl;
165164
auto isValid = gemm::checkAndUpdateGemmOptions(options, cudaArch,
166165
/* tpGrpSize */ 1, updateOptions);
167-
std::cout << "finished checking gemm options" << std::endl;
168-
std::cout << "the result is " << isValid << std::endl;
169166

170167
if (!isValid)
171168
{

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/GemmOptions.h

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,6 @@ inline int32_t getShuffleBlockSize(int epilogueTileM)
629629
inline bool checkAndUpdateGemmOptions(
630630
GemmOptions& options, tg::CudaArch cudaArch, int tpGrpSize, bool updateOptions = true)
631631
{
632-
std::cout << "Checking GemmOptions..." << std::endl;
633632
options.mWorldSize = tpGrpSize;
634633

635634
bool isBlackwell = tg::isArchBlackwell(cudaArch);
@@ -642,11 +641,9 @@ inline bool checkAndUpdateGemmOptions(
642641
}
643642
else
644643
{
645-
std::cout << "failed at dtypeB" << std::endl;
646644
return false;
647645
}
648646
}
649-
std::cout << "ckpt 0" << std::endl;
650647

651648
// If not specified, used the input dtypes as MMA dtypes (no cast required).
652649
if (options.mDtypeMmaA == tg::Dtype::Void)
@@ -657,7 +654,6 @@ inline bool checkAndUpdateGemmOptions(
657654
}
658655
else
659656
{
660-
std::cout << "failed at dtypeMmaA" << std::endl;
661657
return false;
662658
}
663659
}
@@ -669,7 +665,6 @@ inline bool checkAndUpdateGemmOptions(
669665
}
670666
else
671667
{
672-
std::cout << "failed at dtypeMmaB" << std::endl;
673668
return false;
674669
}
675670
}
@@ -691,13 +686,8 @@ inline bool checkAndUpdateGemmOptions(
691686
// It must not exceed the padded dimensions.
692687
if (options.mValidM > options.mM || options.mValidN > options.mN || options.mValidK > options.mK)
693688
{
694-
std::cout << "test validM/N/K start" << std::endl;
695-
std::cout << "options.mValidM=" << options.mValidM << ", options.mM=" << options.mM << std::endl;
696-
std::cout << "options.mValidN=" << options.mValidN << ", options.mN=" << options.mN << std::endl;
697-
std::cout << "options.mValidK=" << options.mValidK << ", options.mK=" << options.mK << std::endl;
698689
TLLM_LOG_WARNING(options.mValidK <= options.mK,
699690
"ValidM, ValidN, and ValidK must be less than or equal to M, N, and K respectively.");
700-
std::cout << "test validM/N/K start2" << std::endl;
701691
if (updateOptions)
702692
{
703693
options.mValidM = std::min(options.mValidM, options.mM);
@@ -706,7 +696,6 @@ inline bool checkAndUpdateGemmOptions(
706696
}
707697
else
708698
{
709-
std::cout << "failed at validM/N/K" << std::endl;
710699
return false;
711700
}
712701
}
@@ -717,12 +706,10 @@ inline bool checkAndUpdateGemmOptions(
717706
bool hasValidParams = (options.mValidM != -1 && options.mValidM != options.mM)
718707
|| (options.mValidN != -1 && options.mValidN != options.mN)
719708
|| (options.mValidK != -1 && options.mValidK != options.mK);
720-
std::cout << "test BlockMajorK start" << std::endl;
721709
TLLM_CHECK_ERROR(!hasValidParams,
722710
"BlockMajorK layout does not support validM/validN/validK parameters due to swizzled layout. "
723711
"Found validM=",
724712
options.mValidM, " validN=", options.mValidN, " validK=", options.mValidK);
725-
std::cout << "test BlockMajorK start2" << std::endl;
726713
}
727714

728715
#ifdef TLLM_PUBLIC_RELEASE
@@ -731,7 +718,6 @@ inline bool checkAndUpdateGemmOptions(
731718
TLLM_CHECK_ERROR(false, "E2m1 x E4m3 is not supported for JIT compile. Use cubins instead.");
732719
}
733720
#endif // TLLM_PUBLIC_RELEASE
734-
std::cout << "ckpt 1" << std::endl;
735721
// Check that the A cast is supported.
736722
// Currently, we only support {MxFp4, NvFp4} -> Bf16.
737723
TLLM_CHECK_ERROR((options.mDtypeA == options.mDtypeMmaA)
@@ -775,7 +761,6 @@ inline bool checkAndUpdateGemmOptions(
775761
TLLM_CHECK_ERROR(options.mDtypeMmaB == tg::Dtype::E4m3 || options.mDtypeMmaB == tg::Dtype::E2m1,
776762
"For dtypeMmaA = E4m3/E2m1 A, dtypeMmaB must also be E4m3/E2m1.");
777763
}
778-
std::cout << "ckpt 2" << std::endl;
779764
// kind::mxf8f6f4
780765
if (options.mDtypeMmaA == tg::Dtype::MxE4m3 || options.mDtypeMmaA == tg::Dtype::MxE2m1)
781766
{
@@ -787,7 +772,6 @@ inline bool checkAndUpdateGemmOptions(
787772
TLLM_CHECK_ERROR(options.mDtypeMmaA == tg::Dtype::MxE4m3 || options.mDtypeMmaA == tg::Dtype::MxE2m1,
788773
"For dtypeMmaB = MxE4m3 or MxE2m1, dtypeMmaA must also be MxE4m3 or MxE2m1.");
789774
}
790-
std::cout << "ckpt 3" << std::endl;
791775
// kind::f16
792776
if (options.mDtypeMmaA == tg::Dtype::Fp16 || options.mDtypeMmaA == tg::Dtype::Bfloat16)
793777
{
@@ -819,7 +803,6 @@ inline bool checkAndUpdateGemmOptions(
819803
}
820804
else
821805
{
822-
std::cout << "failed at mmaKind" << std::endl;
823806
return false;
824807
}
825808
}
@@ -836,7 +819,6 @@ inline bool checkAndUpdateGemmOptions(
836819
}
837820
else
838821
{
839-
std::cout << "failed at mmaK" << std::endl;
840822
return false;
841823
}
842824
}
@@ -867,7 +849,6 @@ inline bool checkAndUpdateGemmOptions(
867849
"Hopper does not use TMEM. The register layout corresponds to 16dp256bit. Got ", options.mEpilogueLdtmDps,
868850
"dp", options.mEpilogueLdtmBits, "bit.");
869851
}
870-
std::cout << "ckpt 4" << std::endl;
871852
// Constraints for NvFp4 and MxFp8.
872853
if ((options.mMmaKind == tg::MmaKind::MxFp4NvFp4 || options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4
873854
|| options.mDtypeC == tg::Dtype::MxE4m3)
@@ -887,7 +868,6 @@ inline bool checkAndUpdateGemmOptions(
887868
}
888869
else
889870
{
890-
std::cout << "failed at mmaM" << std::endl;
891871
return false;
892872
}
893873
}
@@ -932,7 +912,6 @@ inline bool checkAndUpdateGemmOptions(
932912
}
933913
else
934914
{
935-
std::cout << "failed at mmaK" << std::endl;
936915
return false;
937916
}
938917
}
@@ -1039,7 +1018,6 @@ inline bool checkAndUpdateGemmOptions(
10391018
}
10401019
else
10411020
{
1042-
std::cout << "failed at dtypeC" << std::endl;
10431021
return false;
10441022
}
10451023
}
@@ -1055,7 +1033,6 @@ inline bool checkAndUpdateGemmOptions(
10551033
}
10561034
else
10571035
{
1058-
std::cout << "failed at epilogueTileM" << std::endl;
10591036
return false;
10601037
}
10611038
}
@@ -1070,7 +1047,6 @@ inline bool checkAndUpdateGemmOptions(
10701047
}
10711048
else
10721049
{
1073-
std::cout << "failed at epilogueTileN" << std::endl;
10741050
return false;
10751051
}
10761052
}
@@ -1086,7 +1062,6 @@ inline bool checkAndUpdateGemmOptions(
10861062
}
10871063
else
10881064
{
1089-
std::cout << "failed at epilogueTileM/N" << std::endl;
10901065
return false;
10911066
}
10921067
}
@@ -1101,7 +1076,6 @@ inline bool checkAndUpdateGemmOptions(
11011076
}
11021077
else
11031078
{
1104-
std::cout << "failed at epilogueTileM" << std::endl;
11051079
return false;
11061080
}
11071081
}
@@ -1222,7 +1196,6 @@ inline bool checkAndUpdateGemmOptions(
12221196
}
12231197
else
12241198
{
1225-
std::cout << "failed at epilogueTileM/N" << std::endl;
12261199
return false;
12271200
}
12281201
}
@@ -1246,7 +1219,6 @@ inline bool checkAndUpdateGemmOptions(
12461219
}
12471220
else
12481221
{
1249-
std::cout << "failed at mmaStages" << std::endl;
12501222
return false;
12511223
}
12521224
}
@@ -1258,7 +1230,6 @@ inline bool checkAndUpdateGemmOptions(
12581230
}
12591231
else
12601232
{
1261-
std::cout << "failed at mmaStages" << std::endl;
12621233
return false;
12631234
}
12641235
}
@@ -1270,7 +1241,6 @@ inline bool checkAndUpdateGemmOptions(
12701241
}
12711242
else
12721243
{
1273-
std::cout << "failed at mmaStages" << std::endl;
12741244
return false;
12751245
}
12761246
}
@@ -1367,7 +1337,6 @@ inline bool checkAndUpdateGemmOptions(
13671337
}
13681338
else
13691339
{
1370-
std::cout << "failed at tileM" << std::endl;
13711340
return false;
13721341
}
13731342
}
@@ -1382,7 +1351,6 @@ inline bool checkAndUpdateGemmOptions(
13821351
}
13831352
else
13841353
{
1385-
std::cout << "failed at numSlicesForSliceK" << std::endl;
13861354
return false;
13871355
}
13881356
}
@@ -1427,7 +1395,6 @@ inline bool checkAndUpdateGemmOptions(
14271395
}
14281396
else
14291397
{
1430-
std::cout << "failed at unrollLoop2xForMma" << std::endl;
14311398
return false;
14321399
}
14331400
}
@@ -1448,7 +1415,6 @@ inline bool checkAndUpdateGemmOptions(
14481415
}
14491416
else
14501417
{
1451-
std::cout << "failed at tileScheduler" << std::endl;
14521418
return false;
14531419
}
14541420
}
@@ -1464,7 +1430,6 @@ inline bool checkAndUpdateGemmOptions(
14641430
}
14651431
else
14661432
{
1467-
std::cout << "failed at earlyExit" << std::endl;
14681433
return false;
14691434
}
14701435
}
@@ -1552,7 +1517,6 @@ inline bool checkAndUpdateGemmOptions(
15521517
}
15531518
else
15541519
{
1555-
std::cout << "failed at blockK" << std::endl;
15561520
return false;
15571521
}
15581522
}

0 commit comments

Comments
 (0)