Skip to content

Commit 20c24c0

Browse files
committed
singleton rpc_client
1 parent 28596a3 commit 20c24c0

18 files changed

+161
-240
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -146,15 +146,6 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
146146
checker(op.InputArgumentNames(), recv_vars);
147147
}
148148

149-
bool MultiDevSSAGraphBuilder::IsRPCOp(const OpDesc &op) const {
150-
for (auto &name : op.OutputNames()) {
151-
if (name == "RPCClient") {
152-
return true;
153-
}
154-
}
155-
return false;
156-
}
157-
158149
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
159150
const ProgramDesc &program) const {
160151
std::unordered_map<std::string, proto::VarType::Type> var_types;
@@ -184,7 +175,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
184175

185176
bool is_forwarding = true;
186177
for (auto *op : program.Block(0).AllOps()) {
187-
if (IsRPCOp(*op)) {
178+
if (boost::get<int>(
179+
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
180+
static_cast<int>(OpRole::kRPC)) {
188181
// append rpc op if program is distributed trainer main program.
189182
// always use the first device
190183
CreateRPCOp(&result, *op);

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
8080
std::vector<std::string> FindDistTrainRecvVars(
8181
const ProgramDesc &program) const;
8282

83-
bool IsRPCOp(const OpDesc &op) const;
84-
8583
void ConnectOp(SSAGraph *result, OpHandleBase *op,
8684
const std::string &prev_op_name) const;
8785

paddle/fluid/framework/op_proto_maker.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
6666
.InEnum(
6767
{static_cast<int>(OpRole::kForward),
6868
static_cast<int>(OpRole::kBackward),
69-
static_cast<int>(OpRole::kOptimize),
69+
static_cast<int>(OpRole::kOptimize), static_cast<int>(OpRole::kRPC),
7070
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
7171
static_cast<int>(OpRole::kLoss) |
7272
static_cast<int>(OpRole::kBackward),

paddle/fluid/framework/op_proto_maker.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ enum class OpRole {
2424
kForward = 0x0000,
2525
kBackward = 0x0001,
2626
kOptimize = 0x0002,
27+
kRPC = 0x0003,
2728

2829
kLoss = 0x0100,
2930
// The default value of op's role. This should be only used for unittests and

paddle/fluid/inference/analysis/data_flow_graph_tester.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ TEST(DataFlowGraph, BFS) {
3535

3636
GraphTraits<DataFlowGraph> trait(&dfg);
3737
auto nodes = trait.nodes();
38-
int count = 0;
38+
size_t count = 0;
3939
for (auto it = nodes.begin(); it != nodes.end(); ++it) {
4040
LOG(INFO) << "visiting " << it->name();
4141
++count;
@@ -49,7 +49,7 @@ TEST(DataFlowGraph, DFS) {
4949
dfg.Build();
5050
GraphTraits<DataFlowGraph> trait(&dfg);
5151
auto nodes = trait.nodes_in_DFS();
52-
int count = 0;
52+
size_t count = 0;
5353
for (auto it = nodes.begin(); it != nodes.end(); ++it) {
5454
LOG(INFO) << "visiting " << it->name();
5555
++count;

paddle/fluid/operators/detail/grpc_client.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,21 @@ namespace paddle {
2525
namespace operators {
2626
namespace detail {
2727

28+
std::once_flag RPCClient::init_flag_;
29+
30+
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr);
31+
32+
RPCClient* RPCClient::GetInstance() {
33+
std::call_once(init_flag_, &RPCClient::Init);
34+
return rpc_client_.get();
35+
}
36+
37+
void RPCClient::Init() {
38+
if (rpc_client_.get() == nullptr) {
39+
rpc_client_.reset(new RPCClient());
40+
}
41+
}
42+
2843
bool RPCClient::AsyncSendVariable(const std::string& ep,
2944
const platform::DeviceContext& ctx,
3045
const framework::Scope& scope,

paddle/fluid/operators/detail/grpc_client.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ limitations under the License. */
3636
#include "paddle/fluid/framework/scope.h"
3737
#include "paddle/fluid/framework/selected_rows.h"
3838
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
39+
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
3940

4041
namespace paddle {
4142
namespace operators {
@@ -162,6 +163,10 @@ class FetchBarrierProcessor : public BaseProcessor {
162163

163164
class RPCClient {
164165
public:
166+
RPCClient() {}
167+
168+
static RPCClient* GetInstance();
169+
165170
bool AsyncSendVariable(const std::string& ep,
166171
const platform::DeviceContext& ctx,
167172
const framework::Scope& scope,
@@ -192,12 +197,17 @@ class RPCClient {
192197
private:
193198
bool Proceed();
194199
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
200+
// Init is called by GetInstance.
201+
static void Init();
195202

196203
private:
197204
grpc::CompletionQueue cq_;
198205
std::map<std::string, std::shared_ptr<grpc::Channel>> channels_;
199206
std::atomic<int64_t> req_count_{0};
200207
std::mutex mutex_;
208+
static std::unique_ptr<RPCClient> rpc_client_;
209+
static std::once_flag init_flag_;
210+
DISABLE_COPY_AND_ASSIGN(RPCClient);
201211
};
202212

203213
} // namespace detail

paddle/fluid/operators/detail/grpc_server_test.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,13 @@ TEST(PREFETCH, DISABLED_CPU) {
121121
std::string in_var_name("ids");
122122
std::string out_var_name("out");
123123

124-
detail::RPCClient client;
125-
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
126-
out_var_name);
127-
client.Wait();
124+
detail::RPCClient::GetInstance();
125+
126+
// detail::RPCClient::GetInstance();
127+
// client->Wait();
128+
// client->AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
129+
// out_var_name);
130+
// client->Wait();
128131

129132
auto var = scope.Var(out_var_name);
130133
auto value = var->GetMutable<framework::SelectedRows>()->value();

paddle/fluid/operators/fetch_barrier_op.cc

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,7 @@ class FetchBarrierOp : public framework::OperatorBase {
4343
// For profiling
4444
platform::RecordEvent record_event(Type(), &ctx);
4545

46-
auto client_var_name = Output("RPCClient");
47-
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
48-
"Can not find variable '%s' in the scope.",
49-
client_var_name);
50-
auto* client_var = scope.FindVar(client_var_name);
51-
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
46+
auto rpc_client = detail::RPCClient::GetInstance();
5247

5348
PADDLE_ENFORCE(rpc_client->Wait());
5449

@@ -63,9 +58,6 @@ class FetchBarrierOp : public framework::OperatorBase {
6358
class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
6459
public:
6560
void Make() {
66-
AddOutput("RPCClient",
67-
"(RPCClient) The RPC client object which is"
68-
"initialized at most once.");
6961
AddComment(R"DOC(
7062
SendBarrier operator
7163
@@ -80,17 +72,6 @@ the Parameter Server would knew all variables have been sent.
8072
}
8173
};
8274

83-
class FetchBarrierOpVarTypeInference : public framework::VarTypeInference {
84-
public:
85-
void operator()(const framework::OpDesc& op_desc,
86-
framework::BlockDesc* block) const override {
87-
auto out_var_name = op_desc.Output("RPCClient").front();
88-
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
89-
auto var_type = framework::proto::VarType::RAW;
90-
out_var.SetType(var_type);
91-
}
92-
};
93-
9475
class FetchBarrierOpShapeInference : public framework::InferShapeBase {
9576
public:
9677
void operator()(framework::InferShapeContext* ctx) const override {}
@@ -103,5 +84,4 @@ namespace ops = paddle::operators;
10384

10485
REGISTER_OPERATOR(fetch_barrier, ops::FetchBarrierOp,
10586
paddle::framework::EmptyGradOpMaker, ops::FetchBarrierOpMaker,
106-
ops::FetchBarrierOpVarTypeInference,
10787
ops::FetchBarrierOpShapeInference);

paddle/fluid/operators/prefetch_op.cc

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,7 @@ class PrefetchOp : public framework::OperatorBase {
4141
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
4242
auto& ctx = *pool.Get(place);
4343

44-
auto client_var_name = Output("RPCClient");
45-
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
46-
"Can not find variable '%s' in the scope.",
47-
client_var_name);
48-
auto* client_var = scope.FindVar(client_var_name);
49-
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
44+
auto rpc_client = detail::RPCClient::GetInstance();
5045

5146
for (size_t i = 0; i < ins.size(); i++) {
5247
if (NeedSend(scope, ins[i])) {
@@ -66,9 +61,6 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
6661
public:
6762
void Make() {
6863
AddInput("X", "(LoDTensor) Input Id variables to be sent").AsDuplicable();
69-
AddOutput("RPCClient",
70-
"(RPCClient) The RPC client object which will be"
71-
"initialized at most once.");
7264
AddOutput("Out",
7365
"(LoDTensor) result "
7466
"to be fetched from parameter server")
@@ -87,17 +79,6 @@ the parameter server and fetch result back.
8779
}
8880
};
8981

90-
class PrefetchOpVarTypeInference : public framework::VarTypeInference {
91-
public:
92-
void operator()(const framework::OpDesc& op_desc,
93-
framework::BlockDesc* block) const override {
94-
auto out_var_name = op_desc.Output("RPCClient").front();
95-
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
96-
auto var_type = framework::proto::VarType::RAW;
97-
out_var.SetType(var_type);
98-
}
99-
};
100-
10182
class PrefetchOpShapeInference : public framework::InferShapeBase {
10283
public:
10384
void operator()(framework::InferShapeContext* ctx) const override {}
@@ -110,5 +91,4 @@ namespace ops = paddle::operators;
11091

11192
REGISTER_OPERATOR(prefetch, ops::PrefetchOp,
11293
paddle::framework::EmptyGradOpMaker, ops::PrefetchOpMaker,
113-
ops::PrefetchOpVarTypeInference,
11494
ops::PrefetchOpShapeInference);

0 commit comments

Comments
 (0)