@@ -195,7 +195,6 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
195
195
}
196
196
};
197
197
198
- template <typename T>
199
198
class ReshapeKernel {
200
199
public:
201
200
void operator ()(const framework::ExecutionContext &ctx) const {
@@ -228,25 +227,21 @@ class ReshapeKernel {
228
227
" sequence_reshape op." );
229
228
}
230
229
231
- if (in->data <T>() !=
232
- reinterpret_cast <T *>(out->mutable_data (ctx.GetPlace (), in->type ()))) {
233
- framework::TensorCopySync (*in, ctx.GetPlace (), out);
234
- }
230
+ out->mutable_data (ctx.GetPlace (), in->type ());
231
+ framework::TensorCopySync (*in, ctx.GetPlace (), out);
235
232
out->Resize (out_dims);
236
233
}
237
234
};
238
235
239
- template <typename T>
240
236
class ReshapeGradKernel {
241
237
public:
242
238
void operator ()(const framework::ExecutionContext &ctx) const {
243
239
auto *d_out = ctx.Input <framework::Tensor>(framework::GradVarName (" Out" ));
244
240
auto *d_x = ctx.Output <framework::Tensor>(framework::GradVarName (" X" ));
245
241
auto in_dims = d_x->dims ();
246
242
247
- if (d_out->data <T>() != d_x->mutable_data (ctx.GetPlace (), d_out->type ())) {
248
- framework::TensorCopySync (*d_out, ctx.GetPlace (), d_x);
249
- }
243
+ d_x->mutable_data (ctx.GetPlace (), d_out->type ());
244
+ framework::TensorCopySync (*d_out, ctx.GetPlace (), d_x);
250
245
d_x->Resize (in_dims);
251
246
}
252
247
};
@@ -341,46 +336,38 @@ namespace ops = paddle::operators;
341
336
REGISTER_OPERATOR (reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
342
337
paddle::framework::DefaultGradOpDescMaker<true >);
343
338
REGISTER_OPERATOR (reshape_grad, ops::ReshapeGradOp);
344
- REGISTER_OP_CPU_KERNEL_FUNCTOR (reshape, float , ops::ReshapeKernel<float >,
345
- double , ops::ReshapeKernel<double >, int ,
346
- ops::ReshapeKernel<int >, int64_t ,
347
- ops::ReshapeKernel<int64_t >);
348
- REGISTER_OP_CPU_KERNEL_FUNCTOR (reshape_grad, float ,
349
- ops::ReshapeGradKernel<float >, double ,
350
- ops::ReshapeGradKernel<double >, int ,
351
- ops::ReshapeGradKernel<int >, int64_t ,
352
- ops::ReshapeGradKernel<int64_t >);
339
+ REGISTER_OP_CPU_KERNEL_FUNCTOR (reshape, float , ops::ReshapeKernel, double ,
340
+ ops::ReshapeKernel, int , ops::ReshapeKernel,
341
+ int64_t , ops::ReshapeKernel);
342
+ REGISTER_OP_CPU_KERNEL_FUNCTOR (reshape_grad, float , ops::ReshapeGradKernel,
343
+ double , ops::ReshapeGradKernel, int ,
344
+ ops::ReshapeGradKernel, int64_t ,
345
+ ops::ReshapeGradKernel);
353
346
354
347
REGISTER_OPERATOR (reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
355
348
ops::Reshape2GradMaker);
356
349
REGISTER_OPERATOR (reshape2_grad, ops::Reshape2GradOp);
357
- REGISTER_OP_CPU_KERNEL_FUNCTOR (reshape2, float , ops::ReshapeKernel<float >,
358
- double , ops::ReshapeKernel<double >, int ,
359
- ops::ReshapeKernel<int >, int64_t ,
360
- ops::ReshapeKernel<int64_t >);
361
- REGISTER_OP_CPU_KERNEL_FUNCTOR (reshape2_grad, float ,
362
- ops::ReshapeGradKernel<float >, double ,
363
- ops::ReshapeGradKernel<double >, int ,
364
- ops::ReshapeGradKernel<int >, int64_t ,
365
- ops::ReshapeGradKernel<int64_t >);
350
+ REGISTER_OP_CPU_KERNEL_FUNCTOR (reshape2, float , ops::ReshapeKernel, double ,
351
+ ops::ReshapeKernel, int , ops::ReshapeKernel,
352
+ int64_t , ops::ReshapeKernel);
353
+ REGISTER_OP_CPU_KERNEL_FUNCTOR (reshape2_grad, float , ops::ReshapeGradKernel,
354
+ double , ops::ReshapeGradKernel, int ,
355
+ ops::ReshapeGradKernel, int64_t ,
356
+ ops::ReshapeGradKernel);
366
357
367
358
#ifdef PADDLE_WITH_CUDA
368
- REGISTER_OP_CUDA_KERNEL_FUNCTOR (reshape, float , ops::ReshapeKernel<float >,
369
- double , ops::ReshapeKernel<double >, int ,
370
- ops::ReshapeKernel<int >, int64_t ,
371
- ops::ReshapeKernel<int64_t >);
372
- REGISTER_OP_CUDA_KERNEL_FUNCTOR (reshape_grad, float ,
373
- ops::ReshapeGradKernel<float >, double ,
374
- ops::ReshapeGradKernel<double >, int ,
375
- ops::ReshapeGradKernel<int >, int64_t ,
376
- ops::ReshapeGradKernel<int64_t >);
377
- REGISTER_OP_CUDA_KERNEL_FUNCTOR (reshape2, float , ops::ReshapeKernel<float >,
378
- double , ops::ReshapeKernel<double >, int ,
379
- ops::ReshapeKernel<int >, int64_t ,
380
- ops::ReshapeKernel<int64_t >);
381
- REGISTER_OP_CUDA_KERNEL_FUNCTOR (reshape2_grad, float ,
382
- ops::ReshapeGradKernel<float >, double ,
383
- ops::ReshapeGradKernel<double >, int ,
384
- ops::ReshapeGradKernel<int >, int64_t ,
385
- ops::ReshapeGradKernel<int64_t >);
359
+ REGISTER_OP_CUDA_KERNEL_FUNCTOR (reshape, float , ops::ReshapeKernel, double ,
360
+ ops::ReshapeKernel, int , ops::ReshapeKernel,
361
+ int64_t , ops::ReshapeKernel);
362
+ REGISTER_OP_CUDA_KERNEL_FUNCTOR (reshape_grad, float , ops::ReshapeGradKernel,
363
+ double , ops::ReshapeGradKernel, int ,
364
+ ops::ReshapeGradKernel, int64_t ,
365
+ ops::ReshapeGradKernel);
366
+ REGISTER_OP_CUDA_KERNEL_FUNCTOR (reshape2, float , ops::ReshapeKernel, double ,
367
+ ops::ReshapeKernel, int , ops::ReshapeKernel,
368
+ int64_t , ops::ReshapeKernel);
369
+ REGISTER_OP_CUDA_KERNEL_FUNCTOR (reshape2_grad, float , ops::ReshapeGradKernel,
370
+ double , ops::ReshapeGradKernel, int ,
371
+ ops::ReshapeGradKernel, int64_t ,
372
+ ops::ReshapeGradKernel);
386
373
#endif
0 commit comments