Skip to content

Commit 59d75bd

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/python_doc
2 parents df681fd + 50104f1 commit 59d75bd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+1421
-425
lines changed

benchmark/fluid/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ Currently supported `--model` argument include:
2929
You can choose to use GPU/CPU training. With GPU training, you can specify
3030
`--gpus <gpu_num>` to run multi GPU training.
3131
* Run distributed training with parameter servers:
32+
* see [run_fluid_benchmark.sh](https://github.com/PaddlePaddle/Paddle/blob/develop/benchmark/fluid/run_fluid_benchmark.sh) as an example.
3233
* start parameter servers:
3334
```bash
3435
PADDLE_TRAINING_ROLE=PSERVER PADDLE_PSERVER_PORT=7164 PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=1 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model mnist --device GPU --update_method pserver
36+
sleep 15
3537
```
3638
* start trainers:
3739
```bash
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#!/bin/bash
2+
3+
PADDLE_TRAINING_ROLE=PSERVER PADDLE_PSERVER_PORT=7164 PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=2 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model resnet --device CPU --update_method pserver --iterations=10000 &
4+
5+
sleep 15
6+
7+
CUDA_VISIBLE_DEVICES=0,1 PADDLE_TRAINING_ROLE=TRAINER PADDLE_PSERVER_PORT=7164 PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=2 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model resnet --device GPU --update_method pserver --iterations=10000 --gpus 2 &
8+
9+
CUDA_VISIBLE_DEVICES=2,3 PADDLE_TRAINING_ROLE=TRAINER PADDLE_PSERVER_PORT=7164 PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=2 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=1 python fluid_benchmark.py --model resnet --device GPU --update_method pserver --iterations=10000 --gpus 2 &

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
8787
framework_proto glog lod_rank_table feed_fetch_method)
8888

8989

90-
cc_library(parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
90+
cc_library(parallel_executor SRCS parallel_executor.cc DEPS graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
9191

9292
cc_library(prune SRCS prune.cc DEPS framework_proto)
9393
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place
77

88
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
99
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
10+
cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder)
1011

1112
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
1213

@@ -28,6 +29,9 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
2829
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
2930
scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
3031

32+
33+
cc_library(graph_builder_factory SRCS graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer)
34+
3135
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
3236
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
3337
simple_threadpool device_context)

paddle/fluid/framework/details/broadcast_op_handle.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ struct BroadcastOpHandle : public OpHandleBase {
5959
void RunImpl() override;
6060

6161
private:
62-
const std::vector<Scope *> &local_scopes_;
63-
const std::vector<platform::Place> &places_;
62+
std::vector<Scope *> local_scopes_;
63+
std::vector<platform::Place> places_;
6464
#ifdef PADDLE_WITH_CUDA
6565
const platform::NCCLContextMap *nccl_ctxs_;
6666
#endif

paddle/fluid/framework/details/build_strategy.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#pragma once
1616

17+
#include <string>
18+
1719
namespace paddle {
1820
namespace framework {
1921
namespace details {
@@ -29,6 +31,8 @@ struct BuildStrategy {
2931

3032
ReduceStrategy reduce_{ReduceStrategy::kAllReduce};
3133
GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice};
34+
35+
std::string debug_graphviz_path_{""};
3236
};
3337

3438
} // namespace details
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/details/graph_builder_factory.h"
16+
#include <fstream>
17+
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
18+
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace details {
23+
std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() {
24+
std::unique_ptr<SSAGraphBuilder> res(
25+
#ifdef PADDLE_WITH_CUDA
26+
new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_,
27+
local_scopes_, nccl_ctxs_, strategy_)
28+
#else
29+
new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_,
30+
local_scopes_, strategy_)
31+
#endif
32+
); // NOLINT
33+
34+
if (!strategy_.debug_graphviz_path_.empty()) {
35+
std::unique_ptr<std::ostream> fout(
36+
new std::ofstream(strategy_.debug_graphviz_path_));
37+
PADDLE_ENFORCE(fout->good());
38+
std::unique_ptr<GraphvizSSAGraphPrinter> graphviz_printer(
39+
new GraphvizSSAGraphPrinter());
40+
res.reset(new SSAGraghBuilderWithPrinter(
41+
std::move(fout), std::move(graphviz_printer), std::move(res)));
42+
}
43+
return res;
44+
}
45+
} // namespace details
46+
} // namespace framework
47+
} // namespace paddle
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include <memory>
17+
#include <string>
18+
#include <vector>
19+
#include "paddle/fluid/framework/details/build_strategy.h"
20+
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
21+
#include "paddle/fluid/platform/place.h"
22+
23+
#ifdef PADDLE_WITH_CUDA
24+
#include "paddle/fluid/platform/nccl_helper.h"
25+
#endif
26+
27+
namespace paddle {
28+
namespace framework {
29+
class Scope;
30+
namespace details {
31+
32+
class SSAGraphBuilderFactory {
33+
public:
34+
SSAGraphBuilderFactory(const std::vector<platform::Place>& places,
35+
const std::string& loss_var_name,
36+
const std::unordered_set<std::string>& param_names,
37+
const std::vector<Scope*>& local_scopes,
38+
const BuildStrategy& strategy)
39+
: places_(places),
40+
loss_var_name_(loss_var_name),
41+
param_names_(param_names),
42+
local_scopes_(local_scopes),
43+
strategy_(strategy) {}
44+
45+
#ifdef PADDLE_WITH_CUDA
46+
void SetNCCLContextMap(platform::NCCLContextMap* nccl_ctxs) {
47+
nccl_ctxs_ = nccl_ctxs;
48+
}
49+
#endif
50+
51+
std::unique_ptr<SSAGraphBuilder> Create();
52+
53+
private:
54+
std::vector<platform::Place> places_;
55+
std::string loss_var_name_;
56+
std::unordered_set<std::string> param_names_;
57+
std::vector<Scope*> local_scopes_;
58+
BuildStrategy strategy_;
59+
60+
#ifdef PADDLE_WITH_CUDA
61+
platform::NCCLContextMap* nccl_ctxs_;
62+
#endif
63+
};
64+
65+
} // namespace details
66+
} // namespace framework
67+
} // namespace paddle

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,6 @@
3030
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
3131
#endif
3232

33-
DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot",
34-
"the ssa graph path only print with GLOG_v=10,"
35-
"default /tmp/graph.dot");
36-
3733
namespace paddle {
3834
namespace framework {
3935
namespace details {
@@ -277,11 +273,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
277273
*/
278274
AddOutputToLeafOps(&result);
279275

280-
if (VLOG_IS_ON(10)) {
281-
std::ofstream fout(FLAGS_ssa_graph_path);
282-
PrintGraphviz(*graph, fout);
283-
}
284-
285276
return std::unique_ptr<SSAGraph>(graph);
286277
}
287278

paddle/fluid/framework/details/nccl_all_reduce_op_handle.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
4141
void RunImpl() override;
4242

4343
private:
44-
const std::vector<Scope *> &local_scopes_;
45-
const std::vector<platform::Place> &places_;
44+
std::vector<Scope *> local_scopes_;
45+
std::vector<platform::Place> places_;
4646
const platform::NCCLContextMap &nccl_ctxs_;
4747
};
4848

0 commit comments

Comments
 (0)