Skip to content

Commit 6b9ec8d

Browse files
committed
exclude NCCL_DEVICE from stupid checks
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
1 parent dc090ea commit 6b9ec8d

File tree

1 file changed

+14
-43
lines changed

1 file changed

+14
-43
lines changed

cpp/tensorrt_llm/thop/allreduceOp.cpp

Lines changed: 14 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -483,11 +483,8 @@ class AllreduceOp
483483
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias)
484484
{
485485
auto const myRank = getRank();
486-
if (myRank == 0)
487-
{
488-
TLLM_LOG_INFO("[RANK 0] *** ENTERED runNCCLAllReduceDeviceFusion ***");
489-
TLLM_LOG_INFO("[RANK 0] Fusion op: %s", tensorrt_llm::kernels::toString(mOp).c_str());
490-
}
486+
TLLM_LOG_DEBUG("runNCCLAllReduceDeviceFusion: rank=%d, fusion_op=%s", myRank,
487+
tensorrt_llm::kernels::toString(mOp).c_str());
491488

492489
TLLM_CHECK_WITH_INFO(tensorrt_llm::runtime::ub::ub_is_initialized(),
493490
"UserBuffer has not been initialized (required for NCCL_DEVICE)");
@@ -532,52 +529,26 @@ class AllreduceOp
532529
return {norm_out};
533530
case AllReduceFusionOp::RESIDUAL_RMS_NORM:
534531
{
535-
if (myRank == 0)
536-
{
537-
TLLM_LOG_INFO("[RANK 0] NCCL_DEVICE: Processing RESIDUAL_RMS_NORM fusion");
538-
}
539-
540532
TORCH_CHECK(norm_weight, "norm_weight is required for residual rms norm allreduce");
541533
TORCH_CHECK(residual, "residual is required for residual rms norm allreduce");
542534
TORCH_CHECK(!bias, "bias is not supported for residual rms norm allreduce");
543535

544536
int const hidden_size = input.size(-1);
545537
int const num_tokens = size / hidden_size;
546-
if (myRank == 0)
547-
{
548-
TLLM_LOG_INFO(
549-
"[RANK 0] NCCL_DEVICE: hidden_size=%d, num_tokens=%d, nRanks=%d", hidden_size, num_tokens, nRanks);
550-
}
551-
// Get cached launch config from NCCLUserBufferAllocator
552-
if (myRank == 0)
553-
{
554-
TLLM_LOG_INFO(
555-
"[RANK 0] Getting cached NCCL device launch config with: dtype=%d, hidden_size=%d, num_tokens=%d, "
556-
"nRanks=%d",
557-
static_cast<int>(mType), hidden_size, num_tokens, nRanks);
558-
}
538+
539+
TLLM_LOG_DEBUG("NCCL_DEVICE RESIDUAL_RMS_NORM: rank=%d, hidden_size=%d, num_tokens=%d, nRanks=%d, dtype=%d",
540+
myRank, hidden_size, num_tokens, nRanks, static_cast<int>(mType));
541+
559542
std::shared_ptr<tensorrt_llm::kernels::nccl_device::LaunchConfig> launchConfig
560543
= nccl_ub_allocator.getCachedNCCLDeviceLaunchConfig(
561544
mType, hidden_size, num_tokens, myRank, nRanks, true, false);
562545

563546
// Check if multimem is supported for this data type
564547
bool multimemSupported = launchConfig->supportsMultimem();
565-
if (myRank == 0)
566-
{
567-
TLLM_LOG_INFO("[RANK 0] NCCL_DEVICE: Checking multimem support...");
568-
TLLM_LOG_INFO("[RANK 0] NCCL_DEVICE: - supportsMultimem() = %s", multimemSupported ? "TRUE" : "FALSE");
569-
TLLM_LOG_INFO("[RANK 0] NCCL_DEVICE: - nRanks = %d", nRanks);
570-
TLLM_LOG_INFO(
571-
"[RANK 0] NCCL_DEVICE: - dataType = %d (0=FLOAT, 1=HALF, 7=BF16, 9=FP8)", static_cast<int>(mType));
572-
TLLM_LOG_INFO("[RANK 0] NCCL_DEVICE: - hidden_size = %d, num_tokens = %d", hidden_size, num_tokens);
573-
}
548+
TLLM_LOG_DEBUG("NCCL_DEVICE: rank=%d, supportsMultimem=%s", myRank, multimemSupported ? "true" : "false");
549+
574550
if (multimemSupported)
575551
{
576-
if (myRank == 0)
577-
{
578-
TLLM_LOG_INFO(
579-
"[RANK 0] NCCL_DEVICE: *** Multimem IS SUPPORTED - launching fused NCCL device kernel ***");
580-
}
581552
ncclWindow_t inWindow = ub_buffer0.window;
582553
ncclWindow_t outWindow = ub_buffer1.window;
583554
TLLM_CHECK(inWindow != nullptr);
@@ -590,10 +561,8 @@ class AllreduceOp
590561

591562
launchConfig->launchRMSNorm(inWindow, outWindow, residual.value().data_ptr(), ub_buffer2.window,
592563
norm_weight.value().data_ptr(), nullptr, devComm, mEps, stream);
593-
if (myRank == 0)
594-
{
595-
TLLM_LOG_INFO("[RANK 0] NCCL_DEVICE: Fused kernel launched successfully");
596-
}
564+
565+
TLLM_LOG_DEBUG("NCCL_DEVICE: rank=%d, fused kernel launched successfully", myRank);
597566
return {norm_out, residual_out};
598567
}
599568
// Fall back to old strategy with warning
@@ -1142,9 +1111,11 @@ class AllreduceOp
11421111
{
11431112
if (mStrategy != AllReduceStrategyType::AUTO)
11441113
{
1145-
// For UB,NCCL,NCCL_SYMMETRIC, the correctness of the strategy dispatching is guaranteed by the user.
1114+
// For UB,NCCL,NCCL_SYMMETRIC,NCCL_DEVICE, the correctness of the strategy dispatching is guaranteed by the
1115+
// user.
11461116
if (mStrategy == AllReduceStrategyType::UB || mStrategy == AllReduceStrategyType::NCCL
1147-
|| mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC)
1117+
|| mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC
1118+
|| mStrategy == AllReduceStrategyType::NCCL_DEVICE)
11481119
{
11491120
return mStrategy;
11501121
}

0 commit comments

Comments
 (0)