Skip to content

Commit 6b3c2a9

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/remove_trainer_api
2 parents 82b8a3c + a83a4fa commit 6b3c2a9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+1209
-233
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ option(WITH_ANAKIN "Compile with Anakin library" OFF)
6969
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
7070
option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF)
7171
option(WITH_INFERENCE "Compile fluid inference library" ON)
72+
option(WITH_INFERENCE_API_TEST "Test fluid inference high-level api interface" OFF)
7273
option(WITH_SYSTEM_BLAS "Use system blas library" OFF)
7374
option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION})
7475

doc/fluid/dev/releasing_process_en.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# PaddlePaddle Releasing Process
22

3-
PaddlePaddle manages its branches using "git-flow branching model", and [Semantic Versioning](http://semver.org/) as it's version number semantics.
3+
PaddlePaddle manages its branches using Trunk Based Development, and [Semantic Versioning](http://semver.org/) as it's version number semantics.
44

55
Each time we release a new PaddlePaddle version, we should follow the below steps:
66

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,20 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
2828
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
2929
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
3030

31+
if(WITH_GPU)
32+
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle
33+
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
34+
endif()
35+
3136
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
3237
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle)
3338

34-
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto)
39+
if(WITH_GPU)
40+
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass)
41+
else()
42+
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto)
43+
endif()
44+
3545
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
3646
simple_threadpool device_context)
3747

paddle/fluid/framework/details/computation_op_handle.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ struct ComputationOpHandle : public OpHandleBase {
3232

3333
std::string Name() const override;
3434

35+
const Scope *GetScope() const { return scope_; }
36+
37+
const platform::Place &GetPlace() const { return place_; }
38+
3539
protected:
3640
void RunImpl() override;
3741

paddle/fluid/framework/details/multi_devices_graph_pass.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ static const char kLocalScopes[] = "local_scopes";
127127
static const char kStrategy[] = "strategy";
128128

129129
void MultiDevSSAGraphBuilder::Init() const {
130+
all_vars_.clear();
131+
balance_vars_.clear();
132+
130133
loss_var_name_ = Get<const std::string>(kLossVarName);
131134
places_ = Get<const std::vector<platform::Place>>(kPlaces);
132135
local_scopes_ = Get<const std::vector<Scope *>>(kLocalScopes);

paddle/fluid/framework/details/multi_devices_graph_pass.h

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
4040
size_t device_id) const;
4141
void Init() const;
4242

43-
private:
44-
mutable std::string loss_var_name_;
45-
mutable std::vector<platform::Place> places_;
46-
mutable std::vector<Scope *> local_scopes_;
47-
mutable std::unordered_set<std::string> grad_names_;
48-
4943
#ifdef PADDLE_WITH_CUDA
5044
mutable platform::NCCLContextMap *nccl_ctxs_;
5145
#endif
@@ -95,13 +89,17 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
9589
size_t GetAppropriateDeviceID(
9690
const std::vector<std::string> &var_names) const;
9791

98-
private:
92+
void SetCommunicationContext(OpHandleBase *op_handle,
93+
const platform::Place &p) const;
94+
95+
mutable std::string loss_var_name_;
96+
mutable std::vector<platform::Place> places_;
97+
mutable std::vector<Scope *> local_scopes_;
98+
mutable std::unordered_set<std::string> grad_names_;
99+
99100
mutable BuildStrategy strategy_;
100101
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
101102
mutable std::vector<int64_t> balance_vars_;
102-
103-
void SetCommunicationContext(OpHandleBase *op_handle,
104-
const platform::Place &p) const;
105103
};
106104
} // namespace details
107105
} // namespace framework
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <atomic>
18+
#include <string>
19+
#include <unordered_map>
20+
#include <vector>
21+
22+
#include "paddle/fluid/framework/details/op_handle_base.h"
23+
#include "paddle/fluid/framework/garbage_collector.h"
24+
#include "paddle/fluid/framework/scope.h"
25+
#include "paddle/fluid/framework/tensor.h"
26+
27+
namespace paddle {
28+
namespace framework {
29+
namespace details {
30+
31+
using ReferenceCountMap = std::unordered_map<std::string, int>;
32+
using AtomicReferenceCountMap =
33+
std::unordered_map<std::string, std::atomic<int>>;
34+
using DeviceReferenceCountMap =
35+
std::unordered_map<int, std::unique_ptr<ReferenceCountMap>>;
36+
using AtomicDeviceReferenceCountMap =
37+
std::unordered_map<int, std::unique_ptr<AtomicReferenceCountMap>>;
38+
using DeviceGarbageCollectorMap =
39+
std::unordered_map<int,
40+
std::unique_ptr<GarbageCollector<framework::Tensor>>>;
41+
42+
class ReferenceCountOpHandle : public OpHandleBase {
43+
public:
44+
ReferenceCountOpHandle(ir::Node *node, const Scope *scope,
45+
const platform::CUDAPlace &place,
46+
const std::vector<std::string> &var_names,
47+
GarbageCollector<Tensor> *gc,
48+
AtomicReferenceCountMap *ref_cnts)
49+
: OpHandleBase(node),
50+
scope_(scope),
51+
var_names_(var_names),
52+
gc_(gc),
53+
ref_cnts_(ref_cnts) {
54+
dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
55+
platform::DeviceContextPool::Instance().Get(place));
56+
if (IsStreamGarabageCollector()) {
57+
PADDLE_ENFORCE(cudaSetDevice(place.device));
58+
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
59+
}
60+
}
61+
62+
~ReferenceCountOpHandle() {
63+
if (IsStreamGarabageCollector()) {
64+
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
65+
PADDLE_ENFORCE(cudaSetDevice(gpu_place.device));
66+
PADDLE_ENFORCE(cudaEventDestroy(event_));
67+
}
68+
}
69+
70+
std::string Name() const override { return "reference_count"; }
71+
72+
protected:
73+
void RunImpl() override {
74+
auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
75+
std::vector<LoDTensor *> tensors;
76+
for (auto &name : var_names_) {
77+
auto it = ref_cnts_->find(name);
78+
if (it == ref_cnts_->end()) continue;
79+
80+
auto *var = exec_scope->FindVar(name);
81+
if (var == nullptr || !var->IsType<LoDTensor>()) continue;
82+
83+
if (it->second.fetch_sub(1) <= 1) {
84+
tensors.emplace_back(var->GetMutable<LoDTensor>());
85+
}
86+
}
87+
88+
if (!tensors.empty()) {
89+
ClearTensors(tensors);
90+
}
91+
}
92+
93+
private:
94+
void ClearTensors(const std::vector<LoDTensor *> &tensors) {
95+
auto *gc = dynamic_cast<StreamGarbageCollector<Tensor> *>(gc_);
96+
if (gc != nullptr) {
97+
auto compute_stream = dev_ctx_->stream();
98+
auto callback_stream = gc->stream();
99+
auto callback_func = [=]() {
100+
PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream));
101+
PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0));
102+
};
103+
gc_->Add(tensors, callback_func);
104+
} else {
105+
gc_->Add(tensors);
106+
}
107+
}
108+
109+
bool IsStreamGarabageCollector() const {
110+
return dynamic_cast<const StreamGarbageCollector<Tensor> *>(gc_) != nullptr;
111+
}
112+
113+
const Scope *scope_;
114+
platform::CUDADeviceContext *dev_ctx_;
115+
std::vector<std::string> var_names_;
116+
GarbageCollector<Tensor> *gc_; // not own
117+
AtomicReferenceCountMap *ref_cnts_; // not own
118+
cudaEvent_t event_;
119+
};
120+
121+
} // namespace details
122+
} // namespace framework
123+
} // namespace paddle
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <string>
16+
#include <vector>
17+
18+
#include "paddle/fluid/framework/details/computation_op_handle.h"
19+
#include "paddle/fluid/framework/details/multi_devices_helper.h"
20+
#include "paddle/fluid/framework/details/reference_count_pass.h"
21+
22+
namespace paddle {
23+
namespace framework {
24+
namespace details {
25+
26+
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
27+
std::unique_ptr<ir::Graph> graph) const {
28+
auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount);
29+
auto &cur_ref_cnts = Get<AtomicDeviceReferenceCountMap>(kCurReferenceCount);
30+
auto &gcs = Get<DeviceGarbageCollectorMap>(kGarbageCollector);
31+
32+
// It is not easy to find the right reference counts of varaibles in graph
33+
// Step 1: Find all variables in computation ops
34+
// Step 2: Find all variables in non-computation ops which refers to variables
35+
// in computation ops
36+
std::unordered_set<std::string> names;
37+
auto get_ref_cnts_from_compute_op = [&](
38+
const std::unique_ptr<OpHandleBase> &op,
39+
const std::vector<VarHandleBase *> &vars) {
40+
std::vector<std::string> var_names_in_op;
41+
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
42+
if (compute_op == nullptr ||
43+
!platform::is_gpu_place(compute_op->GetPlace()))
44+
return var_names_in_op;
45+
auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace());
46+
for (VarHandleBase *var_handle_base : vars) {
47+
auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base);
48+
if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue;
49+
50+
if (!platform::is_gpu_place(var_handle->place_) ||
51+
boost::get<platform::CUDAPlace>(var_handle->place_) != place)
52+
continue;
53+
54+
VarDesc *var_desc = var_handle->Node()->Var();
55+
auto var_name = var_handle->Node()->Name();
56+
57+
// This is wierd but there is really some variables without var_desc
58+
// in computation_op
59+
if (var_desc == nullptr) {
60+
if (compute_op->Node()->Op()->Block()->FindVar(var_name) == nullptr)
61+
continue;
62+
} else {
63+
if (var_desc->Persistable() ||
64+
var_desc->Proto()->type().type() != proto::VarType::LOD_TENSOR)
65+
continue;
66+
}
67+
68+
// compute op only runs in one device
69+
if (ref_cnts[place.device]->count(var_name))
70+
++(*ref_cnts[place.device])[var_name];
71+
else
72+
(*ref_cnts[place.device])[var_name] = 1;
73+
74+
names.insert(var_name);
75+
var_names_in_op.push_back(var_name);
76+
}
77+
return var_names_in_op;
78+
};
79+
80+
auto update_ref_cnts_from_non_compute_op = [&](
81+
const std::unique_ptr<OpHandleBase> &op,
82+
const std::vector<VarHandleBase *> &vars) {
83+
if (dynamic_cast<ComputationOpHandle *>(op.get()) != nullptr) return;
84+
for (VarHandleBase *var_handle_base : vars) {
85+
auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base);
86+
if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue;
87+
88+
auto var_name = var_handle->Node()->Name();
89+
auto var_place = var_handle->place_;
90+
if (!platform::is_gpu_place(var_place)) continue;
91+
auto place = boost::get<platform::CUDAPlace>(var_place);
92+
if (names.count(var_name) == 0) continue;
93+
if (ref_cnts.count(place.device) &&
94+
ref_cnts[place.device]->count(var_name)) {
95+
++(*ref_cnts[place.device])[var_name];
96+
}
97+
}
98+
};
99+
100+
std::unordered_map<OpHandleBase *, ReferenceCountOpHandle *>
101+
compute_ref_cnt_map;
102+
auto &all_ops = graph->Get<GraphOps>(kGraphOps);
103+
for (auto &op : all_ops) {
104+
auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs());
105+
auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs());
106+
if (in_var_names.empty() && out_var_names.empty()) continue;
107+
in_var_names.insert(in_var_names.end(), out_var_names.begin(),
108+
out_var_names.end());
109+
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get());
110+
auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace());
111+
ir::Node *ref_cnt_node =
112+
graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation);
113+
auto *ref_cnt_handle = new ReferenceCountOpHandle(
114+
ref_cnt_node, compute_op->GetScope(), place, in_var_names,
115+
gcs[place.device].get(), cur_ref_cnts[place.device].get());
116+
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
117+
compute_op->AddOutput(dep_var);
118+
ref_cnt_handle->AddInput(dep_var);
119+
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
120+
compute_ref_cnt_map[compute_op] = ref_cnt_handle;
121+
}
122+
123+
for (auto &op : all_ops) {
124+
update_ref_cnts_from_non_compute_op(op, op->Inputs());
125+
update_ref_cnts_from_non_compute_op(op, op->Outputs());
126+
}
127+
128+
std::vector<std::unique_ptr<OpHandleBase>> new_all_ops;
129+
new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size());
130+
for (auto &op : all_ops) {
131+
new_all_ops.emplace_back(std::move(op));
132+
auto it = compute_ref_cnt_map.find(new_all_ops.back().get());
133+
if (it != compute_ref_cnt_map.end()) {
134+
new_all_ops.emplace_back(it->second);
135+
}
136+
}
137+
138+
all_ops.swap(new_all_ops);
139+
return graph;
140+
}
141+
142+
} // namespace details
143+
} // namespace framework
144+
} // namespace paddle
145+
146+
REGISTER_PASS(reference_count_pass,
147+
paddle::framework::details::ReferenceCountPass)
148+
.RequirePassAttr(paddle::framework::details::kGlobalReferenceCount)
149+
.RequirePassAttr(paddle::framework::details::kCurReferenceCount)
150+
.RequirePassAttr(paddle::framework::details::kGarbageCollector);

0 commit comments

Comments
 (0)