Skip to content

Commit 574bcda

Browse files
committed
Add comments
1 parent 7ffd50b commit 574bcda

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

paddle/fluid/operators/while_op.cc

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -261,35 +261,37 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
261261
for (auto &o : Output(kOutputs)) {
262262
block_ins.insert(o);
263263
}
264-
std::unordered_set<std::string> extra_inputs;
264+
std::unordered_set<std::string> output_grads;
265265
for (const auto *op : grad_block->AllOps()) {
266266
for (auto &input_name : op->InputArgumentNames()) {
267267
// If the input of Op has been recorded or is generated by the forward
268268
// block, do not make it as input again.
269269

270+
// The input is located in I/O or other op's outputs or the variable is
271+
// located in grad_block's parents
270272
if (block_ins.find(input_name) != block_ins.end() ||
271-
fwd_block->FindVar(input_name) != nullptr ||
272-
parent_block->FindVar(input_name) != nullptr) {
273+
(fwd_block->FindVarRecursive(input_name) != nullptr ||
274+
parent_block->FindVarRecursive(input_name) != nullptr)) {
273275
continue;
274276
}
275-
extra_inputs.insert(input_name);
277+
output_grads.insert(input_name);
276278
}
277279
for (auto &output_name : op->OutputArgumentNames()) {
278280
block_ins.insert(output_name);
279281
}
280282
}
281283

282-
std::vector<std::string> extra_inputs_list;
283-
extra_inputs_list.resize(extra_inputs.size());
284-
std::copy(extra_inputs.begin(), extra_inputs.end(),
285-
extra_inputs_list.begin());
286-
while_grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list);
284+
std::vector<std::string> output_grads_list;
285+
output_grads_list.resize(output_grads.size());
286+
std::copy(output_grads.begin(), output_grads.end(),
287+
output_grads_list.begin());
288+
while_grad->SetInput(framework::GradVarName(kOutputs), output_grads_list);
287289

288290
while_grad->SetAttrMap(this->Attrs());
289291
while_grad->SetBlockAttr(kStepBlock, *grad_block);
290292
// record the original output gradient names, since the gradient name of
291293
// while operator could be renamed.
292-
while_grad->SetAttr("original_output_grad", extra_inputs_list);
294+
while_grad->SetAttr("original_output_grad", output_grads_list);
293295

294296
return std::unique_ptr<framework::OpDesc>(while_grad);
295297
}

0 commit comments

Comments
 (0)