Skip to content

Commit 87648f8

Browse files
committed
merge develop, test=develop
2 parents c3c3c0b + db9284e commit 87648f8

25 files changed

+665
-173
lines changed

cmake/inference_lib.cmake

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,7 @@ set(module "inference")
186186
copy(inference_lib DEPS ${inference_deps}
187187
SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.*
188188
${src_dir}/${module}/api/paddle_*.h
189-
${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h
190-
DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}
189+
DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}
191190
)
192191

193192
set(module "platform")

paddle/fluid/API.spec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_ti
9797
paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None)
9898
paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,))
9999
paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None))
100-
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0))
101100
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'non_leaf_num', 'ptable', 'pcode', 'is_costum', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, False, False))
101+
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False))
102102
paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None))
103103
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
104104
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@ if (WITH_GPU)
3939
endif()
4040

4141
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
42+
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass)
4243

4344
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
4445
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
4546

46-
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass)
47+
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass)
4748
if (WITH_GPU)
4849
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
4950
endif()
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <algorithm>
16+
#include <string>
17+
#include <unordered_map>
18+
#include <unordered_set>
19+
#include <vector>
20+
21+
#include "paddle/fluid/framework/details/all_reduce_deps_pass.h"
22+
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
23+
#include "paddle/fluid/framework/details/multi_devices_helper.h"
24+
#include "paddle/fluid/framework/details/op_graph_view.h"
25+
#include "paddle/fluid/framework/details/var_handle.h"
26+
#include "paddle/fluid/framework/ir/graph_helper.h"
27+
#include "paddle/fluid/framework/op_proto_maker.h"
28+
29+
namespace paddle {
30+
namespace framework {
31+
namespace details {
32+
33+
static constexpr char kAllOpDescs[] = "all_op_descs";
34+
35+
VarHandle* GetValidInput(const OpHandleBase* a) {
36+
for (auto p : a->Inputs()) {
37+
VarHandle* b = dynamic_cast<VarHandle*>(p);
38+
if (b) {
39+
return b;
40+
}
41+
}
42+
43+
return nullptr;
44+
}
45+
46+
std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
47+
std::unique_ptr<ir::Graph> graph) const {
48+
auto graph_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
49+
50+
// get vars order
51+
int order = 0;
52+
std::unordered_map<std::string, int> vars;
53+
// TODO(gongwb): use graph topology sort to find the order of operators.
54+
// Note that must assert topology sort is stable
55+
auto& ops = Get<const std::vector<OpDesc*>>(kAllOpDescs);
56+
for (auto* op_desc : ops) {
57+
auto outputs = op_desc->Outputs();
58+
for (auto& o_it : outputs) {
59+
for (auto& v : o_it.second) { // values
60+
vars[v] = order;
61+
}
62+
}
63+
order++;
64+
}
65+
66+
std::vector<OpHandleBase*> dist_ops;
67+
// get allreduce ops.
68+
for (auto& op : graph_ops) {
69+
// FIXME(gongwb):add broad cast.
70+
if (op->Name() == "all_reduce" || op->Name() == "reduce") {
71+
dist_ops.push_back(op);
72+
}
73+
}
74+
75+
VLOG(10) << "dist_ops size:" << dist_ops.size() << std::endl;
76+
77+
std::sort(dist_ops.begin(), dist_ops.end(), [&](OpHandleBase* op1,
78+
OpHandleBase* op2) {
79+
VarHandle* i0 = dynamic_cast<VarHandle*>(GetValidInput(op1));
80+
VarHandle* i1 = dynamic_cast<VarHandle*>(GetValidInput(op2));
81+
82+
PADDLE_ENFORCE(i0 != nullptr && i1 != nullptr, "%s convert to %s error",
83+
op1->DebugString(), op2->DebugString());
84+
85+
auto l_it = vars.find(i0->name_);
86+
auto r_it = vars.find(i1->name_);
87+
88+
if (l_it->second < r_it->second) return true;
89+
90+
if (l_it->second == r_it->second) {
91+
return i0->name_ < i1->name_;
92+
}
93+
94+
return false;
95+
});
96+
97+
// add dependency.
98+
auto& sorted_ops = dist_ops;
99+
for (size_t i = 1; i < sorted_ops.size(); ++i) {
100+
auto* dep_var = new DummyVarHandle(graph->CreateControlDepVar());
101+
102+
auto* pre_op = sorted_ops[i - 1];
103+
auto* op = sorted_ops[i];
104+
105+
pre_op->AddOutput(dep_var);
106+
op->AddInput(dep_var);
107+
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
108+
109+
VLOG(10) << "add all_reduce sequential dependencies between " << pre_op
110+
<< " and " << op;
111+
112+
VLOG(10) << "pre_op:" << pre_op->DebugString()
113+
<< ", op:" << op->DebugString();
114+
}
115+
116+
return graph;
117+
}
118+
119+
} // namespace details
120+
} // namespace framework
121+
} // namespace paddle
122+
123+
REGISTER_PASS(all_reduce_deps_pass,
124+
paddle::framework::details::AllReduceDepsPass)
125+
.RequirePassAttr(paddle::framework::details::kAllOpDescs);
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/ir/graph.h"
18+
#include "paddle/fluid/framework/ir/pass.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace details {
23+
24+
// TODO(gongwb): overlap allreduce with backward computation.
25+
class AllReduceDepsPass : public ir::Pass {
26+
protected:
27+
std::unique_ptr<ir::Graph> ApplyImpl(
28+
std::unique_ptr<ir::Graph> graph) const override;
29+
};
30+
31+
} // namespace details
32+
} // namespace framework
33+
} // namespace paddle

paddle/fluid/framework/details/build_strategy.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616

1717
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
1818
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
19+
#include "paddle/fluid/framework/details/reduce_op_handle.h"
1920
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
2021
#include "paddle/fluid/framework/ir/graph.h"
2122
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
@@ -24,6 +25,10 @@ namespace paddle {
2425
namespace framework {
2526
namespace details {
2627

28+
static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) {
29+
return (!strategy.enable_sequential_execution_ && strategy.num_trainers_ > 1);
30+
}
31+
2732
class ParallelExecutorPassBuilder : public ir::PassBuilder {
2833
public:
2934
explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
@@ -70,6 +75,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
7075
// Verify that the graph is correct for multi-device executor.
7176
AppendPass("multi_devices_check_pass");
7277

78+
if (SeqOnlyAllReduceOps(strategy)) {
79+
AppendPass("all_reduce_deps_pass");
80+
}
81+
7382
if (strategy_.remove_unnecessary_lock_) {
7483
AppendPass("modify_op_lock_and_record_event_pass");
7584
}
@@ -124,6 +133,17 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
124133
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
125134
#endif
126135
} else if (pass->Type() == "sequential_execution_pass") {
136+
VLOG(1) << "set enable_sequential_execution:"
137+
<< enable_sequential_execution_;
138+
139+
pass->Erase(kAllOpDescs);
140+
pass->Set<const std::vector<OpDesc *>>(
141+
kAllOpDescs,
142+
new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
143+
} else if (pass->Type() == "all_reduce_deps_pass") {
144+
VLOG(1) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this)
145+
<< ", num_trainers:" << num_trainers_;
146+
127147
pass->Erase(kAllOpDescs);
128148
pass->Set<const std::vector<OpDesc *>>(
129149
kAllOpDescs,
@@ -144,4 +164,5 @@ USE_PASS(multi_devices_pass);
144164
USE_PASS(multi_devices_check_pass);
145165
USE_PASS(multi_devices_print_pass);
146166
USE_PASS(sequential_execution_pass);
167+
USE_PASS(all_reduce_deps_pass);
147168
USE_PASS(modify_op_lock_and_record_event_pass);

paddle/fluid/framework/details/build_strategy.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ struct BuildStrategy {
7373

7474
bool fuse_broadcast_op_{false};
7575

76+
int num_trainers_{1};
7677
bool remove_unnecessary_lock_{false};
7778

7879
// NOTE:

paddle/fluid/framework/parallel_executor.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ limitations under the License. */
2020

2121
#include "paddle/fluid/framework/ir/graph.h"
2222

23-
#ifdef PADDLE_WITH_CUDA
23+
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
2424
#include "paddle/fluid/platform/nccl_helper.h"
2525
#endif
2626

@@ -54,7 +54,7 @@ class ParallelExecutorPrivate {
5454
Scope *global_scope_; // not owned
5555
std::unique_ptr<details::SSAGraphExecutor> executor_;
5656

57-
#ifdef PADDLE_WITH_CUDA
57+
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
5858
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
5959
#endif
6060
bool own_local_scope_;
@@ -104,7 +104,7 @@ ParallelExecutor::ParallelExecutor(
104104

105105
if (member_->use_cuda_) {
106106
// Bcast Parameters to all GPUs
107-
#ifdef PADDLE_WITH_CUDA
107+
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
108108
auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
109109
ncclUniqueId *nccl_id = nullptr;
110110
if (nccl_id_var != nullptr) {
@@ -124,7 +124,7 @@ ParallelExecutor::ParallelExecutor(
124124

125125
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
126126
// ncclOp
127-
#ifdef PADDLE_WITH_CUDA
127+
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
128128
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
129129
main_program, member_->places_, loss_var_name, params,
130130
member_->local_scopes_, member_->use_cuda_, member_->nccl_ctxs_.get());
@@ -213,7 +213,7 @@ void ParallelExecutor::BCastParamsToDevices(
213213
}
214214
auto &dims = main_tensor.dims();
215215
if (paddle::platform::is_gpu_place(main_tensor.place())) {
216-
#ifdef PADDLE_WITH_CUDA
216+
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
217217
std::vector<void *> buffers;
218218
size_t numel = main_tensor.numel();
219219
ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type());

paddle/fluid/inference/analysis/analyzer_tester.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ void TestWord2vecPrediction(const std::string& model_path) {
7676
0.000932706};
7777
const size_t num_elements = outputs.front().data.length() / sizeof(float);
7878
// The outputs' buffers are in CPU memory.
79-
for (size_t i = 0; i < std::min((size_t)5UL, num_elements); i++) {
79+
for (size_t i = 0; i < std::min(static_cast<size_t>(5UL), num_elements);
80+
i++) {
8081
LOG(INFO) << "data: "
8182
<< static_cast<float*>(outputs.front().data.data())[i];
8283
PADDLE_ENFORCE(static_cast<float*>(outputs.front().data.data())[i],

paddle/fluid/memory/allocation/best_fit_allocator_test.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,8 @@ TEST(BestFitAllocator, test_concurrent_cpu_allocation) {
9999

100100
LockedAllocator locked_allocator(std::move(best_fit_allocator));
101101

102-
auto th_main = [&] {
103-
std::random_device dev;
104-
std::default_random_engine engine(dev());
102+
auto th_main = [&](std::random_device::result_type seed) {
103+
std::default_random_engine engine(seed);
105104
std::uniform_int_distribution<size_t> dist(1U, 1024U);
106105

107106
for (size_t i = 0; i < 128; ++i) {
@@ -125,7 +124,8 @@ TEST(BestFitAllocator, test_concurrent_cpu_allocation) {
125124
{
126125
std::vector<std::thread> threads;
127126
for (size_t i = 0; i < 1024; ++i) {
128-
threads.emplace_back(th_main);
127+
std::random_device dev;
128+
threads.emplace_back(th_main, dev());
129129
}
130130
for (auto& th : threads) {
131131
th.join();

0 commit comments

Comments
 (0)