@@ -25,6 +25,36 @@ namespace distributed {
25
25
std::shared_ptr<HeterClient> HeterClient::s_instance_ = NULL ;
26
26
bool HeterClient::is_initialized_ = false ;
27
27
28
+ int GetMicroId (const platform::DeviceContext& ctx,
29
+ const framework::Scope* scope) {
30
+ framework::Variable* var = scope->FindVar (" microbatch_id" );
31
+ PADDLE_ENFORCE_EQ (var->IsType <framework::LoDTensor>(), true ,
32
+ platform::errors::InvalidArgument (
33
+ " the type of micro id shoulde be LoDTensor." ));
34
+ auto micro_id = -1 ;
35
+ auto * tensor = var->GetMutable <framework::LoDTensor>();
36
+ if (platform::is_cpu_place (tensor->place ())) {
37
+ auto data = reinterpret_cast <const float *>(tensor->data <void >());
38
+ micro_id = static_cast <int >(data[0 ]);
39
+ } else {
40
+ #ifdef PADDLE_WITH_CUDA
41
+ std::vector<char > temp;
42
+ temp.resize (tensor->numel () * framework::SizeOfType (tensor->type ()));
43
+ char * temp_ptr = temp.data ();
44
+ auto stream =
45
+ reinterpret_cast <const platform::CUDADeviceContext&>(ctx).stream ();
46
+ memory::Copy (platform::CPUPlace (), temp_ptr,
47
+ BOOST_GET_CONST (platform::CUDAPlace, tensor->place ()),
48
+ tensor->data <void >(),
49
+ tensor->numel () * framework::SizeOfType (tensor->type ()),
50
+ stream);
51
+ float * temp_ptr_float = reinterpret_cast <float *>(temp_ptr);
52
+ micro_id = static_cast <int >(temp_ptr_float[0 ]);
53
+ #endif
54
+ }
55
+ return micro_id;
56
+ }
57
+
28
58
void HeterClient::MainThread () {
29
59
while (running_) {
30
60
RpcProfilerControl ();
@@ -99,43 +129,68 @@ void HeterClient::CreateClient2XpuConnection() {
99
129
}
100
130
}
101
131
}
132
+ previous_xpu_channels_.resize (previous_xpu_list_.size ());
133
+ for (size_t i = 0 ; i < previous_xpu_list_.size (); ++i) {
134
+ previous_xpu_channels_[i].reset (new brpc::Channel ());
135
+ if (previous_xpu_channels_[i]->Init (previous_xpu_list_[i].c_str (), " " ,
136
+ &options) != 0 ) {
137
+ VLOG (0 ) << " HeterClient channel init fail. Try Again" ;
138
+ auto ip_port = paddle::string::Split (previous_xpu_list_[i], ' :' );
139
+ std::string ip = ip_port[0 ];
140
+ int port = std::stoi (ip_port[1 ]);
141
+ std::string int_ip_port = GetIntTypeEndpoint (ip, port);
142
+ if (previous_xpu_channels_[i]->Init (int_ip_port.c_str (), " " , &options) !=
143
+ 0 ) {
144
+ LOG (ERROR) << " BrpcPsServer start failed, ip_port= " << int_ip_port;
145
+ }
146
+ }
147
+ }
102
148
}
103
149
104
150
void HeterClient::SendAndRecvAsync (
105
- const std::vector<std::string>& ep , const platform::DeviceContext& ctx ,
106
- const framework::Scope& scope, const std::string& message_name,
151
+ const platform::DeviceContext& ctx , const framework::Scope& scope ,
152
+ const std::string& message_name,
107
153
const std::vector<std::string>& send_var_name,
108
- const std::vector<std::string>& recv_var_name) {
154
+ const std::vector<std::string>& recv_var_name, const std::string& mode ) {
109
155
platform::RecordEvent record_event (" HeterClient->SendAndRecvAsync" );
110
156
const platform::DeviceContext* p_ctx = &ctx;
111
157
const framework::Scope* p_scope = &scope;
112
158
const std::string message_name_val = message_name;
113
159
const std::vector<std::string> send_var_name_val = send_var_name;
114
160
const std::vector<std::string> recv_var_name_val = recv_var_name;
115
-
116
- VLOG (3 ) << " GRPCClient::SendAndRecv Begin, message_name: "
161
+ VLOG (3 ) << " BRPCClient::SendAndRecv Begin, message_name: "
117
162
<< message_name_val;
118
- // Todo: get correct channel
119
- int num = trainer_id_ % xpu_channels_.size ();
120
-
121
- brpc::Controller cntl;
122
- cntl.set_timeout_ms (FLAGS_pserver_timeout_ms);
123
- distributed::MultiVarMsg request, response;
124
- auto & request_io_buffer = cntl.request_attachment ();
125
- ::paddle::distributed::PsService_Stub stub (xpu_channels_[num].get ());
163
+ brpc::Channel* channel = nullptr ;
164
+ distributed::MultiVarMsg request;
165
+ OnHeterRpcDone* closure = new OnHeterRpcDone ([p_ctx, p_scope](void * done) {
166
+ auto * closure = reinterpret_cast <OnHeterRpcDone*>(done);
167
+ PADDLE_ENFORCE_NE (
168
+ closure->cntl .Failed (), true ,
169
+ platform::errors::Unimplemented (
170
+ " HeterClient::SendAndRecv meets brpc error, error message is %s" ,
171
+ closure->cntl .ErrorText ()));
172
+
173
+ VLOG (4 ) << " call heter_worker success" ;
174
+ });
175
+ closure->cntl .set_timeout_ms (FLAGS_pserver_timeout_ms);
176
+ auto & request_io_buffer = closure->cntl .request_attachment ();
126
177
distributed::SerializeToMultiVarMsgAndIOBuf (
127
178
message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope,
128
179
&request, &request_io_buffer);
129
- stub.SendAndRecvVariable (&cntl, &request, &response, NULL );
130
- PADDLE_ENFORCE_NE (
131
- cntl.Failed (), true ,
132
- platform::errors::Unimplemented (
133
- " HeterClient::SendAndRecv meets brpc error, error message is %s" ,
134
- cntl.ErrorText ()));
135
- VLOG (4 ) << " call heter_worker success" ;
136
- auto & response_io_buffer = cntl.response_attachment ();
137
- distributed::DeserializeFromMultiVarMsgAndIOBuf (response, &response_io_buffer,
138
- ctx, p_scope);
180
+
181
+ int micro_id = GetMicroId (ctx, p_scope);
182
+ auto minibatch_id = micro_id / 10 ;
183
+ // select channel according to micro id
184
+ if (mode == " forward" ) {
185
+ int num = minibatch_id % xpu_channels_.size ();
186
+ channel = xpu_channels_[num].get ();
187
+ } else if (mode == " backward" ) {
188
+ int num = minibatch_id % previous_xpu_channels_.size ();
189
+ channel = previous_xpu_channels_[num].get ();
190
+ }
191
+ ::paddle::distributed::PsService_Stub stub (channel);
192
+ stub.SendAndRecvVariable (&closure->cntl , &request, &closure->response ,
193
+ closure);
139
194
}
140
195
141
196
std::future<int32_t > HeterClient::SendCmd (
0 commit comments