@@ -483,11 +483,8 @@ class AllreduceOp
483483 torch::optional<torch::Tensor> const & scale, torch::optional<torch::Tensor> const & bias)
484484 {
485485 auto const myRank = getRank ();
486- if (myRank == 0 )
487- {
488- TLLM_LOG_INFO (" [RANK 0] *** ENTERED runNCCLAllReduceDeviceFusion ***" );
489- TLLM_LOG_INFO (" [RANK 0] Fusion op: %s" , tensorrt_llm::kernels::toString (mOp ).c_str ());
490- }
486+ TLLM_LOG_DEBUG (" runNCCLAllReduceDeviceFusion: rank=%d, fusion_op=%s" , myRank,
487+ tensorrt_llm::kernels::toString (mOp ).c_str ());
491488
492489 TLLM_CHECK_WITH_INFO (tensorrt_llm::runtime::ub::ub_is_initialized (),
493490 " UserBuffer has not been initialized (required for NCCL_DEVICE)" );
@@ -532,52 +529,26 @@ class AllreduceOp
532529 return {norm_out};
533530 case AllReduceFusionOp::RESIDUAL_RMS_NORM:
534531 {
535- if (myRank == 0 )
536- {
537- TLLM_LOG_INFO (" [RANK 0] NCCL_DEVICE: Processing RESIDUAL_RMS_NORM fusion" );
538- }
539-
540532 TORCH_CHECK (norm_weight, " norm_weight is required for residual rms norm allreduce" );
541533 TORCH_CHECK (residual, " residual is required for residual rms norm allreduce" );
542534 TORCH_CHECK (!bias, " bias is not supported for residual rms norm allreduce" );
543535
544536 int const hidden_size = input.size (-1 );
545537 int const num_tokens = size / hidden_size;
546- if (myRank == 0 )
547- {
548- TLLM_LOG_INFO (
549- " [RANK 0] NCCL_DEVICE: hidden_size=%d, num_tokens=%d, nRanks=%d" , hidden_size, num_tokens, nRanks);
550- }
551- // Get cached launch config from NCCLUserBufferAllocator
552- if (myRank == 0 )
553- {
554- TLLM_LOG_INFO (
555- " [RANK 0] Getting cached NCCL device launch config with: dtype=%d, hidden_size=%d, num_tokens=%d, "
556- " nRanks=%d" ,
557- static_cast <int >(mType ), hidden_size, num_tokens, nRanks);
558- }
538+
539+ TLLM_LOG_DEBUG (" NCCL_DEVICE RESIDUAL_RMS_NORM: rank=%d, hidden_size=%d, num_tokens=%d, nRanks=%d, dtype=%d" ,
540+ myRank, hidden_size, num_tokens, nRanks, static_cast <int >(mType ));
541+
559542 std::shared_ptr<tensorrt_llm::kernels::nccl_device::LaunchConfig> launchConfig
560543 = nccl_ub_allocator.getCachedNCCLDeviceLaunchConfig (
561544 mType , hidden_size, num_tokens, myRank, nRanks, true , false );
562545
563546 // Check if multimem is supported for this data type
564547 bool multimemSupported = launchConfig->supportsMultimem ();
565- if (myRank == 0 )
566- {
567- TLLM_LOG_INFO (" [RANK 0] NCCL_DEVICE: Checking multimem support..." );
568- TLLM_LOG_INFO (" [RANK 0] NCCL_DEVICE: - supportsMultimem() = %s" , multimemSupported ? " TRUE" : " FALSE" );
569- TLLM_LOG_INFO (" [RANK 0] NCCL_DEVICE: - nRanks = %d" , nRanks);
570- TLLM_LOG_INFO (
571- " [RANK 0] NCCL_DEVICE: - dataType = %d (0=FLOAT, 1=HALF, 7=BF16, 9=FP8)" , static_cast <int >(mType ));
572- TLLM_LOG_INFO (" [RANK 0] NCCL_DEVICE: - hidden_size = %d, num_tokens = %d" , hidden_size, num_tokens);
573- }
548+ TLLM_LOG_DEBUG (" NCCL_DEVICE: rank=%d, supportsMultimem=%s" , myRank, multimemSupported ? " true" : " false" );
549+
574550 if (multimemSupported)
575551 {
576- if (myRank == 0 )
577- {
578- TLLM_LOG_INFO (
579- " [RANK 0] NCCL_DEVICE: *** Multimem IS SUPPORTED - launching fused NCCL device kernel ***" );
580- }
581552 ncclWindow_t inWindow = ub_buffer0.window ;
582553 ncclWindow_t outWindow = ub_buffer1.window ;
583554 TLLM_CHECK (inWindow != nullptr );
@@ -590,10 +561,8 @@ class AllreduceOp
590561
591562 launchConfig->launchRMSNorm (inWindow, outWindow, residual.value ().data_ptr (), ub_buffer2.window ,
592563 norm_weight.value ().data_ptr (), nullptr , devComm, mEps , stream);
593- if (myRank == 0 )
594- {
595- TLLM_LOG_INFO (" [RANK 0] NCCL_DEVICE: Fused kernel launched successfully" );
596- }
564+
565+ TLLM_LOG_DEBUG (" NCCL_DEVICE: rank=%d, fused kernel launched successfully" , myRank);
597566 return {norm_out, residual_out};
598567 }
599568 // Fall back to old strategy with warning
@@ -1142,9 +1111,11 @@ class AllreduceOp
11421111 {
11431112 if (mStrategy != AllReduceStrategyType::AUTO)
11441113 {
1145- // For UB,NCCL,NCCL_SYMMETRIC, the correctness of the strategy dispatching is guaranteed by the user.
1114+ // For UB,NCCL,NCCL_SYMMETRIC,NCCL_DEVICE, the correctness of the strategy dispatching is guaranteed by the
1115+ // user.
11461116 if (mStrategy == AllReduceStrategyType::UB || mStrategy == AllReduceStrategyType::NCCL
1147- || mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC)
1117+ || mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC
1118+ || mStrategy == AllReduceStrategyType::NCCL_DEVICE)
11481119 {
11491120 return mStrategy ;
11501121 }
0 commit comments