Skip to content

Commit 4dc426f

Browse files
authored
[cherry-pick]Refactor Heterogenous Pipeline Parameter Server (#37446)
* bug fix for DeserializeSelectedRows. test=develop (#36520) * fix SerializeSelectedRows (#36543) * bug fix for DeserializeSelectedRows. test=develop * fix bug for SerializeSelectedRows. test=develop * update. test=develop * [Heterps]Refactor Heter Pipeline Parameter Server (#36845) * change username * fix * fix * fix * fix * fix * update * update * update unittests * fix * update * fix * update * fix * fix * fix * update * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update send_and_recv op. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix ut. test=develop * fix unit. notest,test=coverage * fix ut. notest, test=coverage * update. notest,test=coverage * fix ut. notest, test=coverage * fix ut. notest, test=coverage * fix. notest, test=coverage * fix. notest, test=coverage * fix ut. notest, test=coverage * fix ut. notest, test=coverage * fix ut. notest, test=coverage * fix ut. notest, test=coverage * add func. notest, test=coverage * fix ut. notest, test=coverage * fix. test=develop * fix. test=develop * Fix unit test for send_and_recv_cpu & send_and_recv_gpu (#37129) * [heterps]fix ut for heter_pipeline_trainer.cc (#37136) * fix ut. test=develop * fix ut. test=develop * [heterps]bug fix for local training with --heter_worker_num (#37166) * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix ut. test=develop * fix ut. test=develop * fix ut. test=develop * [heterps]Refactor heterogenous worker (#37244) * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix ut. test=develop * fix ut. test=develop * fix ut. test=develop * refactor heter trainer. test=develop * fix. test=develop * fix ut. test=develop * fix ut. test=develop * fix ut. test=develop * fix ut. test=develop * fix ut. test=develop * fix ut. test=develop * fix ut. test=develop * fix ut. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix ut. test=develop * fix ut. test=develop * fix ut. test=develop * [heterps]add heterps mode judgement (#37298) * [heterps]change default executor for heter trainer (#37314) * fix pslib. test=develop * add device to train_from_dataset. test=develop * refine fleet.stop_worker. test=develop * fix ut. test=develop * fix ut. test=develop * fix executor & ut. test=develop * fix executor & ut. test=develop * fix executor & ut. test=develop * [heterps]remove api for heter pipeline ps (#37396) * fix api. test=develop * fix api. test=develop * fix code style. test=release/2.2 * fix CMakeLists. test=develop (#37454)
1 parent 436808c commit 4dc426f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+4106
-767
lines changed

paddle/fluid/distributed/service/brpc_utils.cc

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -138,23 +138,11 @@ void SerializeSelectedRows(framework::Variable* var,
138138
var_data->clear();
139139
var_data->resize(rows->size() * sizeof(int64_t));
140140
char* data_ptr = const_cast<char*>(var_data->data());
141-
142-
if (platform::is_cpu_place(tensor->place())) {
143-
memcpy(data_ptr, &(*rows)[0], rows->size() * sizeof(int64_t));
144-
} else {
145-
#ifdef PADDLE_WITH_CUDA
146-
auto stream =
147-
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
148-
memory::Copy(platform::CPUPlace(), data_ptr,
149-
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()),
150-
&(*rows)[0], rows->size() * sizeof(int64_t), stream);
151-
#endif
152-
}
141+
memcpy(data_ptr, &((*rows)[0]), rows->size() * sizeof(int64_t));
153142
var_msg->set_data_type(static_cast<VarMsg::Type>(tensor->type()));
154143
for (auto& dim : framework::vectorize(tensor->dims())) {
155144
var_msg->add_dims(dim);
156145
}
157-
158146
// IO Buffer
159147
if (platform::is_cpu_place(tensor->place())) {
160148
auto data_len = tensor->numel() * framework::SizeOfType(tensor->type());
@@ -273,8 +261,8 @@ void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg,
273261
auto* slr = var->GetMutable<framework::SelectedRows>();
274262
framework::Tensor* tensor = slr->mutable_value();
275263
slr->set_height(msg.slr_height());
276-
std::vector<int64_t> tmp_rows(msg.slr_height());
277-
memcpy(&tmp_rows[0], msg.data().data(), msg.slr_height() * sizeof(int64_t));
264+
std::vector<int64_t> tmp_rows(msg.dims()[0]);
265+
memcpy(tmp_rows.data(), msg.data().data(), msg.dims()[0] * sizeof(int64_t));
278266
slr->set_rows(tmp_rows);
279267
std::vector<int> vec_dim;
280268
for (auto& x : msg.dims()) {

paddle/fluid/distributed/service/communicator.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/distributed/service/communicator.h"
16-
1716
#include <google/protobuf/text_format.h>
1817

1918
#include "gflags/gflags.h"
@@ -361,6 +360,8 @@ void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) {
361360
<< " from 0' trainer done";
362361
}
363362
}
363+
std::this_thread::sleep_for(
364+
std::chrono::milliseconds(100 + trainer_id_ * 10));
364365
BarrierWithTable(1);
365366
return;
366367
}
@@ -518,7 +519,6 @@ void AsyncCommunicator::SendByCommunicator() {
518519
MergeVars<float>(var_name, vars[i], send_scope_.get(), 1);
519520
}
520521
}
521-
522522
if (ctx.is_tensor_table) {
523523
SendGlobalStep(ctx, merged_var_num, send_scope_.get());
524524
} else if (ctx.is_sparse) {

paddle/fluid/distributed/service/heter_client.cc

Lines changed: 78 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,36 @@ namespace distributed {
2525
std::shared_ptr<HeterClient> HeterClient::s_instance_ = NULL;
2626
bool HeterClient::is_initialized_ = false;
2727

28+
int GetMicroId(const platform::DeviceContext& ctx,
29+
const framework::Scope* scope) {
30+
framework::Variable* var = scope->FindVar("microbatch_id");
31+
PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(), true,
32+
platform::errors::InvalidArgument(
33+
"the type of micro id shoulde be LoDTensor."));
34+
auto micro_id = -1;
35+
auto* tensor = var->GetMutable<framework::LoDTensor>();
36+
if (platform::is_cpu_place(tensor->place())) {
37+
auto data = reinterpret_cast<const float*>(tensor->data<void>());
38+
micro_id = static_cast<int>(data[0]);
39+
} else {
40+
#ifdef PADDLE_WITH_CUDA
41+
std::vector<char> temp;
42+
temp.resize(tensor->numel() * framework::SizeOfType(tensor->type()));
43+
char* temp_ptr = temp.data();
44+
auto stream =
45+
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
46+
memory::Copy(platform::CPUPlace(), temp_ptr,
47+
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()),
48+
tensor->data<void>(),
49+
tensor->numel() * framework::SizeOfType(tensor->type()),
50+
stream);
51+
float* temp_ptr_float = reinterpret_cast<float*>(temp_ptr);
52+
micro_id = static_cast<int>(temp_ptr_float[0]);
53+
#endif
54+
}
55+
return micro_id;
56+
}
57+
2858
void HeterClient::MainThread() {
2959
while (running_) {
3060
RpcProfilerControl();
@@ -99,43 +129,68 @@ void HeterClient::CreateClient2XpuConnection() {
99129
}
100130
}
101131
}
132+
previous_xpu_channels_.resize(previous_xpu_list_.size());
133+
for (size_t i = 0; i < previous_xpu_list_.size(); ++i) {
134+
previous_xpu_channels_[i].reset(new brpc::Channel());
135+
if (previous_xpu_channels_[i]->Init(previous_xpu_list_[i].c_str(), "",
136+
&options) != 0) {
137+
VLOG(0) << "HeterClient channel init fail. Try Again";
138+
auto ip_port = paddle::string::Split(previous_xpu_list_[i], ':');
139+
std::string ip = ip_port[0];
140+
int port = std::stoi(ip_port[1]);
141+
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
142+
if (previous_xpu_channels_[i]->Init(int_ip_port.c_str(), "", &options) !=
143+
0) {
144+
LOG(ERROR) << "BrpcPsServer start failed, ip_port= " << int_ip_port;
145+
}
146+
}
147+
}
102148
}
103149

104150
void HeterClient::SendAndRecvAsync(
105-
const std::vector<std::string>& ep, const platform::DeviceContext& ctx,
106-
const framework::Scope& scope, const std::string& message_name,
151+
const platform::DeviceContext& ctx, const framework::Scope& scope,
152+
const std::string& message_name,
107153
const std::vector<std::string>& send_var_name,
108-
const std::vector<std::string>& recv_var_name) {
154+
const std::vector<std::string>& recv_var_name, const std::string& mode) {
109155
platform::RecordEvent record_event("HeterClient->SendAndRecvAsync");
110156
const platform::DeviceContext* p_ctx = &ctx;
111157
const framework::Scope* p_scope = &scope;
112158
const std::string message_name_val = message_name;
113159
const std::vector<std::string> send_var_name_val = send_var_name;
114160
const std::vector<std::string> recv_var_name_val = recv_var_name;
115-
116-
VLOG(3) << "GRPCClient::SendAndRecv Begin, message_name: "
161+
VLOG(3) << "BRPCClient::SendAndRecv Begin, message_name: "
117162
<< message_name_val;
118-
// Todo: get correct channel
119-
int num = trainer_id_ % xpu_channels_.size();
120-
121-
brpc::Controller cntl;
122-
cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
123-
distributed::MultiVarMsg request, response;
124-
auto& request_io_buffer = cntl.request_attachment();
125-
::paddle::distributed::PsService_Stub stub(xpu_channels_[num].get());
163+
brpc::Channel* channel = nullptr;
164+
distributed::MultiVarMsg request;
165+
OnHeterRpcDone* closure = new OnHeterRpcDone([p_ctx, p_scope](void* done) {
166+
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
167+
PADDLE_ENFORCE_NE(
168+
closure->cntl.Failed(), true,
169+
platform::errors::Unimplemented(
170+
"HeterClient::SendAndRecv meets brpc error, error message is %s",
171+
closure->cntl.ErrorText()));
172+
173+
VLOG(4) << "call heter_worker success";
174+
});
175+
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
176+
auto& request_io_buffer = closure->cntl.request_attachment();
126177
distributed::SerializeToMultiVarMsgAndIOBuf(
127178
message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope,
128179
&request, &request_io_buffer);
129-
stub.SendAndRecvVariable(&cntl, &request, &response, NULL);
130-
PADDLE_ENFORCE_NE(
131-
cntl.Failed(), true,
132-
platform::errors::Unimplemented(
133-
"HeterClient::SendAndRecv meets brpc error, error message is %s",
134-
cntl.ErrorText()));
135-
VLOG(4) << "call heter_worker success";
136-
auto& response_io_buffer = cntl.response_attachment();
137-
distributed::DeserializeFromMultiVarMsgAndIOBuf(response, &response_io_buffer,
138-
ctx, p_scope);
180+
181+
int micro_id = GetMicroId(ctx, p_scope);
182+
auto minibatch_id = micro_id / 10;
183+
// select channel according to micro id
184+
if (mode == "forward") {
185+
int num = minibatch_id % xpu_channels_.size();
186+
channel = xpu_channels_[num].get();
187+
} else if (mode == "backward") {
188+
int num = minibatch_id % previous_xpu_channels_.size();
189+
channel = previous_xpu_channels_[num].get();
190+
}
191+
::paddle::distributed::PsService_Stub stub(channel);
192+
stub.SendAndRecvVariable(&closure->cntl, &request, &closure->response,
193+
closure);
139194
}
140195

141196
std::future<int32_t> HeterClient::SendCmd(

paddle/fluid/distributed/service/heter_client.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,23 @@ class HeterClient {
7676

7777
void CreateClient2XpuConnection();
7878

79-
void SendAndRecvAsync(const std::vector<std::string>& ep,
80-
const platform::DeviceContext& ctx,
79+
void SendAndRecvAsync(const platform::DeviceContext& ctx,
8180
const framework::Scope& scope,
8281
const std::string& message_name,
8382
const std::vector<std::string>& send_var_name,
84-
const std::vector<std::string>& recv_var_name);
83+
const std::vector<std::string>& recv_var_name,
84+
const std::string& mode = "forward");
8585

8686
// HeterClient singleton
8787
static std::shared_ptr<HeterClient> GetInstance(
88-
const std::vector<std::string>& endpoint, const int& trainer_id) {
88+
const std::vector<std::string>& endpoint,
89+
const std::vector<std::string>& previous_endpoint,
90+
const int& trainer_id) {
8991
if (NULL == s_instance_) {
9092
is_initialized_ = true;
9193
s_instance_.reset(new paddle::distributed::HeterClient());
9294
s_instance_->SetXpuList(endpoint);
95+
s_instance_->SetPreviousXpuList(previous_endpoint);
9396
s_instance_->SetTrainerID(trainer_id);
9497
s_instance_->CreateClient2XpuConnection();
9598
}
@@ -118,16 +121,22 @@ class HeterClient {
118121
xpu_list_ = xpu_list;
119122
}
120123

124+
void SetPreviousXpuList(const std::vector<std::string>& xpu_list) {
125+
previous_xpu_list_ = xpu_list;
126+
}
127+
121128
void SetTrainerID(const int& trainer_id) { trainer_id_ = trainer_id; }
122129

123130
private:
124131
static std::shared_ptr<HeterClient> s_instance_;
125132
static bool is_initialized_;
126133
std::unique_ptr<std::thread> main_thread_{nullptr};
127134
std::vector<std::shared_ptr<brpc::Channel>> xpu_channels_;
135+
std::vector<std::shared_ptr<brpc::Channel>> previous_xpu_channels_;
128136

129137
DISABLE_COPY_AND_ASSIGN(HeterClient);
130138
std::vector<std::string> xpu_list_;
139+
std::vector<std::string> previous_xpu_list_;
131140

132141
bool running_ = false;
133142
int trainer_id_;

paddle/fluid/distributed/service/heter_server.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,20 @@ void HeterServer::StartHeterService() {
4646
ready_ = 1;
4747
}
4848
condition_ready_.notify_all();
49-
5049
std::unique_lock<std::mutex> running_lock(mutex_);
50+
stoped_ = false;
5151
cv_.wait(running_lock, [&] {
5252
VLOG(1) << "Heter Server is Stop? " << stoped_;
5353
return stoped_;
5454
});
5555
}
5656

57-
void HeterServer::SetEndPoint(std::string& endpoint) {
57+
void HeterServer::SetEndPoint(const std::string& endpoint) {
5858
endpoint_ = endpoint;
5959
service_.SetEndpoint(endpoint);
6060
}
6161

62-
void HeterServer::SetFanin(int& fan_in) { service_.SetFanin(fan_in); }
62+
void HeterServer::SetFanin(const int& fan_in) { service_.SetFanin(fan_in); }
6363

6464
void HeterServer::WaitServerReady() {
6565
std::unique_lock<std::mutex> lock(this->mutex_ready_);

0 commit comments

Comments
 (0)