@@ -42,16 +42,18 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
42
42
43
43
auto *out_tensor = out_var->GetMutable <framework::LoDTensor>();
44
44
auto &mem_tensor = mem_var->Get <framework::LoDTensor>();
45
- out_tensor-> ShareDataWith (mem_tensor);
45
+ framework::TensorCopySync (mem_tensor, dev_place, out_tensor );
46
46
out_tensor->set_lod (mem_tensor.lod ());
47
47
}
48
48
};
49
49
50
50
class RNNMemoryHelperOpShapeInference : public framework ::InferShapeBase {
51
51
public:
52
52
void operator ()(framework::InferShapeContext *ctx) const override {
53
- PADDLE_ENFORCE (ctx->HasInput (" X" ), " " );
54
- PADDLE_ENFORCE (ctx->HasOutput (" Out" ), " " );
53
+ PADDLE_ENFORCE (ctx->HasInput (" X" ),
54
+ " Input(X) of rnn_memory_helper op should not be null." );
55
+ PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
56
+ " Output of rnn_memory_helper op should not be null." );
55
57
ctx->SetOutputDim (" Out" , ctx->GetInputDim (" X" ));
56
58
ctx->ShareLoD (" X" , /* ->*/ " Out" );
57
59
}
@@ -107,7 +109,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
107
109
} else {
108
110
auto &out_grad_tensor = out_grad_var->Get <framework::LoDTensor>();
109
111
auto *in_grad_tensor = in_grad_var->GetMutable <framework::LoDTensor>();
110
- in_grad_tensor-> ShareDataWith (out_grad_tensor);
112
+ framework::TensorCopySync (out_grad_tensor, dev_place, in_grad_tensor );
111
113
in_grad_tensor->set_lod (out_grad_tensor.lod ());
112
114
}
113
115
}
@@ -133,8 +135,11 @@ class RNNMemoryHelperGradOpShapeInference : public framework::InferShapeBase {
133
135
public:
134
136
void operator ()(framework::InferShapeContext *ctx) const override {
135
137
auto x_grad_name = framework::GradVarName (" X" );
136
- PADDLE_ENFORCE (ctx->HasOutput (x_grad_name), " " );
137
- PADDLE_ENFORCE (ctx->HasInput (" X" ), " " );
138
+ PADDLE_ENFORCE (ctx->HasOutput (x_grad_name),
139
+ " Gradient of Input(X) in rnn_memory_helper_grad of should "
140
+ " not be null." );
141
+ PADDLE_ENFORCE (ctx->HasInput (" X" ),
142
+ " Input(X) of rnn_memory_helper_grad of should not be null." );
138
143
ctx->SetOutputDim (x_grad_name, ctx->GetInputDim (" X" ));
139
144
ctx->ShareLoD (" X" , /* ->*/ x_grad_name);
140
145
}
0 commit comments