Skip to content

Commit 62af10d

Browse files
committed
support multiple devices
1 parent 274df85 commit 62af10d

13 files changed

+208
-26
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_h
44
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
55
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
66
cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry)
7+
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
78

89
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
910
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
@@ -26,7 +27,7 @@ endif()
2627
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
2728

2829
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
29-
scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
30+
scale_loss_grad_op_handle send_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
3031

3132
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
3233
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
15+
#include <fstream>
1516
#include <utility>
1617
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
1718
#include "paddle/fluid/framework/details/computation_op_handle.h"
1819
#include "paddle/fluid/framework/details/reduce_op_handle.h"
20+
#include "paddle/fluid/framework/details/rpc_op_handle.h"
1921
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
2022
#include "paddle/fluid/framework/details/send_op_handle.h"
2123
#include "paddle/fluid/framework/scope.h"
@@ -77,7 +79,6 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
7779
CreateOpOutput(result, op_handle, each_var_name, p, place_id);
7880
}
7981
}
80-
8182
bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
8283
OpDesc *send_op) const {
8384
if (send_op == nullptr) {
@@ -98,14 +99,23 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
9899
return false;
99100
};
100101

101-
if (op.Type() == "split") {
102+
if (op.Type() == "split" || op.Type() == "split_byref") {
102103
return checker(op.OutputArgumentNames(), send_op->InputArgumentNames());
103104
} else if (op.Type() == "concat") {
104105
return checker(op.InputArgumentNames(), send_op->OutputArgumentNames());
105106
}
106107
return false;
107108
}
108109

110+
bool MultiDevSSAGraphBuilder::IsRPCOp(const OpDesc &op) const {
111+
for (auto &name : op.OutputNames()) {
112+
if (name == "RPCClient") {
113+
return true;
114+
}
115+
}
116+
return false;
117+
}
118+
109119
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
110120
const ProgramDesc &program) const {
111121
std::unordered_map<std::string, proto::VarType::Type> var_types;
@@ -133,10 +143,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
133143

134144
bool is_forwarding = true;
135145
for (auto *op : program.Block(0).AllOps()) {
136-
if (op->Type() == "send") {
137-
// append send op if program is distributed trainer main program.
146+
if (IsRPCOp(*op)) {
147+
// append rpc op if program is distributed trainer main program.
138148
// always use the first device
139-
CreateSendOp(&result, *op);
149+
CreateRPCOp(&result, *op);
140150
} else if (IsDistTrainOp(*op, send_op)) {
141151
CreateComputationalOps(&result, *op, 1);
142152
} else if (IsScaleLossOp(*op)) {
@@ -203,9 +213,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
203213
AddOutputToLeafOps(&result);
204214

205215
if (VLOG_IS_ON(10)) {
206-
std::ostringstream sout;
207-
PrintGraphviz(*graph, sout);
208-
VLOG(10) << sout.str();
216+
std::string filename = "/tmp/graph";
217+
std::ofstream fout(filename);
218+
PrintGraphviz(*graph, fout);
209219
}
210220

211221
return std::unique_ptr<SSAGraph>(graph);
@@ -386,12 +396,40 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
386396
return var;
387397
}
388398

389-
void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
390-
const OpDesc &op) const {
399+
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result,
400+
std::string op_name) const {
401+
for (auto &prev_op : result->ops_) {
402+
if (prev_op->Name() == op_name) {
403+
auto *dep_var = new DummyVarHandle();
404+
prev_op->AddOutput(dep_var);
405+
result->dep_vars_.emplace(dep_var);
406+
result->ops_.back().get()->AddInput(dep_var);
407+
}
408+
}
409+
}
410+
411+
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
412+
const OpDesc &op) const {
391413
auto &p = places_[0];
392414
auto *s = local_scopes_[0];
415+
VLOG(3) << "create rpc op: " << op.Type();
416+
result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
417+
if (op.Type() == "send_barrier") {
418+
ConnectOp(result, "send_vars");
419+
} else if (op.Type() == "recv") {
420+
ConnectOp(result, "send_barrier");
421+
} else if (op.Type() == "fetch_barrier") {
422+
ConnectOp(result, "recv");
423+
} else if (op.Type() == "send" || op.Type() == "send_vars") {
424+
// do nothing
425+
} else {
426+
PADDLE_THROW(
427+
"rpc op should be in [send,"
428+
"send_vars, send_barrier. recv, fetch_barrier]");
429+
}
430+
393431
// FIXME(wuyi): send op always copy from GPU 0
394-
result->ops_.emplace_back(new SendOpHandle(op, s, p));
432+
// result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
395433
// Create inputs for output on original place and no ssa output
396434
// is created for send op.
397435
CreateOpHandleIOs(result, op, 0);

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,17 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
6565
bool IsScaleLossOp(const OpDesc &op) const;
6666

6767
void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
68+
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
6869

6970
/**
7071
* Is this operator as the end-point operator before/after send operator.
7172
*/
7273
bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const;
7374

75+
bool IsRPCOp(const OpDesc &op) const;
76+
77+
void ConnectOp(SSAGraph *result, std::string op_name) const;
78+
7479
void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
7580
size_t num_places) const;
7681

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/details/rpc_op_handle.h"
16+
17+
namespace paddle {
18+
namespace framework {
19+
namespace details {
20+
21+
RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc,
22+
const Scope *local_scope, const platform::Place &place,
23+
const std::string &name)
24+
: op_(framework::OpRegistry::CreateOp(op_desc)),
25+
local_scope_(local_scope),
26+
place_(place),
27+
name_(name) {}
28+
29+
void RPCOpHandle::RunImpl() {
30+
// TODO(wuyi): need further analysis whether wait VarDummyHandle.
31+
// Wait input done
32+
for (auto *in : inputs_) {
33+
auto &p = static_cast<VarHandle *>(in)->place_;
34+
if (in->DebugString() == "dummy") { // HACK
35+
continue;
36+
}
37+
if (in->generated_op_) {
38+
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[p]);
39+
}
40+
}
41+
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
42+
// FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead
43+
// lock.
44+
op_->Run(*tmp_scope, place_);
45+
}
46+
47+
std::string RPCOpHandle::Name() const { return name_; }
48+
} // namespace details
49+
} // namespace framework
50+
} // namespace paddle
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include <vector>
19+
20+
#include "paddle/fluid/framework/details/op_handle_base.h"
21+
#include "paddle/fluid/framework/lod_tensor.h"
22+
#include "paddle/fluid/framework/op_registry.h"
23+
#include "paddle/fluid/framework/operator.h"
24+
#include "paddle/fluid/framework/scope.h"
25+
26+
namespace paddle {
27+
namespace framework {
28+
namespace details {
29+
30+
struct RPCOpHandle : public OpHandleBase {
31+
RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
32+
const platform::Place& place, const std::string& name);
33+
34+
std::string Name() const override;
35+
36+
// Delay and buffer nccl_all_reduce together can significantly increase
37+
// performance. Disable this feature by returning false.
38+
bool IsMultiDeviceTransfer() override { return false; };
39+
40+
protected:
41+
void RunImpl() override;
42+
43+
private:
44+
std::unique_ptr<OperatorBase> op_;
45+
const Scope* local_scope_;
46+
const platform::Place& place_;
47+
const std::string name_;
48+
};
49+
50+
} // namespace details
51+
} // namespace framework
52+
} // namespace paddle

paddle/fluid/framework/variable.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#pragma once
1515

1616
#include <memory>
17+
#include <mutex> // NOLINT
1718
#include <string>
1819
#include <typeindex>
1920
#include <typeinfo>
@@ -38,6 +39,7 @@ class Variable {
3839

3940
template <typename T>
4041
T* GetMutable() {
42+
std::unique_lock<std::mutex> lock(mutex_);
4143
if (!IsType<T>()) {
4244
holder_.reset(new PlaceholderImpl<T>(new T()));
4345
}
@@ -90,6 +92,7 @@ class Variable {
9092
// by its address but not the unreadable name.
9193
friend class Scope;
9294
const std::string* name_;
95+
std::mutex mutex_;
9396
};
9497

9598
} // namespace framework

paddle/fluid/operators/detail/grpc_client.cc

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
3333
const std::string ep_val = ep;
3434
const std::string var_name_val = var_name;
3535
const framework::Scope* p_scope = &scope;
36-
const auto ch = GetChannel(ep_val);
36+
const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val);
3737

3838
framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch,
3939
this] {
@@ -88,7 +88,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
8888
const std::string ep_val = ep;
8989
const std::string var_name_val = var_name;
9090
const framework::Scope* p_scope = &scope;
91-
const auto ch = GetChannel(ep_val);
91+
const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val);
9292

9393
framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch,
9494
this] {
@@ -132,7 +132,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
132132
const std::string in_var_name_val = in_var_name;
133133
const std::string out_var_name_val = out_var_name;
134134
const framework::Scope* p_scope = &scope;
135-
const auto ch = GetChannel(ep_val);
135+
const auto ch = GetChannel(ep_val, ep_val + ":" + in_var_name_val);
136136

137137
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
138138
time_out, ch, this] {
@@ -165,7 +165,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
165165
}
166166

167167
void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
168-
const auto ch = GetChannel(ep);
168+
const auto ch = GetChannel(ep, ep);
169169

170170
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
171171
s->Prepare(time_out);
@@ -178,7 +178,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
178178
}
179179

180180
void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
181-
const auto ch = GetChannel(ep);
181+
const auto ch = GetChannel(ep, ep);
182182
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
183183
s->Prepare(time_out);
184184

@@ -243,12 +243,19 @@ bool RPCClient::Proceed() {
243243
delete c;
244244
return true;
245245
}
246-
247-
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
248-
auto it = channels_.find(ep);
246+
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep,
247+
const std::string& key) {
248+
VLOG(3) << "this addr: " << this;
249+
std::unique_lock<std::mutex> lock(mutex_);
250+
auto it = channels_.find(key);
249251
if (it != channels_.end()) {
252+
VLOG(3) << "find ep: " << ep;
250253
return it->second;
251254
}
255+
VLOG(3) << "can not find ep: " << ep;
256+
for (auto it = channels_.begin(); it != channels_.end(); ++it) {
257+
VLOG(3) << "ep: " << it->first;
258+
}
252259

253260
grpc::ChannelArguments args;
254261
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
@@ -257,8 +264,7 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
257264

258265
auto ch =
259266
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
260-
261-
channels_[ep] = ch;
267+
channels_[key] = ch;
262268
return ch;
263269
}
264270

paddle/fluid/operators/detail/grpc_client.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License. */
2121
#include <functional>
2222
#include <iostream>
2323
#include <map>
24+
#include <mutex> // NOLINT
2425
#include <string>
2526
#include <vector>
2627

@@ -190,12 +191,14 @@ class RPCClient {
190191

191192
private:
192193
bool Proceed();
193-
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
194+
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep,
195+
const std::string& key);
194196

195197
private:
196198
grpc::CompletionQueue cq_;
197199
std::map<std::string, std::shared_ptr<grpc::Channel>> channels_;
198200
int64_t req_count_ = 0;
201+
std::mutex mutex_;
199202
};
200203

201204
} // namespace detail

0 commit comments

Comments
 (0)