Skip to content

Commit 0ae7017

Browse files
authored
Unify two versions of AllReduce custom op (NVIDIA#3032)
* Rewrite unit test for unified allreduce op. Removing the legacy unit test. * Revise formats, fusion_op bindings. Put all tensors as optional inputs. * Move the MoeAllreduceOp to a separate custom op. * Move all the fusion patterns to the new version of the AllReduce fusion kernel. Remove the AllReduce strategy config. Revise the AllReduce strategies and fusion pattern definitions. * Add more TODOs, fixing minor bugs, and remove legacy code. Signed-off-by: Yukun He <[email protected]>
1 parent b87f26e commit 0ae7017

File tree

18 files changed

+1006
-492
lines changed

18 files changed

+1006
-492
lines changed

cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ void allreduce_fusion_kernel_launcher(AllReduceFusionParams const& params)
615615
TLLM_CHECK(params.hidden_dim % kElemsPerAccess<DType> == 0);
616616
static int SM = tensorrt_llm::common::getSMVersion();
617617
int token_num = params.size / params.hidden_dim;
618-
bool oneshot = use_oneshot(token_num);
618+
bool oneshot = params.use_oneshot;
619619
int cluster_num = token_num;
620620
std::array<int, NRanks> begin_tokens, token_num_per_ranks;
621621
if (!oneshot)

cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ struct AllReduceFusionParams
131131
void* rms_gamma;
132132
float rms_eps;
133133
float* scale_factor;
134+
bool use_oneshot;
134135
FP4QuantizationSFLayout layout = FP4QuantizationSFLayout::SWIZZLED;
135136
cudaStream_t stream;
136137
AllReduceFusionPattern pattern;

cpp/tensorrt_llm/kernels/customAllReduceKernels.h

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,11 @@ static constexpr int kLamportHiddenSizeThreshold = 256;
4646
enum class AllReduceStrategyType : int8_t
4747
{
4848
NCCL = 0,
49-
ONESHOT = 1,
50-
TWOSHOT = 2,
51-
UB = 3,
52-
AUTO = 4,
49+
MIN_LATENCY = 1,
50+
UB = 2,
51+
AUTO = 3,
52+
ONESHOT = 4,
53+
TWOSHOT = 5,
5354
};
5455

5556
enum class AllReduceStrategyConfig : int8_t
@@ -66,10 +67,36 @@ enum class AllReduceFusionOp : int8_t
6667
RESIDUAL_RMS_PREPOST_NORM = 3,
6768
RESIDUAL_RMS_NORM_QUANT_FP8 = 4,
6869
RESIDUAL_RMS_NORM_QUANT_NVFP4 = 5,
69-
MOE_ALLREDUCE_RESIDUAL_RMS_NORM = 6,
70-
RESIDUAL_RMS_NORM_AND_QUANT_NVFP4 = 7,
70+
RESIDUAL_RMS_NORM_OUT_QUANT_FP8 = 6,
71+
RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4 = 7,
72+
MOE_ALLREDUCE_RESIDUAL_RMS_NORM = 8,
7173
};
7274

75+
inline std::ostream& operator<<(std::ostream& os, AllReduceFusionOp op)
76+
{
77+
switch (op)
78+
{
79+
case AllReduceFusionOp::NONE: os << "NONE"; break;
80+
case AllReduceFusionOp::RESIDUAL_RMS_NORM: os << "RESIDUAL_RMS_NORM"; break;
81+
case AllReduceFusionOp::LAST_PROCESS_FOR_UB: os << "LAST_PROCESS_FOR_UB"; break;
82+
case AllReduceFusionOp::RESIDUAL_RMS_PREPOST_NORM: os << "RESIDUAL_RMS_PREPOST_NORM"; break;
83+
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8: os << "RESIDUAL_RMS_NORM_QUANT_FP8"; break;
84+
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4: os << "RESIDUAL_RMS_NORM_QUANT_NVFP4"; break;
85+
case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8: os << "RESIDUAL_RMS_NORM_OUT_QUANT_FP8"; break;
86+
case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4: os << "RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4"; break;
87+
case AllReduceFusionOp::MOE_ALLREDUCE_RESIDUAL_RMS_NORM: os << "MOE_ALLREDUCE_RESIDUAL_RMS_NORM"; break;
88+
default: os << "UNKNOWN"; break;
89+
}
90+
return os;
91+
}
92+
93+
inline std::string toString(AllReduceFusionOp op)
94+
{
95+
std::ostringstream oss;
96+
oss << op;
97+
return oss.str();
98+
}
99+
73100
struct AllReduceFusionParams
74101
{
75102
AllReduceFusionParams()

cpp/tensorrt_llm/pybind/runtime/bindings.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
*/
1717

1818
#include "bindings.h"
19+
#include "tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h"
1920
#include "tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.h"
21+
#include "tensorrt_llm/kernels/customAllReduceKernels.h"
2022
#include "tensorrt_llm/kernels/delayStream.h"
2123
#include "tensorrt_llm/runtime/cudaStream.h"
2224
#include "tensorrt_llm/runtime/decodingInput.h"
@@ -413,6 +415,26 @@ void initBindings(pybind11::module_& m)
413415
tensorrt_llm::kernels::invokeDelayStreamKernel(delay_micro_secs, stream);
414416
},
415417
"Delay kernel launch on the default stream");
418+
419+
py::enum_<tensorrt_llm::kernels::AllReduceFusionOp>(m, "AllReduceFusionOp")
420+
.value("NONE", tensorrt_llm::kernels::AllReduceFusionOp::NONE)
421+
.value("RESIDUAL_RMS_NORM", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM)
422+
.value("LAST_PROCESS_FOR_UB", tensorrt_llm::kernels::AllReduceFusionOp::LAST_PROCESS_FOR_UB)
423+
.value("RESIDUAL_RMS_PREPOST_NORM", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_PREPOST_NORM)
424+
.value("RESIDUAL_RMS_NORM_QUANT_FP8", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8)
425+
.value("RESIDUAL_RMS_NORM_QUANT_NVFP4", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4)
426+
.value("RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4",
427+
tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4)
428+
.value("RESIDUAL_RMS_NORM_OUT_QUANT_FP8",
429+
tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8);
430+
431+
py::enum_<tensorrt_llm::kernels::AllReduceStrategyType>(m, "AllReduceStrategy")
432+
.value("NCCL", tensorrt_llm::kernels::AllReduceStrategyType::NCCL)
433+
.value("MIN_LATENCY", tensorrt_llm::kernels::AllReduceStrategyType::MIN_LATENCY)
434+
.value("AUTO", tensorrt_llm::kernels::AllReduceStrategyType::AUTO)
435+
.value("UB", tensorrt_llm::kernels::AllReduceStrategyType::UB)
436+
.value("ONESHOT", tensorrt_llm::kernels::AllReduceStrategyType::ONESHOT)
437+
.value("TWOSHOT", tensorrt_llm::kernels::AllReduceStrategyType::TWOSHOT);
416438
}
417439

418440
} // namespace tensorrt_llm::pybind::runtime

0 commit comments

Comments
 (0)