Skip to content

Commit 06ddaa7

Browse files
authored
Merge pull request #9840 from reyoung/feature/polish_reshape_op
Polish reshape op
2 parents 6f0df4f + daa5011 commit 06ddaa7

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

paddle/fluid/operators/reshape_op.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
6060
static framework::DDim ValidateShape(const std::vector<int> shape,
6161
const framework::DDim &in_dims) {
6262
const int64_t in_size = framework::product(in_dims);
63-
// only one dimension canbe set to -1, whose size will be automatically
63+
// only one dimension can be set to -1, whose size will be automatically
6464
// infered.
6565
const int64_t unk_dim_val = -1;
6666
const int64_t copy_dim_val = 0;
@@ -119,13 +119,15 @@ class ReshapeKernel : public framework::OpKernel<T> {
119119
auto *shape_tensor = ctx.Input<framework::LoDTensor>("Shape");
120120

121121
framework::DDim out_dims = out->dims();
122+
122123
if (shape_tensor) {
123124
auto *shape_data = shape_tensor->data<int>();
125+
framework::Tensor cpu_shape_tensor;
124126
if (platform::is_gpu_place(ctx.GetPlace())) {
125-
framework::Tensor cpu_shape_tensor;
126127
TensorCopy(*shape_tensor, platform::CPUPlace(), ctx.device_context(),
127128
&cpu_shape_tensor);
128129
shape_data = cpu_shape_tensor.data<int>();
130+
ctx.device_context().Wait();
129131
}
130132
auto shape =
131133
std::vector<int>(shape_data, shape_data + shape_tensor->numel());

0 commit comments

Comments
 (0)