@@ -14,13 +14,15 @@ using namespace std;
1414
1515namespace rtp_llm {
1616
17- CustomAllReduceComm::CustomAllReduceComm (const std::vector<size_t >& tp_ranks, size_t rank, size_t rank_index):
17+ CustomAllReduceComm::CustomAllReduceComm (const std::vector<size_t >& tp_ranks, size_t rank, size_t rank_index, const HWKernelConfig& hw_kernel_config ):
1818 rank_ (rank),
1919 rank_index_ (rank_index),
2020 world_size_ (tp_ranks.size()),
2121 support_nv_link_ (true ), // TODO(liyangcheng.lyc): add check function
2222 comm_buf_threshold_ (getCommBufThreshold()),
23- tp_ranks_ (std::move(tp_ranks)) {}
23+ tp_ranks_ (std::move(tp_ranks)),
24+ ft_disable_custom_ar_ (hw_kernel_config.ft_disable_custom_ar),
25+ rocm_disable_custom_ag_ (hw_kernel_config.rocm_disable_custom_ag) {}
2426
2527CustomAllReduceComm::~CustomAllReduceComm () {
2628 aiter::dispose (fa_);
@@ -41,6 +43,15 @@ bool CustomAllReduceComm::checkAllReduceAvailable(size_t elts_total_num, DataTyp
4143 return false ;
4244}
4345
46+ bool CustomAllReduceComm::checkAllGatherAvailable () {
47+ if (rocm_disable_custom_ag_) {
48+ RTP_LLM_LOG_INFO (" Disable custom ag since ROCM_DISABLE_CUSTOM_AG is set" );
49+ return false ;
50+ }
51+
52+ return true ;
53+ }
54+
4455void CustomAllReduceComm::allReduce (torch::Tensor& input_tensor, torch::Tensor& output_tensor) {
4556 if (at::hip::currentStreamCaptureStatusMayInitCtx () != at::hip::CaptureStatus::None) {
4657 aiter::all_reduce (fa_, input_tensor, output_tensor, false , std::nullopt );
@@ -49,6 +60,14 @@ void CustomAllReduceComm::allReduce(torch::Tensor& input_tensor, torch::Tensor&
4960 }
5061}
5162
63+ void CustomAllReduceComm::allGather (torch::Tensor& input_tensor, torch::Tensor& output_tensor) {
64+ if (at::hip::currentStreamCaptureStatusMayInitCtx () != at::hip::CaptureStatus::None) {
65+ aiter::all_gather_reg (fa_, input_tensor, output_tensor);
66+ } else {
67+ aiter::all_gather_unreg (fa_, input_tensor, buffer_, output_tensor);
68+ }
69+ }
70+
5271void CustomAllReduceComm::registerGraphBuffers () {
5372 auto handle_and_offset = aiter::get_graph_buffer_ipc_meta (fa_); // tuple<tensor, vector<int64_t>> -> vector<tensor> size=2
5473 auto handle = std::get<0 >(handle_and_offset);
@@ -144,7 +163,7 @@ CustomAllReduceComm::prepareP2PBuffer_(const NcclParam& nccl_para, torch::Tensor
144163 return handles;
145164}
146165
147- bool CustomAllReduceComm::shouldCustomAR (const std::vector<size_t >& tp_ranks, size_t rank) {
166+ bool CustomAllReduceComm::shouldCustomAR (const std::vector<size_t >& tp_ranks, size_t rank, const HWKernelConfig& hw_kernel_config ) {
148167 size_t world_size = tp_ranks.size ();
149168 size_t local_world_size = rocm::getDeviceCount ();
150169
@@ -158,9 +177,7 @@ bool CustomAllReduceComm::shouldCustomAR(const std::vector<size_t>& tp_ranks, si
158177 }
159178
160179 // 2. check whether disabled flag is set
161- char * disable_custom_ar_str = std::getenv (" FT_DISABLE_CUSTOM_AR" );
162- bool disable_custom_ar = disable_custom_ar_str != nullptr && std::string (disable_custom_ar_str) == " 1" ;
163- if (disable_custom_ar) {
180+ if (hw_kernel_config.ft_disable_custom_ar ) {
164181 RTP_LLM_LOG_INFO (" Disable custom ar since FT_DISABLE_CUSTOM_AR is set" );
165182 return false ;
166183 }
@@ -186,7 +203,7 @@ size_t CustomAllReduceComm::getCommBufThreshold() {
186203}
187204
188205std::unique_ptr<CustomAllReduceComm>
189- initCustomAllReduceComm (const NcclParam& nccl_para, const std::vector<size_t >& tp_ranks, hipStream_t stream) {
206+ initCustomAllReduceComm (const NcclParam& nccl_para, const std::vector<size_t >& tp_ranks, hipStream_t stream, const HWKernelConfig& hw_kernel_config ) {
190207 size_t rank_index = 0 ;
191208 for (size_t i = 0 ; i < tp_ranks.size (); i++) {
192209 if (tp_ranks[i] == nccl_para.rank_ ) {
@@ -195,11 +212,11 @@ initCustomAllReduceComm(const NcclParam& nccl_para, const std::vector<size_t>& t
195212 }
196213 }
197214
198- if (!CustomAllReduceComm::shouldCustomAR (tp_ranks, nccl_para.rank_ )) {
215+ if (!CustomAllReduceComm::shouldCustomAR (tp_ranks, nccl_para.rank_ , hw_kernel_config )) {
199216 return nullptr ;
200217 }
201218
202- auto comm = std::make_unique<CustomAllReduceComm>(tp_ranks, nccl_para.rank_ , rank_index);
219+ auto comm = std::make_unique<CustomAllReduceComm>(tp_ranks, nccl_para.rank_ , rank_index, hw_kernel_config );
203220 comm->init (nccl_para, stream);
204221 RTP_LLM_LOG_INFO (" Custom all reduce is enabled on rank %d of %d" , nccl_para.rank_ , tp_ranks.size ());
205222 return comm;
0 commit comments