Skip to content

Commit f031555

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add-merge-splited-ids
2 parents 6dd3f3c + 431491a commit f031555

36 files changed

+1401
-540
lines changed

cmake/configure.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ endif()
118118
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SIMD_FLAG}")
119119
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SIMD_FLAG}")
120120

121+
if(WITH_DISTRIBUTE)
122+
add_definitions(-DPADDLE_WITH_DISTRIBUTE)
123+
endif()
124+
121125
if(WITH_GOLANG)
122126
# we need to symlink Paddle directory into GOPATH. If we
123127
# don't do it and we have code that depends on Paddle, go

doc/survey/dynamic_graph.md

Lines changed: 378 additions & 0 deletions
Large diffs are not rendered by default.

paddle/fluid/framework/CMakeLists.txt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,13 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
8383

8484
cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog)
8585

86-
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
87-
framework_proto glog lod_rank_table feed_fetch_method)
86+
if(WITH_DISTRIBUTE)
87+
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr)
88+
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
89+
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
90+
else()
91+
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method)
92+
endif()
8893

8994

9095
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)

paddle/fluid/framework/details/ssa_graph_checker.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
namespace paddle {
2020
namespace framework {
2121
namespace details {
22-
class SSAGraph;
22+
struct SSAGraph;
2323

2424
class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
2525
public:

paddle/fluid/framework/executor.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ limitations under the License. */
2020
#include "paddle/fluid/framework/lod_tensor_array.h"
2121
#include "paddle/fluid/framework/op_registry.h"
2222
#include "paddle/fluid/framework/reader.h"
23+
#ifdef PADDLE_WITH_DISTRIBUTE
24+
#include "paddle/fluid/operators/detail/grpc_client.h"
25+
#endif
2326
#include "paddle/fluid/platform/place.h"
2427
#include "paddle/fluid/platform/profiler.h"
2528

@@ -44,6 +47,14 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
4447

4548
Executor::Executor(const platform::Place& place) : place_(place) {}
4649

50+
#ifdef PADDLE_WITH_DISTRIBUTE
51+
void Executor::Complete() {
52+
::paddle::operators::detail::RPCClient::GetInstance<
53+
::paddle::operators::detail::GRPCClient>()
54+
->SendComplete();
55+
}
56+
#endif
57+
4758
void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
4859
if (var_type == proto::VarType::LOD_TENSOR) {
4960
var->GetMutable<LoDTensor>();

paddle/fluid/framework/executor.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ class Executor {
4444

4545
explicit Executor(const platform::Place& place);
4646

47+
#ifdef PADDLE_WITH_DISTRIBUTE
48+
/*
49+
* Sending signal to pserver to mark current trainer stop.
50+
*/
51+
void Complete();
52+
#endif
53+
4754
/* @Brief
4855
* Runtime evaluation of the given ProgramDesc under certain Scope
4956
*

paddle/fluid/operators/detail/grpc_client.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ void GRPCClient::InitEventLoop() {
3434
client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
3535
}
3636

37+
void GRPCClient::SendComplete() {
38+
for (auto& it : channels_) {
39+
this->AsyncSendComplete(it.first);
40+
}
41+
}
42+
3743
GRPCClient::~GRPCClient() {
3844
Wait();
3945
cq_.Shutdown();
@@ -210,6 +216,19 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
210216
req_count_++;
211217
}
212218

219+
void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
220+
const auto ch = GetChannel(ep);
221+
222+
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
223+
s->Prepare(time_out);
224+
225+
sendrecv::VariableMessage req;
226+
req.set_varname(COMPLETE_MESSAGE);
227+
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
228+
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
229+
req_count_++;
230+
}
231+
213232
void GRPCClient::Wait() {
214233
std::unique_lock<std::mutex> lk(sync_mutex_);
215234
sync_cond_.wait(lk, [this] { return req_count_ == 0; });

paddle/fluid/operators/detail/grpc_client.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ class GRPCClient : public RPCClient {
195195

196196
void Wait() override;
197197

198+
void SendComplete() override;
199+
198200
protected:
199201
void InitImpl() override;
200202

@@ -204,6 +206,9 @@ class GRPCClient : public RPCClient {
204206

205207
void Proceed();
206208

209+
void AsyncSendComplete(const std::string& ep,
210+
int64_t time_out = RPCClient::rpc_time_out);
211+
207212
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
208213

209214
private:

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,16 +162,18 @@ class RequestPrefetch final : public RequestBase {
162162

163163
void Process() override {
164164
// prefetch process...
165-
std::string varname = request_->OutVarname();
166-
VLOG(3) << "RequestPrefetch " << varname;
165+
std::string in_var_name = request_->Varname();
166+
std::string out_var_name = request_->OutVarname();
167+
VLOG(3) << "RequestPrefetch, in_var_name: " << in_var_name
168+
<< " out_var_name: " << out_var_name;
167169

168170
auto scope = request_->GetMutableLocalScope();
169-
auto invar = scope->FindVar(varname);
170-
framework::Variable* outvar = nullptr;
171+
auto invar = scope->FindVar(in_var_name);
172+
framework::Variable* outvar = scope->FindVar(out_var_name);
171173

172-
request_handler_->Handle(varname, scope, invar, &outvar);
174+
request_handler_->Handle(in_var_name, scope, invar, &outvar, out_var_name);
173175

174-
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
176+
SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
175177
&reply_);
176178
Finish(reply_, &responder_);
177179
}
@@ -287,7 +289,7 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
287289
} else if (rpc_name == kRequestPrefetch) {
288290
b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
289291
} else {
290-
PADDLE_ENFORCE(false, "not surpported rpc");
292+
PADDLE_ENFORCE(false, "not supported rpc");
291293
}
292294

293295
reqs[req_id] = b;

paddle/fluid/operators/detail/request_handler.h

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ constexpr char kRequestPrefetch[] = "RequestPrefetch";
4040
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
4141
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
4242
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
43+
#define COMPLETE_MESSAGE "COMPLETE@RECV"
4344

4445
class RPCServer;
4546

@@ -60,9 +61,12 @@ class RequestHandler {
6061
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
6162
void SetProgram(framework::ProgramDesc* program) { program_ = program; }
6263
void SetExecutor(framework::Executor* executor) { executor_ = executor; }
64+
65+
// Used for dist lookup table prefetch
6366
void SetPrefetchPreparedCtx(
64-
std::unique_ptr<framework::ExecutorPrepareContext> prepared) {
65-
prefetch_ctx_.reset(prepared.release());
67+
std::unordered_map<
68+
std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
69+
prefetch_var_name_to_prepared_ctx_ = g;
6670
}
6771

6872
// Used for async.
@@ -78,9 +82,6 @@ class RequestHandler {
7882
bool sync_mode() { return sync_mode_; }
7983
framework::Scope* scope() { return scope_; }
8084
const platform::DeviceContext* dev_ctx() { return dev_ctx_; }
81-
framework::ExecutorPrepareContext* prefetch_ctx() {
82-
return prefetch_ctx_.get();
83-
}
8485
framework::ProgramDesc* program() { return program_; }
8586
framework::Executor* executor() { return executor_; }
8687

@@ -99,8 +100,8 @@ class RequestHandler {
99100
// *request_handler_->dev_ctx(), &reply_);
100101
// }
101102
virtual bool Handle(const std::string& varname, framework::Scope* scope,
102-
framework::Variable* var,
103-
framework::Variable** outvar) = 0;
103+
framework::Variable* var, framework::Variable** outvar,
104+
const std::string& out_var_name = "") = 0;
104105

105106
protected:
106107
const bool sync_mode_;
@@ -109,12 +110,17 @@ class RequestHandler {
109110
framework::Executor* executor_;
110111
framework::Scope* scope_;
111112
framework::ProgramDesc* program_;
112-
std::unique_ptr<framework::ExecutorPrepareContext> prefetch_ctx_;
113+
114+
// used for distribute lookup table prefetch
115+
std::unordered_map<std::string,
116+
std::shared_ptr<framework::ExecutorPrepareContext>>*
117+
prefetch_var_name_to_prepared_ctx_;
113118

114119
// Used for async.
115120
std::unordered_map<std::string,
116121
std::shared_ptr<framework::ExecutorPrepareContext>>*
117122
grad_to_prepared_ctx_;
123+
118124
RPCServer* rpc_server_;
119125
};
120126

0 commit comments

Comments
 (0)