Skip to content

Commit 8f962f7

Browse files
committed
Update
1 parent 32b09b5 commit 8f962f7

File tree

1 file changed

+42
-37
lines changed

1 file changed

+42
-37
lines changed

paddle/operators/while_op.cc

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -211,59 +211,64 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
211211

212212
protected:
213213
std::unique_ptr<framework::OpDesc> Apply() const override {
214-
auto *grad = new framework::OpDesc();
215-
grad->SetType("while_grad");
216-
grad->SetInput(kX, Input(kX));
214+
auto *while_grad = new framework::OpDesc();
215+
while_grad->SetType("while_grad");
216+
while_grad->SetInput(kX, Input(kX));
217+
while_grad->SetInput(kOutputs, Output(kOutputs));
218+
while_grad->SetInput(kStepScopes, Output(kStepScopes));
219+
220+
auto *grad_block = this->grad_block_[0];
221+
auto *fwd_block = grad_block->ParentBlock();
222+
// auto *parent_block = fwd_block->ParentBlock();
217223

218224
// Not all of IGs will be generated by inner gradient operators of while op.
219225
// Ignore IGs that is not generated by the inside block.
220-
auto igs = InputGrad(kX, /*do not drop empty gradient*/ false);
221-
std::unordered_set<std::string> all_outs;
222-
for (size_t i = 0; i < grad_block_[0]->OpSize(); ++i) {
223-
for (auto &oname : grad_block_[0]->Op(i)->OutputArgumentNames()) {
224-
all_outs.insert(oname);
226+
std::unordered_set<std::string> inner_op_outputs;
227+
LOG(INFO) << "FUCK1";
228+
for (const auto *op : grad_block->AllOps()) {
229+
for (auto &oname : op->OutputArgumentNames()) {
230+
inner_op_outputs.insert(oname);
225231
}
226232
}
233+
LOG(INFO) << "FUCK2";
234+
auto igs = InputGrad(kX, /*do not drop empty gradient*/ false);
227235
for (auto &each_ig : igs) {
228-
if (all_outs.find(each_ig) == all_outs.end()) {
236+
if (inner_op_outputs.find(each_ig) == inner_op_outputs.end()) {
229237
VLOG(10) << "Ignore " << each_ig;
230238
each_ig = framework::kEmptyVarName;
231239
}
232240
}
233-
234-
grad->SetOutput(framework::GradVarName(kX), igs);
235-
236-
grad->SetInput(kOutputs, Output(kOutputs));
241+
while_grad->SetOutput(framework::GradVarName(kX), igs);
237242

238243
// OG should be re-calculated by step blocks, since many outputs of while op
239244
// do not need to calculate gradients.
240245
std::unordered_set<std::string> block_ins;
241-
auto *fwd_block = this->grad_block_[0]->ParentBlock();
242-
{
243-
for (auto &p : Input(kX)) {
244-
block_ins.insert(p);
245-
}
246-
for (auto &o : Output(kOutputs)) {
247-
block_ins.insert(o);
248-
}
249-
}
246+
std::copy(Input(kX).begin(), Input(kX).end(),
247+
std::inserter(block_ins, block_ins.end()));
248+
std::copy(Output(kOutputs).begin(), Output(kOutputs).end(),
249+
std::inserter(block_ins, block_ins.end()));
250+
250251
std::unordered_set<std::string> extra_inputs;
251-
for (size_t i = 0; i < grad_block_[0]->OpSize(); ++i) {
252-
for (auto &input_name : grad_block_[0]->Op(i)->InputArgumentNames()) {
253-
if (block_ins.find(input_name) != block_ins.end()) {
252+
for (const auto *op : grad_block->AllOps()) {
253+
for (auto &input_name : op->InputArgumentNames()) {
254+
// If the input of Op has been recorded or is generated by the forward
255+
// block, do not make it as input again.
256+
if (block_ins.find(input_name) != block_ins.end() ||
257+
fwd_block->FindVar(input_name) != nullptr) {
254258
continue;
255259
}
256260

257-
// If the input of Op is generated by the forward block, do not make it
258-
// as input again.
259-
if (fwd_block->FindVar(input_name) != nullptr) {
261+
/*
262+
if (parent_block->FindVarRecursive(input_name) == nullptr) {
263+
VLOG(5) << "WARNING! Variable '" << input_name
264+
<< "' is the input of '" << op->Type()
265+
<< "'. But can not be found in any block.";
260266
continue;
261267
}
262-
268+
*/
263269
extra_inputs.insert(input_name);
264270
}
265-
266-
for (auto &output_name : grad_block_[0]->Op(i)->OutputArgumentNames()) {
271+
for (auto &output_name : op->OutputArgumentNames()) {
267272
block_ins.insert(output_name);
268273
}
269274
}
@@ -272,15 +277,15 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
272277
extra_inputs_list.resize(extra_inputs.size());
273278
std::copy(extra_inputs.begin(), extra_inputs.end(),
274279
extra_inputs_list.begin());
275-
grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list);
276-
grad->SetInput(kStepScopes, Output(kStepScopes));
277-
grad->SetAttrMap(this->Attrs());
278-
grad->SetBlockAttr(kStepBlock, *grad_block_[0]);
280+
while_grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list);
281+
282+
while_grad->SetAttrMap(this->Attrs());
283+
while_grad->SetBlockAttr(kStepBlock, *grad_block);
279284
// record the original output gradient names, since the gradient name of
280285
// while operator could be renamed.
281-
grad->SetAttr("original_output_grad", extra_inputs_list);
286+
while_grad->SetAttr("original_output_grad", extra_inputs_list);
282287

283-
return std::unique_ptr<framework::OpDesc>(grad);
288+
return std::unique_ptr<framework::OpDesc>(while_grad);
284289
}
285290
};
286291

0 commit comments

Comments
 (0)