@@ -24,15 +24,16 @@ namespace imperative {
24
24
void CreateGradOp (const framework::OpDesc& op_desc,
25
25
const std::unordered_set<std::string>& no_grad_set,
26
26
const std::vector<framework::BlockDesc*>& grad_sub_block,
27
- framework::OpDesc** grad_op_desc ,
27
+ std::vector< framework::OpDesc*>* grad_op_descs ,
28
28
std::unordered_map<std::string, std::string>* grad_to_var) {
29
- std::vector<std::unique_ptr<framework::OpDesc>> grad_op_descs =
29
+ PADDLE_ENFORCE (grad_op_descs->empty ());
30
+ std::vector<std::unique_ptr<framework::OpDesc>> descs =
30
31
framework::OpInfoMap::Instance ()
31
32
.Get (op_desc.Type ())
32
33
.GradOpMaker ()(op_desc, no_grad_set, grad_to_var, grad_sub_block);
33
- PADDLE_ENFORCE (grad_op_descs. size () == 1 , " Only support 1 grad op now. " );
34
- // TODO(panyx0718): Leak?
35
- *grad_op_desc = grad_op_descs[ 0 ]. release ();
34
+ for ( auto & desc : descs) {
35
+ grad_op_descs-> emplace_back (desc. release ());
36
+ }
36
37
}
37
38
38
39
void InitVar (framework::Variable* var, framework::Variable* grad_var,
@@ -138,49 +139,52 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
138
139
prepared_op.op , scope, *prepared_op.dev_ctx , prepared_op.ctx ));
139
140
140
141
if (!stop_gradient) {
141
- framework::OpDesc* grad_op_desc;
142
- // TODO(panyx): Is this leaked?
143
142
std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var (
144
143
new std::unordered_map<std::string, std::string>());
145
- CreateGradOp (*op_desc, {}, {block}, &grad_op_desc, grad_to_var.get ());
146
- op->grad_op_desc_ = grad_op_desc;
147
-
148
- for (auto it : grad_op_desc->Inputs ()) {
149
- auto & grad_in_vars = op->grad_input_vars_ [it.first ];
150
- for (const std::string& grad_invar : it.second ) {
151
- block->FindRecursiveOrCreateVar (grad_invar);
152
- auto var_it = grad_to_var->find (grad_invar);
153
- if (var_it == grad_to_var->end ()) {
154
- auto fwd_var_it = vars.find (grad_invar);
155
- PADDLE_ENFORCE (fwd_var_it != vars.end ());
156
- // Forward inputs or outputs.
157
- grad_in_vars.push_back (fwd_var_it->second ->var_ );
158
- } else {
144
+ CreateGradOp (*op_desc, {}, {block}, &op->grad_op_descs_ , grad_to_var.get ());
145
+
146
+ op->grad_input_vars_ .resize (op->grad_op_descs_ .size ());
147
+ op->grad_output_vars_ .resize (op->grad_op_descs_ .size ());
148
+ for (size_t i = 0 ; i < op->grad_op_descs_ .size (); ++i) {
149
+ framework::OpDesc* grad_op_desc = op->grad_op_descs_ [i];
150
+ for (auto it : grad_op_desc->Inputs ()) {
151
+ auto & grad_in_vars = op->grad_input_vars_ [i][it.first ];
152
+ for (const std::string& grad_invar : it.second ) {
153
+ block->FindRecursiveOrCreateVar (grad_invar);
154
+ auto var_it = grad_to_var->find (grad_invar);
155
+ if (var_it == grad_to_var->end ()) {
156
+ auto fwd_var_it = vars.find (grad_invar);
157
+ PADDLE_ENFORCE (fwd_var_it != vars.end ());
158
+ // Forward inputs or outputs.
159
+ grad_in_vars.push_back (fwd_var_it->second ->var_ );
160
+ } else {
161
+ VarBase* var = vars[var_it->second ];
162
+ if (!var->grads_ ->var_ ->IsInitialized ()) {
163
+ InitVar (var->var_ , var->grads_ ->var_ ,
164
+ prepared_op.GetDeviceContext ());
165
+ }
166
+ // Douts.
167
+ grad_in_vars.push_back (var->grads_ ->var_ );
168
+ }
169
+ }
170
+ }
171
+
172
+ for (auto it : grad_op_desc->Outputs ()) {
173
+ auto & grad_out_vars = op->grad_output_vars_ [i][it.first ];
174
+ for (const std::string& grad_outvar : it.second ) {
175
+ block->FindRecursiveOrCreateVar (grad_outvar);
176
+ auto var_it = grad_to_var->find (grad_outvar);
177
+ PADDLE_ENFORCE (var_it != grad_to_var->end (),
178
+ " Could not found the grad op output var, should this "
179
+ " operator %s's stop gradient be True" ,
180
+ op_desc->Type ());
159
181
VarBase* var = vars[var_it->second ];
160
182
if (!var->grads_ ->var_ ->IsInitialized ()) {
161
183
InitVar (var->var_ , var->grads_ ->var_ ,
162
184
prepared_op.GetDeviceContext ());
163
185
}
164
- // Douts.
165
- grad_in_vars.push_back (var->grads_ ->var_ );
166
- }
167
- }
168
- }
169
-
170
- for (auto it : grad_op_desc->Outputs ()) {
171
- auto & grad_out_vars = op->grad_output_vars_ [it.first ];
172
- for (const std::string& grad_outvar : it.second ) {
173
- block->FindRecursiveOrCreateVar (grad_outvar);
174
- auto var_it = grad_to_var->find (grad_outvar);
175
- PADDLE_ENFORCE (var_it != grad_to_var->end (),
176
- " Could not found the grad op output var, should this "
177
- " operator %s's stop gradient be True" ,
178
- op_desc->Type ());
179
- VarBase* var = vars[var_it->second ];
180
- if (!var->grads_ ->var_ ->IsInitialized ()) {
181
- InitVar (var->var_ , var->grads_ ->var_ , prepared_op.GetDeviceContext ());
186
+ grad_out_vars.push_back (var->grads_ ->var_ );
182
187
}
183
- grad_out_vars.push_back (var->grads_ ->var_ );
184
188
}
185
189
}
186
190
}
@@ -209,10 +213,12 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
209
213
out->TrackPreOp (op, PyLayer::kFwdOut , i, stop_gradient);
210
214
}
211
215
if (!stop_gradient) {
216
+ op->grad_input_vars_ .resize (1 );
217
+ op->grad_output_vars_ .resize (1 );
212
218
auto & grad_input_vars =
213
- op->grad_input_vars_ [framework::GradVarName (PyLayer::kFwdInp )];
219
+ op->grad_input_vars_ [0 ][ framework::GradVarName (PyLayer::kFwdInp )];
214
220
auto & grad_output_vars =
215
- op->grad_output_vars_ [framework::GradVarName (PyLayer::kFwdOut )];
221
+ op->grad_output_vars_ [0 ][ framework::GradVarName (PyLayer::kFwdOut )];
216
222
217
223
for (const VarBase* inp : inputs) {
218
224
grad_input_vars.push_back (inp->var_ );
0 commit comments