@@ -131,16 +131,7 @@ size_t NodeSize(const VarDesc& node) {
131
131
return type_size * std::abs (size);
132
132
}
133
133
134
- size_t NodeSize (ir::Node* n) {
135
- VarDesc* desc = nullptr ;
136
- // some op do not have block pointer
137
- if (n->inputs [0 ]->Op () != nullptr ) {
138
- desc = FindVarDescInBlock (n);
139
- } else {
140
- desc = n->Var ();
141
- }
142
- return NodeSize (*desc);
143
- }
134
+ size_t NodeSize (ir::Node* n) { return NodeSize (*(n->Var ())); }
144
135
145
136
std::string DebugStringImpl (VarDesc* var) {
146
137
std::stringstream ss;
@@ -163,24 +154,22 @@ std::string DebugStringImpl(VarDesc* var) {
163
154
}
164
155
165
156
std::string DebugString (ir::Node* var) {
166
- return DebugStringImpl (FindVarDescInBlock (var));
157
+ return DebugStringImpl (GetVarDesc (var));
167
158
}
168
159
169
160
// NOTE(dzh): based ir node, if a large node has been reused
170
161
// by a small size node, then next time it appear in pool, it will
171
162
// have the small size. Find the original node shap from blockdesc.
172
- VarDesc* FindVarDescInBlock (ir::Node* n) {
163
+ VarDesc* GetVarDesc (ir::Node* n) {
173
164
PADDLE_ENFORCE (n->IsVar () && !n->IsCtrlVar () && n->inputs .size () == 1 );
174
- BlockDesc* block = n->inputs [0 ]->Op ()->Block ();
175
- PADDLE_ENFORCE (block->HasVar (n->Name ()),
176
- string::Sprintf (" Block do not has var %s" , n->Name ()));
177
- return block->FindVar (n->Name ());
165
+ return n->Var ();
178
166
}
179
167
180
168
struct NodeComparator {
181
169
bool operator ()(ir::Node* lhs, ir::Node* rhs) const {
182
- auto * lhs_desc = FindVarDescInBlock (lhs);
183
- auto * rhs_desc = FindVarDescInBlock (rhs);
170
+ if (lhs->Var ()->GetType () != rhs->Var ()->GetType ()) return false ;
171
+ auto * lhs_desc = GetVarDesc (lhs);
172
+ auto * rhs_desc = GetVarDesc (rhs);
184
173
// match data type
185
174
if (lhs_desc->GetDataType () != rhs_desc->GetDataType ()) {
186
175
return false ;
@@ -204,15 +193,15 @@ void OrderedSet::Insert(ir::Node* var) {
204
193
return ;
205
194
}
206
195
207
- auto * var_desc = FindVarDescInBlock ( var);
196
+ auto * var_desc = var-> Var ( );
208
197
auto var_shape = var_desc->GetShape ();
209
198
int batch_size = static_cast <int >(var_shape[0 ]);
210
199
211
200
NodeComparator functor;
212
201
Iter it = nodes_.begin ();
213
202
while (it != nodes_.end ()) {
214
203
auto & prev = it->front ();
215
- auto * cache_desc = FindVarDescInBlock (prev);
204
+ auto * cache_desc = GetVarDesc (prev);
216
205
int cache_batch_size = cache_desc->GetShape ()[0 ];
217
206
if ((cache_batch_size == -1 && batch_size == -1 ) ||
218
207
(cache_batch_size != -1 && batch_size != -1 )) {
@@ -336,10 +325,16 @@ int MinChunkSize() {
336
325
bool NodeCanReused (const VarDesc& node) {
337
326
auto type = node.GetType ();
338
327
// only these types holds bulk of gpu memory
339
- if (!(type == proto::VarType::LOD_TENSOR ||
340
- type == proto::VarType::LOD_TENSOR_ARRAY)) {
341
- return false ;
342
- }
328
+ // FIXME(liuwei1031) did not find good ways to test SELECTED_ROWS and
329
+ // LOD_TENSOR_ARRAY re-use logic,
330
+ // disable them in version 1.4
331
+ // if (!(type == proto::VarType::LOD_TENSOR ||
332
+ // type == proto::VarType::SELECTED_ROWS ||
333
+ // type == proto::VarType::LOD_TENSOR_ARRAY)) {
334
+ // return false;
335
+ // }
336
+ if (type != proto::VarType::LOD_TENSOR) return false ;
337
+
343
338
// persistable variable is parameter
344
339
if (node.Persistable ()) {
345
340
return false ;
0 commit comments