Skip to content

Commit f449180

Browse files
authored
Register more data type for reshape operator. (#8617)
1 parent a67ceba commit f449180

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

paddle/fluid/operators/reshape_op.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,15 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
121121
} // namespace operators
122122
} // namespace paddle
123123
namespace ops = paddle::operators;
124+
using CPU = paddle::platform::CPUDeviceContext;
124125

125126
REGISTER_OP(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, reshape_grad,
126127
ops::ReshapeGradOp);
127-
REGISTER_OP_CPU_KERNEL(reshape,
128-
ops::ReshapeKernel<paddle::platform::CPUPlace, float>);
129-
REGISTER_OP_CPU_KERNEL(
130-
reshape_grad, ops::ReshapeGradKernel<paddle::platform::CPUPlace, float>);
128+
REGISTER_OP_CPU_KERNEL(reshape, ops::ReshapeKernel<CPU, float>,
129+
ops::ReshapeKernel<CPU, double>,
130+
ops::ReshapeKernel<CPU, int>,
131+
ops::ReshapeKernel<CPU, int64_t>);
132+
REGISTER_OP_CPU_KERNEL(reshape_grad, ops::ReshapeGradKernel<CPU, float>,
133+
ops::ReshapeGradKernel<CPU, double>,
134+
ops::ReshapeGradKernel<CPU, int>,
135+
ops::ReshapeGradKernel<CPU, int64_t>);

paddle/fluid/operators/reshape_op.cu

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/reshape_op.h"
16+
using CUDA = paddle::platform::CUDADeviceContext;
1617

17-
REGISTER_OP_CUDA_KERNEL(
18-
reshape,
19-
paddle::operators::ReshapeKernel<paddle::platform::CUDAPlace, float>);
20-
REGISTER_OP_CUDA_KERNEL(
21-
reshape_grad,
22-
paddle::operators::ReshapeGradKernel<paddle::platform::CUDAPlace, float>);
18+
REGISTER_OP_CUDA_KERNEL(reshape, paddle::operators::ReshapeKernel<CUDA, float>,
19+
paddle::operators::ReshapeKernel<CUDA, double>,
20+
paddle::operators::ReshapeKernel<CUDA, int>,
21+
paddle::operators::ReshapeKernel<CUDA, int64_t>);
22+
REGISTER_OP_CUDA_KERNEL(reshape_grad,
23+
paddle::operators::ReshapeGradKernel<CUDA, float>,
24+
paddle::operators::ReshapeGradKernel<CUDA, double>,
25+
paddle::operators::ReshapeGradKernel<CUDA, int>,
26+
paddle::operators::ReshapeGradKernel<CUDA, int64_t>);

0 commit comments

Comments
 (0)