Skip to content

Commit 8277381

Browse files
authored
[None][feat] update TRT-LLM Gen DS FP8 MoE cubins and optimize finalize kernel (#11104)
Signed-off-by: Nikita Korobov <14355239+nekorobov@users.noreply.github.com>
1 parent 48206f3 commit 8277381

File tree

923 files changed

+15328
-8881
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

923 files changed

+15328
-8881
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -240,14 +240,6 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
240240
}
241241
}
242242

243-
if (options.mUseDeepSeekFp8)
244-
{
245-
if (!acceptIf(options.mUseShuffledMatrixA == false, "useShuffledMatrixA should be false for DeepSeek Fp8"))
246-
{
247-
continue;
248-
}
249-
}
250-
251243
if (options.mFusedAct)
252244
{
253245
if (!acceptIf(options.mActType == static_cast<batchedGemm::gemmGatedAct::ActType>(mOptions.actType),
@@ -452,7 +444,7 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, int32_t va
452444
bmm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));
453445

454446
auto const err = bmm.run(config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount,
455-
tensorrt_llm::common::getEnvEnablePDL(), globalTrtllmGenBatchedGemmModuleCache);
447+
tensorrt_llm::common::getEnvEnablePDL(), /* pinnedHostBuffer */ nullptr, globalTrtllmGenBatchedGemmModuleCache);
456448

457449
CUresult cuErr = static_cast<CUresult>(err);
458450
char const* cuErrStr = nullptr;

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmEnums.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION &
2+
* SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION &
33
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Lines changed: 181 additions & 45 deletions
Large diffs are not rendered by default.

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmOptions.h

Lines changed: 127 additions & 46 deletions
Large diffs are not rendered by default.

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/Enums.h

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION &
2+
* SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION &
33
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -93,13 +93,38 @@ enum class BiasType : uint32_t
9393

9494
////////////////////////////////////////////////////////////////////////////////////////////////////
9595

96+
// Type of the element-wise activation to apply after the Gemm
97+
enum class EltwiseActType
98+
{
99+
None = 0,
100+
// Gelu is defined as the following operation:
101+
// act = x0 * phi(x0)
102+
// where x0 is the output of the Gemm
103+
// phi is the CDF of standard normal distribution approximated by
104+
// phi(x) = 0.5 * (1 + tanh(0.7978845608028654 * (x + 0.044715 * x * x * x)))
105+
Gelu,
106+
// Relu2 (also known as squared Relu) is defined as the following operation:
107+
// act = relu(x0) ^ 2
108+
// where x0 is the output of the Gemm.
109+
Relu2,
110+
};
111+
112+
////////////////////////////////////////////////////////////////////////////////////////////////////
113+
96114
enum class TileScheduler
97115
{
98116
// Static scheduler (Non-persistent).
99117
Static = 0,
100-
// Dynamic persistent scheduler. This is either based on an atomically incremented global work id
101-
// prior to SM100 archs, or the HW supported work id scheduler based on UGETNEXTWORKID for SM100+.
118+
// Dynamic persistent scheduler for SM100+.
102119
Persistent,
120+
// Static persistent scheduler. Launches a fixed grid size based on the number of SMs and uses
121+
// the underlying PersistentTileSchedulerSm90 for static work distribution. Each CTA iterates
122+
// through tiles and exits the loop by setting is_valid_tile to false when work is exhausted.
123+
StaticPersistent,
124+
// Dynamic persistent scheduler for SM90+ using atomicAdd on a global counter.
125+
// Uses DynamicPersistentPipelinedTileSchedulerSm90 which enables work-stealing among CTAs
126+
// by atomically fetching work tile indices from a global counter.
127+
PersistentSm90,
103128
};
104129

105130
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -154,6 +179,28 @@ BIAS_TYPE_FUNCTION(Mn)
154179

155180
////////////////////////////////////////////////////////////////////////////////////////////////////
156181

182+
// Helper function to check if a scheduler is persistent.
183+
inline bool isPersistentScheduler(TileScheduler scheduler)
184+
{
185+
return scheduler == TileScheduler::Persistent || scheduler == TileScheduler::StaticPersistent
186+
|| scheduler == TileScheduler::PersistentSm90;
187+
}
188+
189+
////////////////////////////////////////////////////////////////////////////////////////////////////
190+
191+
// Helper function to check if CTA rasterization order is compatible with clean early exit of the
192+
// kernel. Clean early exit requires CTA indices to increase monotonically along the batch
193+
// dimension, so when a CTA exits the kernel early, it exits with all valid tiles already done.
194+
// Zigzag or batch-major patterns are NOT compatible because they may cause valid tiles to be
195+
// skipped when exiting early.
196+
inline bool supportsCleanEarlyExit(CtaSwizzleType swizzleType, bool batchM, TileScheduler /* scheduler */)
197+
{
198+
return (
199+
batchM ? (swizzleType == CtaSwizzleType::RasterizeAlongN) : (swizzleType == CtaSwizzleType::RasterizeAlongM));
200+
}
201+
202+
////////////////////////////////////////////////////////////////////////////////////////////////////
203+
157204
} // namespace gemm
158205

159206
} // namespace batchedGemm

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/GemmGatedActOptions.h

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION &
2+
* SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION &
33
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -83,6 +83,8 @@ enum class ActType
8383
// where x0 and x1 are the raw numbers from Gemm, while scaleC and scaleGate are input scales,
8484
// beta' = beta / scaleAb, scaleC' = scaleC * scaleAb.
8585
GeGlu,
86+
// Placeholder for no activation; not implemented in codegen
87+
None,
8688
};
8789

8890
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -137,16 +139,26 @@ struct GemmGatedActOptions : public gemm::GemmOptions
137139
inline bool checkAndUpdateGemmGatedActOptions(
138140
gemmGatedAct::GemmGatedActOptions& options, tg::CudaArch cudaArch, bool updateOptions = true)
139141
{
142+
auto isValid = gemm::checkAndUpdateGemmOptions(options, cudaArch,
143+
/* tpGrpSize */ 1, updateOptions);
144+
if (!isValid)
145+
{
146+
return false;
147+
}
140148

149+
if (options.mActType == gemmGatedAct::ActType::None)
150+
{
151+
TLLM_CHECK_ERROR(false, "ActType None is not supported");
152+
}
141153
// tmpOut is already transposed at this stage
142154
auto const hiddenSizeStr = options.mTransposeMmaOutput ? "M" : "N";
143155
auto const hiddenSize = options.mTransposeMmaOutput ? options.mM : options.mN;
144156
auto const hiddenEpilogueTileSize = options.mTransposeMmaOutput ? options.mEpilogueTileM : options.mEpilogueTileN;
145157

146158
TLLM_CHECK_ERROR(hiddenSize % 2 == 0, hiddenSizeStr, " must be a multiple of 2.");
147159

148-
TLLM_CHECK_ERROR((options.mTransposeMmaOutput ^ options.mUseShuffledMatrixA) == 0,
149-
"Transpose mma output can only be used with shuffled A matrix. And vice versa.");
160+
TLLM_CHECK_ERROR((options.mTransposeMmaOutput && !options.mUseShuffledMatrix) == false,
161+
"Transpose mma output can only be used with shuffled matrix.");
150162

151163
if (options.mUseTmaStore)
152164
{
@@ -157,19 +169,11 @@ inline bool checkAndUpdateGemmGatedActOptions(
157169
if (options.mDtypeC == tg::Dtype::E2m1 || options.mDtypeC == tg::Dtype::MxE4m3)
158170
{
159171
int const outHiddenSize = (options.mTransposeMmaOutput ? options.mM : options.mN) / 2;
160-
int const hiddenGranularity = 4 * tg::dtypeNumEltsPerSf(options.mDtypeC);
172+
int const hiddenGranularity = 4 * options.mSfBlockSizeC;
161173
TLLM_CHECK_ERROR(outHiddenSize % hiddenGranularity == 0, "Output hidden size (", outHiddenSize,
162174
") must be a multiple of ", hiddenGranularity, " for block-scaled outputs.");
163175
}
164176

165-
auto isValid = gemm::checkAndUpdateGemmOptions(options, cudaArch,
166-
/* tpGrpSize */ 1, updateOptions);
167-
168-
if (!isValid)
169-
{
170-
return false;
171-
}
172-
173177
auto const validHiddenSize = options.mTransposeMmaOutput ? options.mValidM : options.mValidN;
174178
if (options.mUseDeepSeekFp8)
175179
{
@@ -178,12 +182,12 @@ inline bool checkAndUpdateGemmGatedActOptions(
178182
}
179183

180184
//
181-
if (options.mUseShuffledMatrixA)
185+
if (options.mUseShuffledMatrix)
182186
{
183187
auto const shuffleBlockSize = gemm::getShuffleBlockSize(options.mEpilogueTileM);
184188
TLLM_CHECK_ERROR(hiddenSize % (2 * shuffleBlockSize) == 0 && validHiddenSize % (2 * shuffleBlockSize) == 0,
185189
"M/validM must be a multiple of 2 * shuffle block size (", 2 * shuffleBlockSize,
186-
") when useShuffledMatrixA");
190+
") when useShuffledMatrix");
187191
}
188192
if (options.mNumSlicesForSplitK > 1)
189193
{

0 commit comments

Comments
 (0)