Skip to content

Commit 3ddc9d2

Browse files
authored
[https://nvbugs/5729697][fix] MNNVL Allreduce: use CUDA runtime instead of Macro to get SM version. (#10062)
Signed-off-by: Shiyu Li <shili@nvidia.com>
1 parent 48c875f commit 3ddc9d2

File tree

2 files changed

+63
-63
lines changed

2 files changed

+63
-63
lines changed

cpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.cu

Lines changed: 62 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -230,59 +230,62 @@ inline __device__ __host__ T divUp(T m, T n)
230230
// Return (block_size, cluster_size, loads_per_thread)
231231
std::tuple<int, int, int> adjustGridConfig(int numTokens, int dim, int eltsPerThread)
232232
{
233-
// Start with preferred block_size and cluster_size
234-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
235-
int clusterSize = 8;
236-
#else
237-
int clusterSize = 1;
238-
#endif
233+
static int SM = tensorrt_llm::common::getSMVersion();
234+
235+
int clusterSize = SM >= 90 ? 8 : 1;
239236
int blockSize = 128;
240237
// ========================== Adjust the grid configuration ==========================
241238
int threadsNeeded = divUp(dim, eltsPerThread);
242239
int loadsPerThread = 1;
243240

244241
blockSize = divUp(threadsNeeded, clusterSize);
245-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
246-
while (threadsNeeded % clusterSize != 0 && clusterSize > 1)
247-
{
248-
clusterSize /= 2;
249-
}
250-
blockSize = divUp(threadsNeeded, clusterSize);
251-
while (blockSize < 128 && clusterSize >= 2)
252-
{
253-
blockSize *= 2;
254-
clusterSize /= 2;
255-
}
256-
int smCount = getMultiProcessorCount();
257-
while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512)
242+
if (clusterSize > 1)
258243
{
259-
blockSize *= 2;
260-
clusterSize /= 2;
244+
while (threadsNeeded % clusterSize != 0 && clusterSize > 1)
245+
{
246+
clusterSize /= 2;
247+
}
248+
blockSize = divUp(threadsNeeded, clusterSize);
249+
while (blockSize < 128 && clusterSize >= 2)
250+
{
251+
blockSize *= 2;
252+
clusterSize /= 2;
253+
}
254+
int smCount = getMultiProcessorCount();
255+
while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512)
256+
{
257+
blockSize *= 2;
258+
clusterSize /= 2;
259+
}
261260
}
262-
#endif
263261

264262
// Trying to scale up use multiple loads or CGA
265263
while (blockSize > 1024)
266264
{
267-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
268-
if (clusterSize < 8)
265+
// Scale up with CGA if supported
266+
if (SM >= 90)
269267
{
270-
clusterSize = clusterSize << 1;
271-
}
272-
else
273-
{
274-
break;
275-
}
276-
#else
277-
if (loadsPerThread < 8)
278-
{
279-
loadsPerThread += 1;
268+
if (clusterSize < 8)
269+
{
270+
clusterSize = clusterSize << 1;
271+
}
272+
else
273+
{
274+
break;
275+
}
280276
}
281277
else
282278
{
283-
break;
279+
280+
if (loadsPerThread < 8)
281+
{
282+
loadsPerThread += 1;
283+
}
284+
else
285+
{
286+
break;
287+
}
284288
}
285-
#endif
286289
blockSize = divUp(threadsNeeded, clusterSize * loadsPerThread);
287290
}
288291
return {blockSize, clusterSize, loadsPerThread};
@@ -420,9 +423,9 @@ __global__ void __launch_bounds__(1024) oneshotAllreduceFusionKernel(T* outputPt
420423
}
421424
float blockSum = blockReduceSum<float, true>(threadSum);
422425

423-
__shared__ float sharedVal[8]; // Temporary variable to share the sum within block
424426
float fullSum = blockSum;
425427
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
428+
__shared__ float sharedVal[8]; // Temporary variable to share the sum within block
426429
namespace cg = cooperative_groups;
427430
cg::cluster_group cluster = cg::this_cluster();
428431
int const numBlocks = cluster.num_blocks();
@@ -459,45 +462,40 @@ using detail::adjustGridConfig;
459462

460463
void oneshotAllreduceFusionOp(AllReduceFusionParams const& params)
461464
{
465+
466+
static int const kSMVersion = tensorrt_llm::common::getSMVersion();
462467
int const numTokens = params.numTokens;
463468
int const tokenDim = params.tokenDim;
464469
int const eltsPerThread = sizeof(float4) / getDTypeSize(params.dType);
465470

466471
auto [blockSize, clusterSize, loadsPerThread] = adjustGridConfig(numTokens, tokenDim, eltsPerThread);
467472
dim3 grid(numTokens, clusterSize, 1);
468473

469-
TLLM_CHECK_WITH_INFO(blockSize <= 1024 && loadsPerThread == 1,
470-
"Hidden Dimension %d exceeds the maximum supported hidden dimension (%d)", tokenDim,
471-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
472-
1024 * 8 * eltsPerThread);
473-
#else
474-
1024 * eltsPerThread);
475-
#endif
476-
477474
TLLM_LOG_DEBUG(
478475
"[MNNVL AllReduceOneShot] Dispatch: grid size: (%d, %d, 1), block_size: %d, cluster_size: %d, "
479476
"loads_per_thread: %d, "
480477
"threads_needed: %d",
481478
numTokens, clusterSize, blockSize, clusterSize, loadsPerThread, divUp(tokenDim, eltsPerThread));
482479

480+
TLLM_CHECK_WITH_INFO(blockSize <= 1024 && loadsPerThread == 1,
481+
"Hidden Dimension %d exceeds the maximum supported hidden dimension (%d)", tokenDim,
482+
1024 * (kSMVersion >= 90 ? 8 : 1) * eltsPerThread);
483+
483484
cudaLaunchAttribute attrs[2];
484485
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
485486
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL() ? 1 : 0;
486-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
487487
attrs[1].id = cudaLaunchAttributeClusterDimension;
488488
attrs[1].val.clusterDim.x = 1;
489489
attrs[1].val.clusterDim.y = clusterSize;
490490
attrs[1].val.clusterDim.z = 1;
491-
#endif
492491

493-
cudaLaunchConfig_t config
494-
{
495-
.gridDim = grid, .blockDim = blockSize, .dynamicSmemBytes = 0, .stream = params.stream, .attrs = attrs,
496-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
497-
.numAttrs = 2,
498-
#else
499-
.numAttrs = 1,
500-
#endif
492+
cudaLaunchConfig_t config{
493+
.gridDim = grid,
494+
.blockDim = blockSize,
495+
.dynamicSmemBytes = 0,
496+
.stream = params.stream,
497+
.attrs = attrs,
498+
.numAttrs = kSMVersion >= 90 ? 2U : 1U,
501499
};
502500

503501
#define LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, T, RMSNORM) \
@@ -831,9 +829,9 @@ __global__ __launch_bounds__(1024) void rmsNormLamport(T_IN* outputPreNorm, T_OU
831829
float blockSum = blockReduceSum<float, true>(threadSum);
832830

833831
float fullSum = blockSum;
834-
__shared__ float sharedVal[8];
835832
// Use CGA Reduction if supported
836833
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
834+
__shared__ float sharedVal[8];
837835
int const numBlocks = cluster.num_blocks();
838836
if (numBlocks > 1)
839837
{
@@ -876,13 +874,19 @@ __global__ __launch_bounds__(1024) void rmsNormLamport(T_IN* outputPreNorm, T_OU
876874
}
877875
constexpr int kELTS_SIZE = sizeof(T_IN);
878876

877+
// Issue ACQBLK at the end. Assuming preceding kernel will not modify the buffer_flags.
878+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
879+
cudaGridDependencySynchronize();
880+
#endif
881+
879882
// Update the buffer pointers
880883
flag.waitAndUpdate({static_cast<uint32_t>(divUp<uint32_t>(numTokens, worldSize) * worldSize * dim * kELTS_SIZE),
881884
static_cast<uint32_t>(numTokens * dim * kELTS_SIZE), 0, 0});
882885
}
883886

884887
void twoshotAllreduceFusionOp(AllReduceFusionParams const& params)
885888
{
889+
static int const kSMVersion = tensorrt_llm::common::getSMVersion();
886890
int const numTokens = params.numTokens;
887891
int const tokenDim = params.tokenDim;
888892
int const numEltsPerThread = sizeof(float4) / getDTypeSize(params.dType);
@@ -959,17 +963,13 @@ void twoshotAllreduceFusionOp(AllReduceFusionParams const& params)
959963
rnConfig.attrs = rnAttrs;
960964
rnAttrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
961965
rnAttrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL() ? 1 : 0;
962-
#ifndef DISABLE_CGA
963966
rnAttrs[1].id = cudaLaunchAttributeClusterDimension;
964967
rnAttrs[1].val.clusterDim.x = 1;
965968
rnAttrs[1].val.clusterDim.y = rnClusterSize;
966969
rnAttrs[1].val.clusterDim.z = 1;
967-
rnConfig.numAttrs = 2;
968-
#else
969-
rnConfig.numAttrs = 1;
970-
#endif
970+
rnConfig.numAttrs = (kSMVersion >= 90) ? 2U : 1U;
971971

972-
bool const rnUseCGA = rnClusterSize > 1;
972+
bool const rnUseCGA = kSMVersion >= 90 && rnClusterSize > 1;
973973
int const dimPadded = divUp(tokenDim, numEltsPerThread * rnNumThreads) * numEltsPerThread * rnNumThreads;
974974
int const iters = dimPadded / rnNumThreads;
975975

tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def func(input, residual, norm_weight, eps, enable_fusion):
179179
], # Test for max_num_token fallback
180180
ids=lambda x: f"seqlen:{x}",
181181
)
182-
@pytest.mark.parametrize("hidden_size", [8, 2880, 7168, 7176, 8192],
182+
@pytest.mark.parametrize("hidden_size", [8, 2880, 7168, 7176, 8192, 16384],
183183
ids=lambda x: f"hidden:{x}")
184184
@pytest.mark.parametrize("dtype", [torch.bfloat16],
185185
ids=lambda x: f"dtype:{torch.finfo(x).dtype}")

0 commit comments

Comments
 (0)