Skip to content

Commit f855c05

Browse files
authored
Merge pull request #13520 from sneaxiy/enhance_eager_delete
Enhance eager delete and sparse Adam
2 parents 3043f51 + 192c49c commit f855c05

File tree

3 files changed

+123
-37
lines changed

3 files changed

+123
-37
lines changed

paddle/fluid/framework/details/reference_count_op_handle.h

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "paddle/fluid/framework/details/op_handle_base.h"
2323
#include "paddle/fluid/framework/garbage_collector.h"
2424
#include "paddle/fluid/framework/scope.h"
25+
#include "paddle/fluid/framework/selected_rows.h"
2526
#include "paddle/fluid/framework/tensor.h"
2627

2728
namespace paddle {
@@ -46,17 +47,15 @@ class ReferenceCountOpHandle : public OpHandleBase {
4647
const std::vector<std::string> &var_names,
4748
GarbageCollector<Tensor> *gc,
4849
AtomicReferenceCountMap *ref_cnts)
49-
: OpHandleBase(node),
50-
scope_(scope),
51-
var_names_(var_names),
52-
gc_(gc),
53-
ref_cnts_(ref_cnts) {
50+
: OpHandleBase(node), scope_(scope), gc_(gc), ref_cnts_(ref_cnts) {
5451
dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
5552
platform::DeviceContextPool::Instance().Get(place));
5653
if (IsStreamGarabageCollector()) {
5754
PADDLE_ENFORCE(cudaSetDevice(place.device));
5855
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
5956
}
57+
58+
for (auto &name : var_names) AddVar(name);
6059
}
6160

6261
~ReferenceCountOpHandle() {
@@ -69,19 +68,35 @@ class ReferenceCountOpHandle : public OpHandleBase {
6968

7069
std::string Name() const override { return "reference_count"; }
7170

71+
void AddVar(const std::string &name) {
72+
auto it = var_names_.find(name);
73+
if (it != var_names_.end())
74+
++(it->second);
75+
else
76+
var_names_[name] = 1;
77+
}
78+
7279
protected:
7380
void RunImpl() override {
7481
auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
75-
std::vector<LoDTensor *> tensors;
76-
for (auto &name : var_names_) {
82+
std::vector<Tensor *> tensors;
83+
for (auto &pair : var_names_) {
84+
auto &name = pair.first;
7785
auto it = ref_cnts_->find(name);
7886
if (it == ref_cnts_->end()) continue;
7987

8088
auto *var = exec_scope->FindVar(name);
81-
if (var == nullptr || !var->IsType<LoDTensor>()) continue;
82-
83-
if (it->second.fetch_sub(1) <= 1) {
84-
tensors.emplace_back(var->GetMutable<LoDTensor>());
89+
if (var == nullptr) continue;
90+
91+
if (var->IsType<LoDTensor>()) {
92+
if (it->second.fetch_sub(pair.second) <= pair.second) {
93+
tensors.emplace_back(var->GetMutable<LoDTensor>());
94+
}
95+
} else if (var->IsType<SelectedRows>()) {
96+
if (it->second.fetch_sub(pair.second) <= pair.second) {
97+
tensors.emplace_back(
98+
var->GetMutable<SelectedRows>()->mutable_value());
99+
}
85100
}
86101
}
87102

@@ -91,7 +106,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
91106
}
92107

93108
private:
94-
void ClearTensors(const std::vector<LoDTensor *> &tensors) {
109+
void ClearTensors(const std::vector<Tensor *> &tensors) {
95110
auto *gc = dynamic_cast<StreamGarbageCollector<Tensor> *>(gc_);
96111
if (gc != nullptr) {
97112
auto compute_stream = dev_ctx_->stream();
@@ -112,7 +127,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
112127

113128
const Scope *scope_;
114129
platform::CUDADeviceContext *dev_ctx_;
115-
std::vector<std::string> var_names_;
130+
std::unordered_map<std::string, int> var_names_;
116131
GarbageCollector<Tensor> *gc_; // not own
117132
AtomicReferenceCountMap *ref_cnts_; // not own
118133
cudaEvent_t event_;

paddle/fluid/framework/details/reference_count_pass.cc

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include <queue>
1516
#include <string>
1617
#include <vector>
1718

@@ -23,6 +24,25 @@ namespace paddle {
2324
namespace framework {
2425
namespace details {
2526

27+
static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) {
28+
std::queue<VarHandleBase *> queue;
29+
queue.push(var_in);
30+
do {
31+
auto *var = queue.front();
32+
queue.pop();
33+
for (auto *op : var->PendingOps()) {
34+
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
35+
if (compute_op != nullptr && compute_op->GetPlace() == var_in->place_) {
36+
return compute_op;
37+
}
38+
for (auto *out_var : op->Outputs()) {
39+
queue.push(out_var);
40+
}
41+
}
42+
} while (!queue.empty());
43+
return nullptr;
44+
}
45+
2646
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
2747
std::unique_ptr<ir::Graph> graph) const {
2848
auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount);
@@ -34,6 +54,9 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
3454
// Step 2: Find all variables in non-computation ops which refers to variables
3555
// in computation ops
3656
std::unordered_set<std::string> names;
57+
std::unordered_map<OpHandleBase *, std::unique_ptr<ReferenceCountOpHandle>>
58+
compute_ref_cnt_map;
59+
3760
auto get_ref_cnts_from_compute_op = [&](
3861
const std::unique_ptr<OpHandleBase> &op,
3962
const std::vector<VarHandleBase *> &vars) {
@@ -54,15 +77,18 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
5477
VarDesc *var_desc = var_handle->Node()->Var();
5578
auto var_name = var_handle->Node()->Name();
5679

57-
// This is wierd but there is really some variables without var_desc
80+
// This is weird but there is really some variables without var_desc
5881
// in computation_op
5982
if (var_desc == nullptr) {
6083
if (compute_op->Node()->Op()->Block()->FindVar(var_name) == nullptr)
6184
continue;
6285
} else {
63-
if (var_desc->Persistable() ||
64-
var_desc->Proto()->type().type() != proto::VarType::LOD_TENSOR)
86+
if (var_desc->Persistable()) continue;
87+
auto var_type = var_desc->Proto()->type().type();
88+
if (var_type != proto::VarType::LOD_TENSOR &&
89+
var_type != proto::VarType::SELECTED_ROWS) {
6590
continue;
91+
}
6692
}
6793

6894
// compute op only runs in one device
@@ -93,12 +119,33 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
93119
if (ref_cnts.count(place.device) &&
94120
ref_cnts[place.device]->count(var_name)) {
95121
++(*ref_cnts[place.device])[var_name];
122+
123+
auto *next_compute_op = FindNextComputationOpHandle(var_handle);
124+
if (next_compute_op != nullptr) {
125+
if (compute_ref_cnt_map.count(next_compute_op)) {
126+
compute_ref_cnt_map[next_compute_op]->AddVar(var_name);
127+
VLOG(5) << "Add reference count of " << var_name << " to Operator "
128+
<< next_compute_op->Name();
129+
} else {
130+
// Create new reference_count_op_handle
131+
ir::Node *ref_cnt_node = graph->CreateEmptyNode(
132+
"reference_count", ir::Node::Type::kOperation);
133+
auto *ref_cnt_handle = new ReferenceCountOpHandle(
134+
ref_cnt_node, next_compute_op->GetScope(), place, {var_name},
135+
gcs[place.device].get(), cur_ref_cnts[place.device].get());
136+
if (next_compute_op->Outputs().empty()) {
137+
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
138+
next_compute_op->AddOutput(dep_var);
139+
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
140+
}
141+
ref_cnt_handle->AddInput(next_compute_op->Outputs().front());
142+
compute_ref_cnt_map[next_compute_op].reset(ref_cnt_handle);
143+
}
144+
}
96145
}
97146
}
98147
};
99148

100-
std::unordered_map<OpHandleBase *, ReferenceCountOpHandle *>
101-
compute_ref_cnt_map;
102149
auto &all_ops = graph->Get<GraphOps>(kGraphOps);
103150
for (auto &op : all_ops) {
104151
auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs());
@@ -113,11 +160,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
113160
auto *ref_cnt_handle = new ReferenceCountOpHandle(
114161
ref_cnt_node, compute_op->GetScope(), place, in_var_names,
115162
gcs[place.device].get(), cur_ref_cnts[place.device].get());
116-
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
117-
compute_op->AddOutput(dep_var);
118-
ref_cnt_handle->AddInput(dep_var);
119-
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
120-
compute_ref_cnt_map[compute_op] = ref_cnt_handle;
163+
if (compute_op->Outputs().empty()) {
164+
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
165+
compute_op->AddOutput(dep_var);
166+
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
167+
}
168+
ref_cnt_handle->AddInput(compute_op->Outputs().front());
169+
compute_ref_cnt_map[compute_op].reset(ref_cnt_handle);
121170
}
122171

123172
for (auto &op : all_ops) {
@@ -131,7 +180,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
131180
new_all_ops.emplace_back(std::move(op));
132181
auto it = compute_ref_cnt_map.find(new_all_ops.back().get());
133182
if (it != compute_ref_cnt_map.end()) {
134-
new_all_ops.emplace_back(it->second);
183+
// Add LeafNode to ReferenceCountOpHandle
184+
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
185+
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
186+
it->second->AddOutput(dummy_leaf);
187+
new_all_ops.emplace_back(std::move(it->second));
135188
}
136189
}
137190

paddle/fluid/operators/adam_op.h

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#pragma once
1616
#include <math.h> // for sqrt in CPU and CUDA
1717
#include <Eigen/Dense>
18+
#include <vector>
1819
#include "paddle/fluid/framework/op_registry.h"
1920
#include "paddle/fluid/operators/detail/safe_ref.h"
2021
#include "paddle/fluid/operators/math/selected_rows_functor.h"
@@ -306,26 +307,43 @@ class AdamOpKernel : public framework::OpKernel<T> {
306307
VLOG(3) << "grad row size is 0!!";
307308
return;
308309
}
309-
// merge duplicated rows if any.
310-
// The rows of grad_merge have been sorted inside MergeAdd functor
311-
scatter::MergeAdd<DeviceContext, T> merge_func;
312-
auto& grad_merge = *(ctx.scope()
313-
.NewScope()
314-
.Var("sparse_adam_grad_merge")
315-
->GetMutable<framework::SelectedRows>());
316-
merge_func(ctx.template device_context<DeviceContext>(), grad,
317-
&grad_merge);
310+
311+
std::vector<int64_t> cpu_rows(grad.rows().begin(), grad.rows().end());
312+
bool is_strict_sorted = true;
313+
for (size_t i = 1; i < cpu_rows.size(); ++i) {
314+
if (cpu_rows[i - 1] >= cpu_rows[i]) {
315+
is_strict_sorted = false;
316+
break;
317+
}
318+
}
319+
320+
const framework::SelectedRows* grad_merge_ptr;
321+
if (is_strict_sorted) {
322+
grad_merge_ptr = &grad;
323+
} else {
324+
// merge duplicated rows if any.
325+
// The rows of grad_merge have been sorted inside MergeAdd functor
326+
scatter::MergeAdd<DeviceContext, T> merge_func;
327+
auto* grad_merge_var = const_cast<framework::Scope&>(ctx.scope())
328+
.Var()
329+
->GetMutable<framework::SelectedRows>();
330+
merge_func(ctx.template device_context<DeviceContext>(), grad,
331+
grad_merge_var);
332+
grad_merge_ptr = grad_merge_var;
333+
}
334+
335+
auto& grad_merge = *grad_merge_ptr;
318336
auto& grad_tensor = grad_merge.value();
319337
const T* grad_data = grad_tensor.template data<T>();
320-
int64_t* rows = nullptr;
321-
// When compiled without CUDA, the CUDAMutableData() interface should not be
338+
const int64_t* rows = nullptr;
339+
// When compiled without CUDA, the CUDAData() interface should not be
322340
// provided.
323341
#if defined(PADDLE_WITH_CUDA)
324342
if (platform::is_gpu_place(ctx.GetPlace())) {
325-
rows = grad_merge.mutable_rows()->CUDAMutableData(ctx.GetPlace());
343+
rows = grad_merge.rows().CUDAData(ctx.GetPlace());
326344
} else {
327345
#endif
328-
rows = grad_merge.mutable_rows()->data();
346+
rows = grad_merge.rows().data();
329347

330348
#if defined(PADDLE_WITH_CUDA)
331349
}

0 commit comments

Comments
 (0)