@@ -164,7 +164,7 @@ dimension value will be copied from Input(X) at runtime. Note that the index of
164
164
[2, 3, 4], Attr(shape) = [2, 3, 2, 0] is an invalid input.
165
165
166
166
3. Input(Shape) has a higher priority than Attr(shape) if it is provided, while
167
- Attr(shape) still should be set correctly to gurantee shape inference in
167
+ Attr(shape) still should be set correctly to gurantee shape inference in
168
168
compile-time.
169
169
170
170
)DOC" );
@@ -195,6 +195,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
195
195
}
196
196
};
197
197
198
+ template <typename T>
198
199
class ReshapeKernel {
199
200
public:
200
201
void operator ()(const framework::ExecutionContext &ctx) const {
@@ -227,21 +228,25 @@ class ReshapeKernel {
227
228
" sequence_reshape op." );
228
229
}
229
230
230
- out->mutable_data (ctx.GetPlace (), in->type ());
231
- framework::TensorCopySync (*in, ctx.GetPlace (), out);
231
+ if (in->data <T>() !=
232
+ reinterpret_cast <T *>(out->mutable_data (ctx.GetPlace (), in->type ()))) {
233
+ framework::TensorCopySync (*in, ctx.GetPlace (), out);
234
+ }
232
235
out->Resize (out_dims);
233
236
}
234
237
};
235
238
239
+ template <typename T>
236
240
class ReshapeGradKernel {
237
241
public:
238
242
void operator ()(const framework::ExecutionContext &ctx) const {
239
243
auto *d_out = ctx.Input <framework::Tensor>(framework::GradVarName (" Out" ));
240
244
auto *d_x = ctx.Output <framework::Tensor>(framework::GradVarName (" X" ));
241
245
auto in_dims = d_x->dims ();
242
246
243
- d_x->mutable_data (ctx.GetPlace (), d_out->type ());
244
- framework::TensorCopySync (*d_out, ctx.GetPlace (), d_x);
247
+ if (d_out->data <T>() != d_x->mutable_data (ctx.GetPlace (), d_out->type ())) {
248
+ framework::TensorCopySync (*d_out, ctx.GetPlace (), d_x);
249
+ }
245
250
d_x->Resize (in_dims);
246
251
}
247
252
};
@@ -259,7 +264,6 @@ class Reshape2Op : public ReshapeOp {
259
264
: ReshapeOp(type, inputs, outputs, attrs) {}
260
265
261
266
void InferShape (framework::InferShapeContext *ctx) const override {
262
- ReshapeOp::InferShape (ctx);
263
267
PADDLE_ENFORCE (ctx->HasOutput (" XShape" ),
264
268
" Output(XShape) of ReshapeOp should not be null." );
265
269
const auto &x_dims = ctx->GetInputDim (" X" );
@@ -270,6 +274,8 @@ class Reshape2Op : public ReshapeOp {
270
274
}
271
275
ctx->SetOutputDim (" XShape" , framework::make_ddim (xshape_dims));
272
276
ctx->ShareLoD (" X" , /* ->*/ " XShape" );
277
+
278
+ ReshapeOp::InferShape (ctx);
273
279
}
274
280
};
275
281
@@ -335,38 +341,46 @@ namespace ops = paddle::operators;
335
341
REGISTER_OPERATOR (reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
336
342
paddle::framework::DefaultGradOpDescMaker<true >);
337
343
REGISTER_OPERATOR (reshape_grad, ops::ReshapeGradOp);
338
- REGISTER_OP_CPU_KERNEL_FUNCTOR (reshape, float , ops::ReshapeKernel, double ,
339
- ops::ReshapeKernel, int , ops::ReshapeKernel,
340
- int64_t , ops::ReshapeKernel);
341
- REGISTER_OP_CPU_KERNEL_FUNCTOR (reshape_grad, float , ops::ReshapeGradKernel,
342
- double , ops::ReshapeGradKernel, int ,
343
- ops::ReshapeGradKernel, int64_t ,
344
- ops::ReshapeGradKernel);
344
+ REGISTER_OP_CPU_KERNEL_FUNCTOR (reshape, float , ops::ReshapeKernel<float >,
345
+ double , ops::ReshapeKernel<double >, int ,
346
+ ops::ReshapeKernel<int >, int64_t ,
347
+ ops::ReshapeKernel<int64_t >);
348
+ REGISTER_OP_CPU_KERNEL_FUNCTOR (reshape_grad, float ,
349
+ ops::ReshapeGradKernel<float >, double ,
350
+ ops::ReshapeGradKernel<double >, int ,
351
+ ops::ReshapeGradKernel<int >, int64_t ,
352
+ ops::ReshapeGradKernel<int64_t >);
345
353
346
354
REGISTER_OPERATOR (reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
347
355
ops::Reshape2GradMaker);
348
356
REGISTER_OPERATOR (reshape2_grad, ops::Reshape2GradOp);
349
- REGISTER_OP_CPU_KERNEL_FUNCTOR (reshape2, float , ops::ReshapeKernel, double ,
350
- ops::ReshapeKernel, int , ops::ReshapeKernel,
351
- int64_t , ops::ReshapeKernel);
352
- REGISTER_OP_CPU_KERNEL_FUNCTOR (reshape2_grad, float , ops::ReshapeGradKernel,
353
- double , ops::ReshapeGradKernel, int ,
354
- ops::ReshapeGradKernel, int64_t ,
355
- ops::ReshapeGradKernel);
357
+ REGISTER_OP_CPU_KERNEL_FUNCTOR (reshape2, float , ops::ReshapeKernel<float >,
358
+ double , ops::ReshapeKernel<double >, int ,
359
+ ops::ReshapeKernel<int >, int64_t ,
360
+ ops::ReshapeKernel<int64_t >);
361
+ REGISTER_OP_CPU_KERNEL_FUNCTOR (reshape2_grad, float ,
362
+ ops::ReshapeGradKernel<float >, double ,
363
+ ops::ReshapeGradKernel<double >, int ,
364
+ ops::ReshapeGradKernel<int >, int64_t ,
365
+ ops::ReshapeGradKernel<int64_t >);
356
366
357
367
#ifdef PADDLE_WITH_CUDA
358
- REGISTER_OP_CUDA_KERNEL_FUNCTOR (reshape, float , ops::ReshapeKernel, double ,
359
- ops::ReshapeKernel, int , ops::ReshapeKernel,
360
- int64_t , ops::ReshapeKernel);
361
- REGISTER_OP_CUDA_KERNEL_FUNCTOR (reshape_grad, float , ops::ReshapeGradKernel,
362
- double , ops::ReshapeGradKernel, int ,
363
- ops::ReshapeGradKernel, int64_t ,
364
- ops::ReshapeGradKernel);
365
- REGISTER_OP_CUDA_KERNEL_FUNCTOR (reshape2, float , ops::ReshapeKernel, double ,
366
- ops::ReshapeKernel, int , ops::ReshapeKernel,
367
- int64_t , ops::ReshapeKernel);
368
- REGISTER_OP_CUDA_KERNEL_FUNCTOR (reshape2_grad, float , ops::ReshapeGradKernel,
369
- double , ops::ReshapeGradKernel, int ,
370
- ops::ReshapeGradKernel, int64_t ,
371
- ops::ReshapeGradKernel);
368
+ REGISTER_OP_CUDA_KERNEL_FUNCTOR (reshape, float , ops::ReshapeKernel<float >,
369
+ double , ops::ReshapeKernel<double >, int ,
370
+ ops::ReshapeKernel<int >, int64_t ,
371
+ ops::ReshapeKernel<int64_t >);
372
+ REGISTER_OP_CUDA_KERNEL_FUNCTOR (reshape_grad, float ,
373
+ ops::ReshapeGradKernel<float >, double ,
374
+ ops::ReshapeGradKernel<double >, int ,
375
+ ops::ReshapeGradKernel<int >, int64_t ,
376
+ ops::ReshapeGradKernel<int64_t >);
377
+ REGISTER_OP_CUDA_KERNEL_FUNCTOR (reshape2, float , ops::ReshapeKernel<float >,
378
+ double , ops::ReshapeKernel<double >, int ,
379
+ ops::ReshapeKernel<int >, int64_t ,
380
+ ops::ReshapeKernel<int64_t >);
381
+ REGISTER_OP_CUDA_KERNEL_FUNCTOR (reshape2_grad, float ,
382
+ ops::ReshapeGradKernel<float >, double ,
383
+ ops::ReshapeGradKernel<double >, int ,
384
+ ops::ReshapeGradKernel<int >, int64_t ,
385
+ ops::ReshapeGradKernel<int64_t >);
372
386
#endif
0 commit comments