Skip to content

Commit 23433de

Browse files
committed
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into overlap_memcpy_with_dist
2 parents 15913d9 + ff9b1a0 commit 23433de

21 files changed

+422
-221
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
8787
framework_proto glog lod_rank_table feed_fetch_method)
8888

8989

90-
cc_library(parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
90+
cc_library(parallel_executor SRCS parallel_executor.cc DEPS graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
9191

9292
cc_library(prune SRCS prune.cc DEPS framework_proto)
9393
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place
77

88
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
99
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
10+
cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder)
1011

1112
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
1213

@@ -28,6 +29,9 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
2829
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
2930
scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
3031

32+
33+
cc_library(graph_builder_factory SRCS graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer)
34+
3135
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
3236
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
3337
simple_threadpool device_context)

paddle/fluid/framework/details/broadcast_op_handle.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ struct BroadcastOpHandle : public OpHandleBase {
5959
void RunImpl() override;
6060

6161
private:
62-
const std::vector<Scope *> &local_scopes_;
63-
const std::vector<platform::Place> &places_;
62+
std::vector<Scope *> local_scopes_;
63+
std::vector<platform::Place> places_;
6464
#ifdef PADDLE_WITH_CUDA
6565
const platform::NCCLContextMap *nccl_ctxs_;
6666
#endif

paddle/fluid/framework/details/build_strategy.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#pragma once
1616

17+
#include <string>
18+
1719
namespace paddle {
1820
namespace framework {
1921
namespace details {
@@ -29,6 +31,8 @@ struct BuildStrategy {
2931

3032
ReduceStrategy reduce_{ReduceStrategy::kAllReduce};
3133
GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice};
34+
35+
std::string debug_graphviz_path_{""};
3236
};
3337

3438
} // namespace details
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/details/graph_builder_factory.h"
16+
#include <fstream>
17+
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
18+
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace details {
23+
std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() {
24+
std::unique_ptr<SSAGraphBuilder> res(
25+
#ifdef PADDLE_WITH_CUDA
26+
new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_,
27+
local_scopes_, nccl_ctxs_, strategy_)
28+
#else
29+
new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_,
30+
local_scopes_, strategy_)
31+
#endif
32+
); // NOLINT
33+
34+
if (!strategy_.debug_graphviz_path_.empty()) {
35+
std::unique_ptr<std::ostream> fout(
36+
new std::ofstream(strategy_.debug_graphviz_path_));
37+
PADDLE_ENFORCE(fout->good());
38+
std::unique_ptr<GraphvizSSAGraphPrinter> graphviz_printer(
39+
new GraphvizSSAGraphPrinter());
40+
res.reset(new SSAGraghBuilderWithPrinter(
41+
std::move(fout), std::move(graphviz_printer), std::move(res)));
42+
}
43+
return res;
44+
}
45+
} // namespace details
46+
} // namespace framework
47+
} // namespace paddle
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include <memory>
17+
#include <string>
18+
#include <vector>
19+
#include "paddle/fluid/framework/details/build_strategy.h"
20+
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
21+
#include "paddle/fluid/platform/place.h"
22+
23+
#ifdef PADDLE_WITH_CUDA
24+
#include "paddle/fluid/platform/nccl_helper.h"
25+
#endif
26+
27+
namespace paddle {
28+
namespace framework {
29+
class Scope;
30+
namespace details {
31+
32+
class SSAGraphBuilderFactory {
33+
public:
34+
SSAGraphBuilderFactory(const std::vector<platform::Place>& places,
35+
const std::string& loss_var_name,
36+
const std::unordered_set<std::string>& param_names,
37+
const std::vector<Scope*>& local_scopes,
38+
const BuildStrategy& strategy)
39+
: places_(places),
40+
loss_var_name_(loss_var_name),
41+
param_names_(param_names),
42+
local_scopes_(local_scopes),
43+
strategy_(strategy) {}
44+
45+
#ifdef PADDLE_WITH_CUDA
46+
void SetNCCLContextMap(platform::NCCLContextMap* nccl_ctxs) {
47+
nccl_ctxs_ = nccl_ctxs;
48+
}
49+
#endif
50+
51+
std::unique_ptr<SSAGraphBuilder> Create();
52+
53+
private:
54+
std::vector<platform::Place> places_;
55+
std::string loss_var_name_;
56+
std::unordered_set<std::string> param_names_;
57+
std::vector<Scope*> local_scopes_;
58+
BuildStrategy strategy_;
59+
60+
#ifdef PADDLE_WITH_CUDA
61+
platform::NCCLContextMap* nccl_ctxs_;
62+
#endif
63+
};
64+
65+
} // namespace details
66+
} // namespace framework
67+
} // namespace paddle

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,6 @@
3030
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
3131
#endif
3232

33-
DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot",
34-
"the ssa graph path only print with GLOG_v=10,"
35-
"default /tmp/graph.dot");
36-
3733
namespace paddle {
3834
namespace framework {
3935
namespace details {
@@ -149,6 +145,7 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
149145

150146
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
151147
const ProgramDesc &program) const {
148+
VLOG(3) << "Building ....";
152149
std::unordered_map<std::string, VarDesc *> all_vars;
153150
for (auto *var : program.Block(0).AllVars()) {
154151
all_vars[var->Name()] = var;
@@ -315,11 +312,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
315312
*/
316313
AddOutputToLeafOps(&result);
317314

318-
if (VLOG_IS_ON(10)) {
319-
std::ofstream fout(FLAGS_ssa_graph_path);
320-
PrintGraphviz(*graph, fout);
321-
}
322-
323315
return std::unique_ptr<SSAGraph>(graph);
324316
}
325317

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
4848

4949
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
5050

51-
int GetRemoteVarDevice(const std::string &var_name) const {
51+
int GetRemoteVarDeviceId(const std::string &var_name) const override {
5252
auto got = remote_vars_devices_.find(var_name);
5353
if (got != remote_vars_devices_.end()) {
5454
return got->second;

paddle/fluid/framework/details/nccl_all_reduce_op_handle.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
4141
void RunImpl() override;
4242

4343
private:
44-
const std::vector<Scope *> &local_scopes_;
45-
const std::vector<platform::Place> &places_;
44+
std::vector<Scope *> local_scopes_;
45+
std::vector<platform::Place> places_;
4646
const platform::NCCLContextMap &nccl_ctxs_;
4747
};
4848

paddle/fluid/framework/details/reduce_op_handle.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ namespace framework {
3232
namespace details {
3333

3434
struct ReduceOpHandle : public OpHandleBase {
35-
const std::vector<Scope *> &local_scopes_;
36-
const std::vector<platform::Place> &places_;
35+
std::vector<Scope *> local_scopes_;
36+
std::vector<platform::Place> places_;
3737

3838
#ifdef PADDLE_WITH_CUDA
3939
const platform::NCCLContextMap *nccl_ctxs_;

0 commit comments

Comments
 (0)