@@ -211,59 +211,64 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
211
211
212
212
protected:
213
213
std::unique_ptr<framework::OpDesc> Apply () const override {
214
- auto *grad = new framework::OpDesc ();
215
- grad->SetType (" while_grad" );
216
- grad->SetInput (kX , Input (kX ));
214
+ auto *while_grad = new framework::OpDesc ();
215
+ while_grad->SetType (" while_grad" );
216
+ while_grad->SetInput (kX , Input (kX ));
217
+ while_grad->SetInput (kOutputs , Output (kOutputs ));
218
+ while_grad->SetInput (kStepScopes , Output (kStepScopes ));
219
+
220
+ auto *grad_block = this ->grad_block_ [0 ];
221
+ auto *fwd_block = grad_block->ParentBlock ();
222
+ // auto *parent_block = fwd_block->ParentBlock();
217
223
218
224
// Not all of IGs will be generated by inner gradient operators of while op.
219
225
// Ignore IGs that is not generated by the inside block.
220
- auto igs = InputGrad ( kX , /* do not drop empty gradient */ false ) ;
221
- std::unordered_set<std::string> all_outs ;
222
- for (size_t i = 0 ; i < grad_block_[ 0 ]-> OpSize (); ++i ) {
223
- for (auto &oname : grad_block_[ 0 ]-> Op (i) ->OutputArgumentNames ()) {
224
- all_outs .insert (oname);
226
+ std::unordered_set<std::string> inner_op_outputs ;
227
+ LOG (INFO) << " FUCK1 " ;
228
+ for (const auto *op : grad_block-> AllOps () ) {
229
+ for (auto &oname : op ->OutputArgumentNames ()) {
230
+ inner_op_outputs .insert (oname);
225
231
}
226
232
}
233
+ LOG (INFO) << " FUCK2" ;
234
+ auto igs = InputGrad (kX , /* do not drop empty gradient*/ false );
227
235
for (auto &each_ig : igs) {
228
- if (all_outs .find (each_ig) == all_outs .end ()) {
236
+ if (inner_op_outputs .find (each_ig) == inner_op_outputs .end ()) {
229
237
VLOG (10 ) << " Ignore " << each_ig;
230
238
each_ig = framework::kEmptyVarName ;
231
239
}
232
240
}
233
-
234
- grad->SetOutput (framework::GradVarName (kX ), igs);
235
-
236
- grad->SetInput (kOutputs , Output (kOutputs ));
241
+ while_grad->SetOutput (framework::GradVarName (kX ), igs);
237
242
238
243
// OG should be re-calculated by step blocks, since many outputs of while op
239
244
// do not need to calculate gradients.
240
245
std::unordered_set<std::string> block_ins;
241
- auto *fwd_block = this ->grad_block_ [0 ]->ParentBlock ();
242
- {
243
- for (auto &p : Input (kX )) {
244
- block_ins.insert (p);
245
- }
246
- for (auto &o : Output (kOutputs )) {
247
- block_ins.insert (o);
248
- }
249
- }
246
+ std::copy (Input (kX ).begin (), Input (kX ).end (),
247
+ std::inserter (block_ins, block_ins.end ()));
248
+ std::copy (Output (kOutputs ).begin (), Output (kOutputs ).end (),
249
+ std::inserter (block_ins, block_ins.end ()));
250
+
250
251
std::unordered_set<std::string> extra_inputs;
251
- for (size_t i = 0 ; i < grad_block_[0 ]->OpSize (); ++i) {
252
- for (auto &input_name : grad_block_[0 ]->Op (i)->InputArgumentNames ()) {
253
- if (block_ins.find (input_name) != block_ins.end ()) {
252
+ for (const auto *op : grad_block->AllOps ()) {
253
+ for (auto &input_name : op->InputArgumentNames ()) {
254
+ // If the input of Op has been recorded or is generated by the forward
255
+ // block, do not make it as input again.
256
+ if (block_ins.find (input_name) != block_ins.end () ||
257
+ fwd_block->FindVar (input_name) != nullptr ) {
254
258
continue ;
255
259
}
256
260
257
- // If the input of Op is generated by the forward block, do not make it
258
- // as input again.
259
- if (fwd_block->FindVar (input_name) != nullptr ) {
261
+ /*
262
+ if (parent_block->FindVarRecursive(input_name) == nullptr) {
263
+ VLOG(5) << "WARNING! Variable '" << input_name
264
+ << "' is the input of '" << op->Type()
265
+ << "'. But can not be found in any block.";
260
266
continue;
261
267
}
262
-
268
+ */
263
269
extra_inputs.insert (input_name);
264
270
}
265
-
266
- for (auto &output_name : grad_block_[0 ]->Op (i)->OutputArgumentNames ()) {
271
+ for (auto &output_name : op->OutputArgumentNames ()) {
267
272
block_ins.insert (output_name);
268
273
}
269
274
}
@@ -272,15 +277,15 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
272
277
extra_inputs_list.resize (extra_inputs.size ());
273
278
std::copy (extra_inputs.begin (), extra_inputs.end (),
274
279
extra_inputs_list.begin ());
275
- grad ->SetInput (framework::GradVarName (kOutputs ), extra_inputs_list);
276
- grad-> SetInput ( kStepScopes , Output ( kStepScopes ));
277
- grad ->SetAttrMap (this ->Attrs ());
278
- grad ->SetBlockAttr (kStepBlock , *grad_block_[ 0 ] );
280
+ while_grad ->SetInput (framework::GradVarName (kOutputs ), extra_inputs_list);
281
+
282
+ while_grad ->SetAttrMap (this ->Attrs ());
283
+ while_grad ->SetBlockAttr (kStepBlock , *grad_block );
279
284
// record the original output gradient names, since the gradient name of
280
285
// while operator could be renamed.
281
- grad ->SetAttr (" original_output_grad" , extra_inputs_list);
286
+ while_grad ->SetAttr (" original_output_grad" , extra_inputs_list);
282
287
283
- return std::unique_ptr<framework::OpDesc>(grad );
288
+ return std::unique_ptr<framework::OpDesc>(while_grad );
284
289
}
285
290
};
286
291
0 commit comments