Skip to content

Commit 03dc7b7

Browse files
authored
Merge pull request #12966 from jerrywgz/fix_rnn_memory_helper
Add error info & remove data sharing between input and output in rnn_…
2 parents 437debf + 6033c1a commit 03dc7b7

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

paddle/fluid/operators/rnn_memory_helper_op.cc

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,18 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
4242

4343
auto *out_tensor = out_var->GetMutable<framework::LoDTensor>();
4444
auto &mem_tensor = mem_var->Get<framework::LoDTensor>();
45-
out_tensor->ShareDataWith(mem_tensor);
45+
framework::TensorCopySync(mem_tensor, dev_place, out_tensor);
4646
out_tensor->set_lod(mem_tensor.lod());
4747
}
4848
};
4949

5050
class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase {
5151
public:
5252
void operator()(framework::InferShapeContext *ctx) const override {
53-
PADDLE_ENFORCE(ctx->HasInput("X"), "");
54-
PADDLE_ENFORCE(ctx->HasOutput("Out"), "");
53+
PADDLE_ENFORCE(ctx->HasInput("X"),
54+
"Input(X) of rnn_memory_helper op should not be null.");
55+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
56+
"Output of rnn_memory_helper op should not be null.");
5557
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
5658
ctx->ShareLoD("X", /*->*/ "Out");
5759
}
@@ -107,7 +109,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
107109
} else {
108110
auto &out_grad_tensor = out_grad_var->Get<framework::LoDTensor>();
109111
auto *in_grad_tensor = in_grad_var->GetMutable<framework::LoDTensor>();
110-
in_grad_tensor->ShareDataWith(out_grad_tensor);
112+
framework::TensorCopySync(out_grad_tensor, dev_place, in_grad_tensor);
111113
in_grad_tensor->set_lod(out_grad_tensor.lod());
112114
}
113115
}
@@ -133,8 +135,11 @@ class RNNMemoryHelperGradOpShapeInference : public framework::InferShapeBase {
133135
public:
134136
void operator()(framework::InferShapeContext *ctx) const override {
135137
auto x_grad_name = framework::GradVarName("X");
136-
PADDLE_ENFORCE(ctx->HasOutput(x_grad_name), "");
137-
PADDLE_ENFORCE(ctx->HasInput("X"), "");
138+
PADDLE_ENFORCE(ctx->HasOutput(x_grad_name),
139+
"Gradient of Input(X) in rnn_memory_helper_grad of should "
140+
"not be null.");
141+
PADDLE_ENFORCE(ctx->HasInput("X"),
142+
"Input(X) of rnn_memory_helper_grad of should not be null.");
138143
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
139144
ctx->ShareLoD("X", /*->*/ x_grad_name);
140145
}

0 commit comments

Comments
 (0)