@@ -157,11 +157,13 @@ class RecurrentBase : public framework::OperatorBase {
157
157
const std::vector<std::string> &src_vars,
158
158
framework::Scope *dst_scope,
159
159
const std::vector<std::string> &dst_vars,
160
- Callback callback) {
160
+ Callback callback,
161
+ bool is_backward = false ) {
161
162
PADDLE_ENFORCE_EQ (src_vars.size (), dst_vars.size ());
162
163
for (size_t i = 0 ; i < dst_vars.size (); ++i) {
163
164
VLOG (10 ) << " Link " << src_vars[i] << " to " << dst_vars[i];
164
- AccessTensor (src_scope, src_vars[i], dst_scope, dst_vars[i], callback);
165
+ AccessTensor (src_scope, src_vars[i], dst_scope, dst_vars[i], callback,
166
+ is_backward);
165
167
}
166
168
}
167
169
@@ -173,11 +175,13 @@ class RecurrentBase : public framework::OperatorBase {
173
175
const std::vector<std::string> &src_vars,
174
176
const framework::Scope &dst_scope,
175
177
const std::vector<std::string> &dst_vars,
176
- Callback callback) {
178
+ Callback callback,
179
+ bool is_backward = false ) {
177
180
PADDLE_ENFORCE_EQ (src_vars.size (), dst_vars.size ());
178
181
for (size_t i = 0 ; i < dst_vars.size (); ++i) {
179
182
VLOG (10 ) << " Link " << src_vars[i] << " to " << dst_vars[i];
180
- AccessTensor (src_scope, src_vars[i], dst_scope, dst_vars[i], callback);
183
+ AccessTensor (src_scope, src_vars[i], dst_scope, dst_vars[i], callback,
184
+ is_backward);
181
185
}
182
186
}
183
187
@@ -194,9 +198,13 @@ class RecurrentBase : public framework::OperatorBase {
194
198
static void AccessTensor (const framework::Scope &src_scope,
195
199
const std::string &src_var_name,
196
200
framework::Scope *dst_scope,
197
- const std::string &dst_var_name, Callback callback) {
201
+ const std::string &dst_var_name, Callback callback,
202
+ bool is_backward = false ) {
198
203
auto *src_var = src_scope.FindVar (src_var_name);
199
- PADDLE_ENFORCE (src_var != nullptr );
204
+ if (is_backward && src_var == nullptr ) {
205
+ return ;
206
+ }
207
+ PADDLE_ENFORCE (src_var != nullptr , " %s is not found." , src_var_name);
200
208
auto &src_tensor = src_var->Get <framework::LoDTensor>();
201
209
202
210
auto *dst_var = dst_scope->Var (dst_var_name);
@@ -208,12 +216,16 @@ class RecurrentBase : public framework::OperatorBase {
208
216
static void AccessTensor (const framework::Scope &src_scope,
209
217
const std::string &src_var_name,
210
218
const framework::Scope &dst_scope,
211
- const std::string &dst_var_name, Callback callback) {
219
+ const std::string &dst_var_name, Callback callback,
220
+ bool is_backward = false ) {
221
+ auto *dst_var = dst_scope.FindVar (dst_var_name);
222
+ if (is_backward && dst_var == nullptr ) {
223
+ return ;
224
+ }
212
225
auto *src_var = src_scope.FindVar (src_var_name);
213
- PADDLE_ENFORCE (src_var != nullptr );
226
+ PADDLE_ENFORCE (src_var != nullptr , " %s is not found. " , src_var_name );
214
227
auto &src_tensor = src_var->Get <framework::LoDTensor>();
215
- auto *dst_var = dst_scope.FindVar (dst_var_name);
216
- PADDLE_ENFORCE (dst_var != nullptr );
228
+ PADDLE_ENFORCE (dst_var != nullptr , " %s is not found." , dst_var_name);
217
229
auto *dst_tensor = dst_var->GetMutable <framework::LoDTensor>();
218
230
callback (src_tensor, dst_tensor);
219
231
}
@@ -345,7 +357,8 @@ class RecurrentGradOp : public RecurrentBase {
345
357
auto dims = framework::vectorize (inside->dims ());
346
358
dims.erase (dims.begin ());
347
359
inside->Resize (framework::make_ddim (dims));
348
- });
360
+ },
361
+ true /* is_backward*/ );
349
362
auto og_set = List2Set (Inputs (kOutputGrads ));
350
363
351
364
if (VLOG_IS_ON (10 )) {
@@ -454,7 +467,8 @@ class RecurrentGradOp : public RecurrentBase {
454
467
455
468
auto dst = outside->Slice (seq_offset, seq_offset + 1 );
456
469
framework::TensorCopy (inside, place, dev_ctx, &dst);
457
- });
470
+ },
471
+ true /* is_backward*/ );
458
472
VLOG (5 ) << " Link outside gradient finished " ;
459
473
460
474
if (step_id + 1 == seq_len) { // at_end
@@ -467,7 +481,8 @@ class RecurrentGradOp : public RecurrentBase {
467
481
outside->Resize (inside.dims ());
468
482
outside->mutable_data (place, inside.type ());
469
483
framework::TensorCopy (inside, place, dev_ctx, outside);
470
- });
484
+ },
485
+ true /* is_backward*/ );
471
486
VLOG (5 ) << " Link initialize state gradient finished " ;
472
487
}
473
488
scopes.Next ();
@@ -608,10 +623,8 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
608
623
std::vector<std::string> input{kInputs , kInitialStates };
609
624
std::vector<std::string> output{kOutputs };
610
625
for (auto &s : input) {
626
+ // NOTE(zcd): In some case, some of kInputs doesn't have gradient.
611
627
PADDLE_ENFORCE (ctx->HasInputs (s));
612
- PADDLE_ENFORCE (ctx->HasOutputs (framework::GradVarName (s)),
613
- " Cannot find the gradient variable %s" ,
614
- framework::GradVarName (s));
615
628
}
616
629
for (auto &s : output) {
617
630
PADDLE_ENFORCE (ctx->HasInputs (s));
0 commit comments