Skip to content

Commit a78a60c

Browse files
authored
Merge pull request #16365 from chengduoZH/cherry-pick_recurrent_op_fix
Cherry pick recurrent op fix
2 parents b29dad2 + b008f5e commit a78a60c

File tree

1 file changed

+29
-16
lines changed

1 file changed

+29
-16
lines changed

paddle/fluid/operators/recurrent_op.cc

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,13 @@ class RecurrentBase : public framework::OperatorBase {
157157
const std::vector<std::string> &src_vars,
158158
framework::Scope *dst_scope,
159159
const std::vector<std::string> &dst_vars,
160-
Callback callback) {
160+
Callback callback,
161+
bool is_backward = false) {
161162
PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size());
162163
for (size_t i = 0; i < dst_vars.size(); ++i) {
163164
VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i];
164-
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback);
165+
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback,
166+
is_backward);
165167
}
166168
}
167169

@@ -173,11 +175,13 @@ class RecurrentBase : public framework::OperatorBase {
173175
const std::vector<std::string> &src_vars,
174176
const framework::Scope &dst_scope,
175177
const std::vector<std::string> &dst_vars,
176-
Callback callback) {
178+
Callback callback,
179+
bool is_backward = false) {
177180
PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size());
178181
for (size_t i = 0; i < dst_vars.size(); ++i) {
179182
VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i];
180-
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback);
183+
AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback,
184+
is_backward);
181185
}
182186
}
183187

@@ -194,9 +198,13 @@ class RecurrentBase : public framework::OperatorBase {
194198
static void AccessTensor(const framework::Scope &src_scope,
195199
const std::string &src_var_name,
196200
framework::Scope *dst_scope,
197-
const std::string &dst_var_name, Callback callback) {
201+
const std::string &dst_var_name, Callback callback,
202+
bool is_backward = false) {
198203
auto *src_var = src_scope.FindVar(src_var_name);
199-
PADDLE_ENFORCE(src_var != nullptr);
204+
if (is_backward && src_var == nullptr) {
205+
return;
206+
}
207+
PADDLE_ENFORCE(src_var != nullptr, "%s is not found.", src_var_name);
200208
auto &src_tensor = src_var->Get<framework::LoDTensor>();
201209

202210
auto *dst_var = dst_scope->Var(dst_var_name);
@@ -208,12 +216,16 @@ class RecurrentBase : public framework::OperatorBase {
208216
static void AccessTensor(const framework::Scope &src_scope,
209217
const std::string &src_var_name,
210218
const framework::Scope &dst_scope,
211-
const std::string &dst_var_name, Callback callback) {
219+
const std::string &dst_var_name, Callback callback,
220+
bool is_backward = false) {
221+
auto *dst_var = dst_scope.FindVar(dst_var_name);
222+
if (is_backward && dst_var == nullptr) {
223+
return;
224+
}
212225
auto *src_var = src_scope.FindVar(src_var_name);
213-
PADDLE_ENFORCE(src_var != nullptr);
226+
PADDLE_ENFORCE(src_var != nullptr, "%s is not found.", src_var_name);
214227
auto &src_tensor = src_var->Get<framework::LoDTensor>();
215-
auto *dst_var = dst_scope.FindVar(dst_var_name);
216-
PADDLE_ENFORCE(dst_var != nullptr);
228+
PADDLE_ENFORCE(dst_var != nullptr, "%s is not found.", dst_var_name);
217229
auto *dst_tensor = dst_var->GetMutable<framework::LoDTensor>();
218230
callback(src_tensor, dst_tensor);
219231
}
@@ -345,7 +357,8 @@ class RecurrentGradOp : public RecurrentBase {
345357
auto dims = framework::vectorize(inside->dims());
346358
dims.erase(dims.begin());
347359
inside->Resize(framework::make_ddim(dims));
348-
});
360+
},
361+
true /*is_backward*/);
349362
auto og_set = List2Set(Inputs(kOutputGrads));
350363

351364
if (VLOG_IS_ON(10)) {
@@ -454,7 +467,8 @@ class RecurrentGradOp : public RecurrentBase {
454467

455468
auto dst = outside->Slice(seq_offset, seq_offset + 1);
456469
framework::TensorCopy(inside, place, dev_ctx, &dst);
457-
});
470+
},
471+
true /*is_backward*/);
458472
VLOG(5) << "Link outside gradient finished ";
459473

460474
if (step_id + 1 == seq_len) { // at_end
@@ -467,7 +481,8 @@ class RecurrentGradOp : public RecurrentBase {
467481
outside->Resize(inside.dims());
468482
outside->mutable_data(place, inside.type());
469483
framework::TensorCopy(inside, place, dev_ctx, outside);
470-
});
484+
},
485+
true /*is_backward*/);
471486
VLOG(5) << "Link initialize state gradient finished ";
472487
}
473488
scopes.Next();
@@ -608,10 +623,8 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
608623
std::vector<std::string> input{kInputs, kInitialStates};
609624
std::vector<std::string> output{kOutputs};
610625
for (auto &s : input) {
626+
// NOTE(zcd): In some case, some of kInputs doesn't have gradient.
611627
PADDLE_ENFORCE(ctx->HasInputs(s));
612-
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(s)),
613-
"Cannot find the gradient variable %s",
614-
framework::GradVarName(s));
615628
}
616629
for (auto &s : output) {
617630
PADDLE_ENFORCE(ctx->HasInputs(s));

0 commit comments

Comments
 (0)