File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -60,7 +60,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
60
60
static framework::DDim ValidateShape (const std::vector<int > shape,
61
61
const framework::DDim &in_dims) {
62
62
const int64_t in_size = framework::product (in_dims);
63
- // only one dimension canbe set to -1, whose size will be automatically
63
+ // only one dimension can be set to -1, whose size will be automatically
64
64
// infered.
65
65
const int64_t unk_dim_val = -1 ;
66
66
const int64_t copy_dim_val = 0 ;
@@ -119,13 +119,15 @@ class ReshapeKernel : public framework::OpKernel<T> {
119
119
auto *shape_tensor = ctx.Input <framework::LoDTensor>(" Shape" );
120
120
121
121
framework::DDim out_dims = out->dims ();
122
+
122
123
if (shape_tensor) {
123
124
auto *shape_data = shape_tensor->data <int >();
125
+ framework::Tensor cpu_shape_tensor;
124
126
if (platform::is_gpu_place (ctx.GetPlace ())) {
125
- framework::Tensor cpu_shape_tensor;
126
127
TensorCopy (*shape_tensor, platform::CPUPlace (), ctx.device_context (),
127
128
&cpu_shape_tensor);
128
129
shape_data = cpu_shape_tensor.data <int >();
130
+ ctx.device_context ().Wait ();
129
131
}
130
132
auto shape =
131
133
std::vector<int >(shape_data, shape_data + shape_tensor->numel ());
You can’t perform that action at this time.
0 commit comments