@@ -147,10 +147,11 @@ void InferMetaFromVecValue(const phi::DenseTensor& x,
147
147
} // namespace
148
148
149
149
template <typename T, typename Context>
150
- void ReshapeInferKernel (const Context& dev_ctx,
151
- const phi::DenseTensor& x,
152
- const phi::IntArray& shape,
153
- phi::DenseTensor* out) {
150
+ void ReshapeKernel (const Context& dev_ctx,
151
+ const phi::DenseTensor& x,
152
+ const phi::IntArray& shape,
153
+ phi::DenseTensor* out) {
154
+ PADDLE_GCU_KERNEL_TRACE (" reshape" );
154
155
PADDLE_ENFORCE_NE (
155
156
x.layout (),
156
157
common::DataLayout::kNDHWC ,
@@ -180,13 +181,13 @@ void ReshapeInferKernel(const Context& dev_ctx,
180
181
}
181
182
182
183
template <typename T, typename Context>
183
- void ReshapeKernel (const Context& dev_ctx,
184
- const phi::DenseTensor& x,
185
- const phi::IntArray& shape,
186
- phi::DenseTensor* out,
187
- phi::DenseTensor* xshape) {
188
- PADDLE_GCU_KERNEL_TRACE (" reshape " );
189
- ReshapeInferKernel <T>(dev_ctx, x, shape, out);
184
+ void ReshapeWithXShapeKernel (const Context& dev_ctx,
185
+ const phi::DenseTensor& x,
186
+ const phi::IntArray& shape,
187
+ phi::DenseTensor* out,
188
+ phi::DenseTensor* xshape) {
189
+ PADDLE_GCU_KERNEL_TRACE (" reshape_with_xshape " );
190
+ ReshapeKernel <T>(dev_ctx, x, shape, out);
190
191
}
191
192
192
193
template <typename T, typename Context>
@@ -251,6 +252,19 @@ PD_REGISTER_PLUGIN_KERNEL(reshape,
251
252
uint8_t ,
252
253
bool ) {}
253
254
255
+ PD_REGISTER_PLUGIN_KERNEL (reshape_with_xshape,
256
+ gcu,
257
+ ALL_LAYOUT,
258
+ custom_kernel::ReshapeWithXShapeKernel,
259
+ float ,
260
+ phi::dtype::float16,
261
+ double ,
262
+ int8_t ,
263
+ int16_t ,
264
+ int32_t ,
265
+ int64_t ,
266
+ uint8_t ,
267
+ bool ) {}
254
268
// PD_REGISTER_PLUGIN_KERNEL(reshape_grad,
255
269
// gcu,
256
270
// ALL_LAYOUT,
0 commit comments