Skip to content

Commit cb9c59b

Browse files
authored
cherry-pick PR 16547,16736,16739 test=release/1.4 (#16748)
* fix the bug of reusing different types of variables in memory_optimiz… (#16547) * fix the bug of reusing different types of variables in memory_optimize_pass, test=develop * disable SELECTED_ROWS AND LOD_TENSOR_ARRAY reusage, test=develop * only use the latest version variable for inplace strategy (#16736) * bug-fix, test=develop * tweak code, test=develop * cherry-pick PR 16547,16736,16739 test=release/1.4
1 parent 44f50cf commit cb9c59b

File tree

3 files changed

+30
-29
lines changed

3 files changed

+30
-29
lines changed

paddle/fluid/framework/details/inplace_op_pass.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,12 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
305305

306306
VLOG(4) << "Try to inplace " << in_var_name << " with " << out_var_name;
307307

308+
if (var_nodes_[in_var_name].back() != in_node) {
309+
VLOG(4) << "SKIP since " << in_var_name
310+
<< " is also used as output by other ops";
311+
continue;
312+
}
313+
308314
bool can_replace = true;
309315
if (in_var_name == out_var_name) {
310316
can_replace = false;
@@ -527,6 +533,9 @@ void GraphView::Build(ir::Graph* g) {
527533
};
528534
for (auto& node : g->Nodes()) {
529535
if (!node->IsOp()) continue;
536+
// avoid optimize the variable used in sub-blocks
537+
if (OpHasSubBlock(node->Op())) update_skip_set(node);
538+
530539
if (node->Name() == "send") update_skip_set(node);
531540
if (node->Name() == "recv") update_skip_set(node);
532541
if (node->Name() == "prefetch") update_skip_set(node);

paddle/fluid/framework/details/memory_optimize_helper.cc

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -131,16 +131,7 @@ size_t NodeSize(const VarDesc& node) {
131131
return type_size * std::abs(size);
132132
}
133133

134-
size_t NodeSize(ir::Node* n) {
135-
VarDesc* desc = nullptr;
136-
// some op do not have block pointer
137-
if (n->inputs[0]->Op() != nullptr) {
138-
desc = FindVarDescInBlock(n);
139-
} else {
140-
desc = n->Var();
141-
}
142-
return NodeSize(*desc);
143-
}
134+
size_t NodeSize(ir::Node* n) { return NodeSize(*(n->Var())); }
144135

145136
std::string DebugStringImpl(VarDesc* var) {
146137
std::stringstream ss;
@@ -163,24 +154,22 @@ std::string DebugStringImpl(VarDesc* var) {
163154
}
164155

165156
std::string DebugString(ir::Node* var) {
166-
return DebugStringImpl(FindVarDescInBlock(var));
157+
return DebugStringImpl(GetVarDesc(var));
167158
}
168159

169160
// NOTE(dzh): based ir node, if a large node has been reused
170161
// by a small size node, then next time it appear in pool, it will
171162
// have the small size. Find the original node shap from blockdesc.
172-
VarDesc* FindVarDescInBlock(ir::Node* n) {
163+
VarDesc* GetVarDesc(ir::Node* n) {
173164
PADDLE_ENFORCE(n->IsVar() && !n->IsCtrlVar() && n->inputs.size() == 1);
174-
BlockDesc* block = n->inputs[0]->Op()->Block();
175-
PADDLE_ENFORCE(block->HasVar(n->Name()),
176-
string::Sprintf("Block do not has var %s", n->Name()));
177-
return block->FindVar(n->Name());
165+
return n->Var();
178166
}
179167

180168
struct NodeComparator {
181169
bool operator()(ir::Node* lhs, ir::Node* rhs) const {
182-
auto* lhs_desc = FindVarDescInBlock(lhs);
183-
auto* rhs_desc = FindVarDescInBlock(rhs);
170+
if (lhs->Var()->GetType() != rhs->Var()->GetType()) return false;
171+
auto* lhs_desc = GetVarDesc(lhs);
172+
auto* rhs_desc = GetVarDesc(rhs);
184173
// match data type
185174
if (lhs_desc->GetDataType() != rhs_desc->GetDataType()) {
186175
return false;
@@ -204,15 +193,15 @@ void OrderedSet::Insert(ir::Node* var) {
204193
return;
205194
}
206195

207-
auto* var_desc = FindVarDescInBlock(var);
196+
auto* var_desc = var->Var();
208197
auto var_shape = var_desc->GetShape();
209198
int batch_size = static_cast<int>(var_shape[0]);
210199

211200
NodeComparator functor;
212201
Iter it = nodes_.begin();
213202
while (it != nodes_.end()) {
214203
auto& prev = it->front();
215-
auto* cache_desc = FindVarDescInBlock(prev);
204+
auto* cache_desc = GetVarDesc(prev);
216205
int cache_batch_size = cache_desc->GetShape()[0];
217206
if ((cache_batch_size == -1 && batch_size == -1) ||
218207
(cache_batch_size != -1 && batch_size != -1)) {
@@ -336,10 +325,16 @@ int MinChunkSize() {
336325
bool NodeCanReused(const VarDesc& node) {
337326
auto type = node.GetType();
338327
// only these types holds bulk of gpu memory
339-
if (!(type == proto::VarType::LOD_TENSOR ||
340-
type == proto::VarType::LOD_TENSOR_ARRAY)) {
341-
return false;
342-
}
328+
// FIXME(liuwei1031) did not find good ways to test SELECTED_ROWS and
329+
// LOD_TENSOR_ARRAY re-use logic,
330+
// disable them in version 1.4
331+
// if (!(type == proto::VarType::LOD_TENSOR ||
332+
// type == proto::VarType::SELECTED_ROWS ||
333+
// type == proto::VarType::LOD_TENSOR_ARRAY)) {
334+
// return false;
335+
// }
336+
if (type != proto::VarType::LOD_TENSOR) return false;
337+
343338
// persistable variable is parameter
344339
if (node.Persistable()) {
345340
return false;

paddle/fluid/framework/details/memory_optimize_helper.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <map>
2121
#include <set>
2222
#include <string>
23+
#include <unordered_map>
2324
#include <utility>
2425
#include <vector>
2526
#include "paddle/fluid/framework/data_type.h"
@@ -140,11 +141,7 @@ size_t NodeSize(const VarDesc&);
140141

141142
std::string DebugString(ir::Node* var);
142143

143-
// NOTE(dzhwinter)
144-
// after node reuse, the replaced node shape is
145-
// different with its VarDesc. So need to find the
146-
// correct VarDesc in Block.
147-
VarDesc* FindVarDescInBlock(ir::Node* n);
144+
VarDesc* GetVarDesc(ir::Node* n);
148145

149146
static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) {
150147
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&

0 commit comments

Comments
 (0)