Skip to content

Commit 0adcd45

Browse files
committed
Merge branch 'main' into feature/adopt-ds-v3.2-encode
2 parents 3e5ef4b + 80e5be6 commit 0adcd45

File tree

16 files changed

+347
-31
lines changed

16 files changed

+347
-31
lines changed

rtp_llm/cpp/config/ConfigModules.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ std::string HWKernelConfig::to_string() const {
152152
<< "num_native_cuda_graph: " << num_native_cuda_graph << "\n"
153153
<< "prefill_capture_seq_lens size: " << prefill_capture_seq_lens.size() << "\n"
154154
<< "decode_capture_batch_sizes size: " << decode_capture_batch_sizes.size() << "\n"
155-
<< "disable_dpc_random: " << disable_dpc_random;
155+
<< "disable_dpc_random: " << disable_dpc_random << "\n"
156+
<< "rocm_disable_custom_ag" << rocm_disable_custom_ag;
156157
return oss.str();
157158
}
158159

rtp_llm/cpp/config/ConfigModules.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ struct HWKernelConfig {
153153
// Comma-separated list of batch sizes, e.g., "1,2,4,8,16,32"
154154
std::vector<int> decode_capture_batch_sizes;
155155
bool disable_dpc_random = false;
156+
bool rocm_disable_custom_ag = true;
156157
std::string to_string() const;
157158
};
158159

rtp_llm/cpp/devices/rocm_impl/ROCmDevice.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ ROCmDevice::ROCmDevice(const DeviceInitParams& params): DeviceBase(params) {
7171
auto& nccl_param = tp_nccl_param_;
7272
std::vector<size_t> tp_ranks = fcNcclGatherRanks(nccl_param, stream_);
7373
// Initialization may fail, and the variable will still be nullptr. When allreduce is called, it will fall back to the normal allreduce.
74-
custom_allreduce_comm_ = initCustomAllReduceComm(nccl_param, tp_ranks, stream_);
74+
custom_allreduce_comm_ = initCustomAllReduceComm(nccl_param, tp_ranks, stream_, params.hw_kernel_config);
7575
quick_allreduce_comm_ = initQuickAllReduceComm(nccl_param, tp_ranks, stream_);
7676
}
7777

rtp_llm/cpp/devices/rocm_impl/ROCmDistributedOp.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,31 @@ void ROCmDevice::allGather(const AllGatherParams& params) {
162162
"Buffer size %ld must be divisible by world size %d",
163163
recv_buffer->size(),
164164
nccl_param.world_size_);
165+
166+
// invoke aiter custom all-gather
167+
// custom all-gather is integrated into custom all-reduce
168+
bool use_custom_ag =
169+
params.mode == ParallelMode::TP
170+
and custom_allreduce_comm_
171+
and custom_allreduce_comm_->checkAllGatherAvailable();
172+
173+
if (use_custom_ag) {
174+
torch::Tensor input_tensor;
175+
176+
if (params.inplace) {
177+
auto option_ = torch::dtype(dataTypeToTorchType(recv_buffer->type())).device(memoryTypeToTorchDevice(recv_buffer->where())).requires_grad(false);
178+
std::vector<int64_t> shape_{static_cast<int64_t>(data_num)};
179+
input_tensor = torch::from_blob(recv_buffer->dataWithOffset(nccl_param.rank_ * data_num), shape_, option_);
180+
} else {
181+
input_tensor = Buffer2torchTensor(*(params.send_buffers[i]), false);
182+
}
183+
torch::Tensor output_tensor = Buffer2torchTensor(*recv_buffer, false);
184+
185+
custom_allreduce_comm_->allGather(input_tensor, output_tensor);
186+
187+
continue;
188+
}
189+
165190
if (params.inplace) {
166191
const auto data_size = data_num * recv_buffer->typeSize();
167192
NCCLCHECK(ncclAllGather((char*)(recv_buffer->data()) + nccl_param.rank_ * data_size,

rtp_llm/cpp/model_rpc/PrefillRpcServer.cc

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,7 @@ void PrefillRpcServer::getRpcConnection(PrefillGenerateContext& prefill_context)
131131

132132
// If no host specified in request, check if there's a master role
133133
char* remote_rpc_server_ip_env = std::getenv("REMOTE_RPC_SERVER_IP");
134-
bool has_master_role =
135-
(remote_rpc_server_ip_env != nullptr && strlen(remote_rpc_server_ip_env) > 0);
134+
bool has_master_role = (remote_rpc_server_ip_env != nullptr && strlen(remote_rpc_server_ip_env) > 0);
136135

137136
// If no host specified in request and no master role, this is a direct prefill request
138137
// In this case, we still need to select decode machines as specified in the requirements
@@ -293,7 +292,6 @@ void PrefillRpcServer::remoteGenerate(PrefillGenerateContext& prefill_context) {
293292
generate_request.mutable_propose_token_ids()->CopyFrom(
294293
{stream->getProposeToken().begin(), stream->getProposeToken().end()});
295294

296-
// TODO(yinzhi): trans propose probs and hidden states
297295
auto sp_output_buffer = stream->getSPOutputBuffer();
298296

299297
if (sp_output_buffer) {
@@ -407,11 +405,10 @@ grpc::Status PrefillRpcServer::GenerateStreamCall(grpc::ServerContext*
407405
meta_);
408406
prefill_context.onflight_requests = onflight_requests_;
409407
prefill_context.loading_cache_requests = loading_cache_requests_;
410-
411408

412-
auto max_retry_times = maga_init_params_.pd_sep_config.prefill_retry_times;
413-
auto max_retry_timeout_ms = maga_init_params_.pd_sep_config.prefill_retry_timeout_ms;
414-
int retry_interval_ms = 1;
409+
auto max_retry_times = maga_init_params_.pd_sep_config.prefill_retry_times;
410+
auto max_retry_timeout_ms = maga_init_params_.pd_sep_config.prefill_retry_timeout_ms;
411+
int retry_interval_ms = 1;
415412

416413
try {
417414
EXECUTE_WITH_RETRY(

rtp_llm/cpp/pybind/ConfigInit.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ PYBIND11_MODULE(libth_transformer_config, m) {
411411
.def_readwrite("prefill_capture_seq_lens", &HWKernelConfig::prefill_capture_seq_lens)
412412
.def_readwrite("decode_capture_batch_sizes", &HWKernelConfig::decode_capture_batch_sizes)
413413
.def_readwrite("disable_dpc_random", &HWKernelConfig::disable_dpc_random)
414+
.def_readwrite("rocm_disable_custom_ag", &HWKernelConfig::rocm_disable_custom_ag)
414415
.def("to_string", &HWKernelConfig::to_string)
415416
.def(py::pickle(
416417
[](const HWKernelConfig& self) {
@@ -427,10 +428,11 @@ PYBIND11_MODULE(libth_transformer_config, m) {
427428
self.num_native_cuda_graph,
428429
self.prefill_capture_seq_lens,
429430
self.decode_capture_batch_sizes,
430-
self.disable_dpc_random);
431+
self.disable_dpc_random,
432+
self.rocm_disable_custom_ag);
431433
},
432434
[](py::tuple t) {
433-
if (t.size() != 14)
435+
if (t.size() != 15)
434436
throw std::runtime_error("Invalid state!");
435437
HWKernelConfig c;
436438
try {
@@ -448,6 +450,7 @@ PYBIND11_MODULE(libth_transformer_config, m) {
448450
c.prefill_capture_seq_lens = t[11].cast<std::vector<int>>();
449451
c.decode_capture_batch_sizes = t[12].cast<std::vector<int>>();
450452
c.disable_dpc_random = t[13].cast<bool>();
453+
c.rocm_disable_custom_ag = t[14].cast<bool>();
451454
} catch (const std::exception& e) {
452455
throw std::runtime_error(std::string("HWKernelConfig unpickle error: ") + e.what());
453456
}

rtp_llm/cpp/rocm/custom_ar/custom_ar_comm.cc

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ using namespace std;
1414

1515
namespace rtp_llm {
1616

17-
CustomAllReduceComm::CustomAllReduceComm(const std::vector<size_t>& tp_ranks, size_t rank, size_t rank_index):
17+
CustomAllReduceComm::CustomAllReduceComm(const std::vector<size_t>& tp_ranks, size_t rank, size_t rank_index, const HWKernelConfig& hw_kernel_config):
1818
rank_(rank),
1919
rank_index_(rank_index),
2020
world_size_(tp_ranks.size()),
2121
support_nv_link_(true), // TODO(liyangcheng.lyc): add check function
2222
comm_buf_threshold_(getCommBufThreshold()),
23-
tp_ranks_(std::move(tp_ranks)) {}
23+
tp_ranks_(std::move(tp_ranks)),
24+
ft_disable_custom_ar_(hw_kernel_config.ft_disable_custom_ar),
25+
rocm_disable_custom_ag_(hw_kernel_config.rocm_disable_custom_ag) {}
2426

2527
CustomAllReduceComm::~CustomAllReduceComm() {
2628
aiter::dispose(fa_);
@@ -41,6 +43,15 @@ bool CustomAllReduceComm::checkAllReduceAvailable(size_t elts_total_num, DataTyp
4143
return false;
4244
}
4345

46+
bool CustomAllReduceComm::checkAllGatherAvailable() {
47+
if (rocm_disable_custom_ag_) {
48+
RTP_LLM_LOG_INFO("Disable custom ag since ROCM_DISABLE_CUSTOM_AG is set");
49+
return false;
50+
}
51+
52+
return true;
53+
}
54+
4455
void CustomAllReduceComm::allReduce(torch::Tensor& input_tensor, torch::Tensor& output_tensor) {
4556
if (at::hip::currentStreamCaptureStatusMayInitCtx() != at::hip::CaptureStatus::None) {
4657
aiter::all_reduce(fa_, input_tensor, output_tensor, false, std::nullopt);
@@ -49,6 +60,14 @@ void CustomAllReduceComm::allReduce(torch::Tensor& input_tensor, torch::Tensor&
4960
}
5061
}
5162

63+
void CustomAllReduceComm::allGather(torch::Tensor& input_tensor, torch::Tensor& output_tensor) {
64+
if (at::hip::currentStreamCaptureStatusMayInitCtx() != at::hip::CaptureStatus::None) {
65+
aiter::all_gather_reg(fa_, input_tensor, output_tensor);
66+
} else {
67+
aiter::all_gather_unreg(fa_, input_tensor, buffer_, output_tensor);
68+
}
69+
}
70+
5271
void CustomAllReduceComm::registerGraphBuffers() {
5372
auto handle_and_offset = aiter::get_graph_buffer_ipc_meta(fa_); // tuple<tensor, vector<int64_t>> -> vector<tensor> size=2
5473
auto handle = std::get<0>(handle_and_offset);
@@ -144,7 +163,7 @@ CustomAllReduceComm::prepareP2PBuffer_(const NcclParam& nccl_para, torch::Tensor
144163
return handles;
145164
}
146165

147-
bool CustomAllReduceComm::shouldCustomAR(const std::vector<size_t>& tp_ranks, size_t rank) {
166+
bool CustomAllReduceComm::shouldCustomAR(const std::vector<size_t>& tp_ranks, size_t rank, const HWKernelConfig& hw_kernel_config) {
148167
size_t world_size = tp_ranks.size();
149168
size_t local_world_size = rocm::getDeviceCount();
150169

@@ -158,9 +177,7 @@ bool CustomAllReduceComm::shouldCustomAR(const std::vector<size_t>& tp_ranks, si
158177
}
159178

160179
// 2. check whether disabled flag is set
161-
char* disable_custom_ar_str = std::getenv("FT_DISABLE_CUSTOM_AR");
162-
bool disable_custom_ar = disable_custom_ar_str != nullptr && std::string(disable_custom_ar_str) == "1";
163-
if (disable_custom_ar) {
180+
if (hw_kernel_config.ft_disable_custom_ar) {
164181
RTP_LLM_LOG_INFO("Disable custom ar since FT_DISABLE_CUSTOM_AR is set");
165182
return false;
166183
}
@@ -186,7 +203,7 @@ size_t CustomAllReduceComm::getCommBufThreshold() {
186203
}
187204

188205
std::unique_ptr<CustomAllReduceComm>
189-
initCustomAllReduceComm(const NcclParam& nccl_para, const std::vector<size_t>& tp_ranks, hipStream_t stream) {
206+
initCustomAllReduceComm(const NcclParam& nccl_para, const std::vector<size_t>& tp_ranks, hipStream_t stream, const HWKernelConfig& hw_kernel_config) {
190207
size_t rank_index = 0;
191208
for (size_t i = 0; i < tp_ranks.size(); i++) {
192209
if (tp_ranks[i] == nccl_para.rank_) {
@@ -195,11 +212,11 @@ initCustomAllReduceComm(const NcclParam& nccl_para, const std::vector<size_t>& t
195212
}
196213
}
197214

198-
if (!CustomAllReduceComm::shouldCustomAR(tp_ranks, nccl_para.rank_)) {
215+
if (!CustomAllReduceComm::shouldCustomAR(tp_ranks, nccl_para.rank_, hw_kernel_config)) {
199216
return nullptr;
200217
}
201218

202-
auto comm = std::make_unique<CustomAllReduceComm>(tp_ranks, nccl_para.rank_, rank_index);
219+
auto comm = std::make_unique<CustomAllReduceComm>(tp_ranks, nccl_para.rank_, rank_index, hw_kernel_config);
203220
comm->init(nccl_para, stream);
204221
RTP_LLM_LOG_INFO("Custom all reduce is enabled on rank %d of %d", nccl_para.rank_, tp_ranks.size());
205222
return comm;

rtp_llm/cpp/rocm/custom_ar/custom_ar_comm.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "rtp_llm/cpp/core/Types.h"
1111
#include "rtp_llm/cpp/cuda/nccl/nccl_utils.h"
1212
#include "rtp_llm/cpp/utils/Logger.h"
13+
#include "rtp_llm/cpp/config/ConfigModules.h"
1314

1415
// aiter custom all reduce kernel
1516
#include "custom_all_reduce.h"
@@ -18,17 +19,22 @@
1819
namespace rtp_llm {
1920
class CustomAllReduceComm {
2021
public:
21-
CustomAllReduceComm(const std::vector<size_t>& tp_ranks, size_t rank, size_t rank_index);
22+
CustomAllReduceComm(const std::vector<size_t>& tp_ranks, size_t rank, size_t rank_index, const HWKernelConfig& hw_kernel_config);
2223

2324
~CustomAllReduceComm();
2425

2526
void init(const NcclParam& nccl_para, hipStream_t stream);
2627

2728
void allReduce(torch::Tensor& input_tensor, torch::Tensor& output_tensor);
2829

30+
// NOTE(liyangcheng.lyc): the implementation of custom all gather is placed together with custom all reduce
31+
void allGather(torch::Tensor& input_tensor, torch::Tensor& output_tensor);
32+
2933
bool checkAllReduceAvailable(size_t elts_total_num, DataType data_type, size_t world_size);
3034

31-
static bool shouldCustomAR(const std::vector<size_t>& tp_ranks, size_t rank);
35+
bool checkAllGatherAvailable();
36+
37+
static bool shouldCustomAR(const std::vector<size_t>& tp_ranks, size_t rank, const HWKernelConfig& hw_kernel_config);
3238

3339
void registerGraphBuffers();
3440

@@ -55,9 +61,11 @@ class CustomAllReduceComm {
5561
torch::Tensor rank_data_;
5662
int64_t fa_;
5763
NcclParam nccl_para_;
64+
bool ft_disable_custom_ar_ = true;
65+
bool rocm_disable_custom_ag_ = true;
5866
};
5967

6068
std::unique_ptr<CustomAllReduceComm>
61-
initCustomAllReduceComm(const NcclParam& nccl_para, const std::vector<size_t>& tp_ranks, hipStream_t stream);
69+
initCustomAllReduceComm(const NcclParam& nccl_para, const std::vector<size_t>& tp_ranks, hipStream_t stream, const HWKernelConfig& hw_kernel_config);
6270

6371
} // namespace rtp_llm

rtp_llm/cpp/speculative_engine/SpeculativeEngine.cc

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,12 @@ absl::Status SpeculativeEngine::prefillMtpStep(std::list<GenerateStreamPtr>& str
563563

564564
RTP_LLM_LOG_DEBUG("update stream");
565565
for (GenerateStreamPtr& stream : streams) {
566-
SpeculativeExecutorStreamOutputPtr score_output = stream->getScoreStream()->getSPOutputBuffer();
566+
GenerateStreamPtr score_stream = stream->getScoreStream();
567+
if (checkStopAndSetError(score_stream, stream)) {
568+
continue;
569+
}
570+
571+
SpeculativeExecutorStreamOutputPtr score_output = score_stream->getSPOutputBuffer();
567572
StreamUpdateInfo update_info{score_output->tokens,
568573
(int)1,
569574
nullptr,
@@ -586,10 +591,17 @@ absl::Status SpeculativeEngine::prefillMtpStep(std::list<GenerateStreamPtr>& str
586591

587592
propose_begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds();
588593
RTP_LLM_LOG_DEBUG("propose model prefill");
589-
THROW_IF_STATUS_ERROR(propose_executor_->propose(streams, true));
594+
THROW_IF_STATUS_ERROR(propose_executor_->propose(streams));
590595

591596
for (const GenerateStreamPtr& stream : streams) {
592-
BufferPtr propose_tokens = stream->getProposeStream()->getSPOutputBuffer()->tokens;
597+
GenerateStreamPtr propose_stream = stream->getProposeStream();
598+
599+
// check propose stream status
600+
if (checkStopAndSetError(propose_stream, stream)) {
601+
continue;
602+
}
603+
604+
BufferPtr propose_tokens = propose_stream->getSPOutputBuffer()->tokens;
593605
vector<int> propose_tokens_vec;
594606
for (int i = 0; i < propose_tokens->shape()[1]; ++i) {
595607
propose_tokens_vec.push_back(propose_tokens->data<int>()[i]);
@@ -606,8 +618,7 @@ absl::Status SpeculativeEngine::prefillMtpStep(std::list<GenerateStreamPtr>& str
606618
RTP_LLM_LOG_DEBUG("stream [%ld] set setNeedRemoteGenerate", stream->streamId());
607619
stream->setNeedRemoteGenerate(true);
608620
}
609-
auto score_stream = stream->getScoreStream();
610-
auto propose_stream = stream->getProposeStream();
621+
auto score_stream = stream->getScoreStream();
611622
if (score_stream) {
612623
score_stream->setLastHiddenStates(nullptr);
613624
score_stream->setSPOutputBuffer(nullptr);
@@ -735,16 +746,31 @@ absl::Status SpeculativeEngine::mtpStep(std::list<GenerateStreamPtr>& streams) {
735746
}
736747
}
737748

749+
// check propose stream status
750+
for (const GenerateStreamPtr& stream : streams) {
751+
GenerateStreamPtr propose_stream = stream->getProposeStream();
752+
checkStopAndSetError(propose_stream, stream);
753+
}
754+
738755
// base model score propose new tokens.
739756
{
740757
RTP_LLM_LOG_DEBUG("score step");
741758
score_begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds();
742759
THROW_IF_STATUS_ERROR(score_executor_->score(streams));
743760

761+
std::list<GenerateStreamPtr> sample_streams;
762+
for (const GenerateStreamPtr& stream : streams) {
763+
GenerateStreamPtr score_stream = stream->getScoreStream();
764+
if (checkStopAndSetError(score_stream, stream)) {
765+
continue;
766+
}
767+
sample_streams.emplace_back(stream);
768+
}
769+
744770
if (device_->getDeviceProperties().tp_rank == 0) {
745771
RTP_LLM_LOG_DEBUG("sample step");
746772
sampler_begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds();
747-
CHECK_AND_RETURN_REF(sampler_output, speculative_sampler_->sample(streams));
773+
CHECK_AND_RETURN_REF(sampler_output, speculative_sampler_->sample(sample_streams));
748774
RTP_LLM_LOG_DEBUG("speculative sample done");
749775

750776
metrics_.propose_token_num += sampler_output.propose_token_num;
@@ -806,4 +832,22 @@ KVCacheInfo SpeculativeEngine::getCacheStatusInfo(int64_t latest_version, bool n
806832
return resource_context_.cache_manager->getKVCacheInfo(latest_version, need_cache_keys);
807833
}
808834

835+
bool SpeculativeEngine::checkStopAndSetError(const GenerateStreamPtr& check_stream,
836+
const GenerateStreamPtr& target_stream) {
837+
if (target_stream->stopped()) {
838+
return true;
839+
}
840+
841+
if (check_stream && check_stream->stopped()) {
842+
ErrorInfo error_info = check_stream->statusInfo();
843+
if (error_info.hasError()) {
844+
target_stream->setStop(error_info.code(), error_info.ToString());
845+
RTP_LLM_LOG_ERROR(
846+
"stream [%ld] stopped with error: %s", target_stream->streamId(), error_info.ToString().c_str());
847+
}
848+
return true;
849+
}
850+
return false;
851+
}
852+
809853
} // namespace rtp_llm

rtp_llm/cpp/speculative_engine/SpeculativeEngine.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ class SpeculativeEngine: public EngineBase {
142142

143143
bool updateEplbConfig(const EPLBConfig& config) override;
144144

145+
bool checkStopAndSetError(const GenerateStreamPtr& check_stream, const GenerateStreamPtr& target_stream);
146+
145147
private:
146148
kmonitor::MetricsReporterPtr metrics_reporter_ = nullptr;
147149
std::unique_ptr<ProposeModelEngineInitParams> propose_model_params_;

0 commit comments

Comments
 (0)