@@ -261,35 +261,37 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
261
261
for (auto &o : Output (kOutputs )) {
262
262
block_ins.insert (o);
263
263
}
264
- std::unordered_set<std::string> extra_inputs ;
264
+ std::unordered_set<std::string> output_grads ;
265
265
for (const auto *op : grad_block->AllOps ()) {
266
266
for (auto &input_name : op->InputArgumentNames ()) {
267
267
// If the input of Op has been recorded or is generated by the forward
268
268
// block, do not make it as input again.
269
269
270
+ // The input is located in I/O or other op's outputs or the variable is
271
+ // located in grad_block's parents
270
272
if (block_ins.find (input_name) != block_ins.end () ||
271
- fwd_block->FindVar (input_name) != nullptr ||
272
- parent_block->FindVar (input_name) != nullptr ) {
273
+ ( fwd_block->FindVarRecursive (input_name) != nullptr ||
274
+ parent_block->FindVarRecursive (input_name) != nullptr ) ) {
273
275
continue ;
274
276
}
275
- extra_inputs .insert (input_name);
277
+ output_grads .insert (input_name);
276
278
}
277
279
for (auto &output_name : op->OutputArgumentNames ()) {
278
280
block_ins.insert (output_name);
279
281
}
280
282
}
281
283
282
- std::vector<std::string> extra_inputs_list ;
283
- extra_inputs_list .resize (extra_inputs .size ());
284
- std::copy (extra_inputs .begin (), extra_inputs .end (),
285
- extra_inputs_list .begin ());
286
- while_grad->SetInput (framework::GradVarName (kOutputs ), extra_inputs_list );
284
+ std::vector<std::string> output_grads_list ;
285
+ output_grads_list .resize (output_grads .size ());
286
+ std::copy (output_grads .begin (), output_grads .end (),
287
+ output_grads_list .begin ());
288
+ while_grad->SetInput (framework::GradVarName (kOutputs ), output_grads_list );
287
289
288
290
while_grad->SetAttrMap (this ->Attrs ());
289
291
while_grad->SetBlockAttr (kStepBlock , *grad_block);
290
292
// record the original output gradient names, since the gradient name of
291
293
// while operator could be renamed.
292
- while_grad->SetAttr (" original_output_grad" , extra_inputs_list );
294
+ while_grad->SetAttr (" original_output_grad" , output_grads_list );
293
295
294
296
return std::unique_ptr<framework::OpDesc>(while_grad);
295
297
}
0 commit comments