Skip to content

Commit 8d8527f

Browse files
author
Yibing Liu
authored
register fp16 kernel for some ops (#22650)
test=release/1.7
1 parent 5d96b6e commit 8d8527f

File tree

5 files changed

+29
-4
lines changed

5 files changed

+29
-4
lines changed

paddle/fluid/operators/conv_transpose_cudnn_op.cu

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
261261
int output_offset =
262262
transformed_output.numel() / transformed_output.dims()[0] / groups;
263263
int filter_offset = filter->numel() / groups;
264-
T alpha = 1.0f, beta = 0.0f;
264+
T alpha = static_cast<T>(1.0), beta = static_cast<T>(0.0);
265265
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
266266
for (int g = 0; g < groups; g++) {
267267
auto cudnn_func = [&](void* cudnn_workspace) {
@@ -507,7 +507,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
507507
int output_grad_offset = transformed_output_grad.numel() /
508508
transformed_output_grad.dims()[0] / groups;
509509
int filter_offset = filter->numel() / groups;
510-
T alpha = 1.0f, beta = 0.0f;
510+
T alpha = static_cast<T>(1.0), beta = static_cast<T>(0.0);
511511
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
512512
if (input_grad) {
513513
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
@@ -569,17 +569,22 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
569569
} // namespace paddle
570570

571571
namespace ops = paddle::operators;
572+
namespace plat = paddle::platform;
572573

573574
REGISTER_OP_KERNEL(conv2d_transpose, CUDNN, ::paddle::platform::CUDAPlace,
575+
ops::CUDNNConvTransposeOpKernel<plat::float16>,
574576
ops::CUDNNConvTransposeOpKernel<float>,
575577
ops::CUDNNConvTransposeOpKernel<double>);
576578
REGISTER_OP_KERNEL(conv2d_transpose_grad, CUDNN, ::paddle::platform::CUDAPlace,
579+
ops::CUDNNConvTransposeGradOpKernel<plat::float16>,
577580
ops::CUDNNConvTransposeGradOpKernel<float>,
578581
ops::CUDNNConvTransposeGradOpKernel<double>);
579582

580583
REGISTER_OP_KERNEL(conv3d_transpose, CUDNN, ::paddle::platform::CUDAPlace,
584+
ops::CUDNNConvTransposeOpKernel<plat::float16>,
581585
ops::CUDNNConvTransposeOpKernel<float>,
582586
ops::CUDNNConvTransposeOpKernel<double>);
583587
REGISTER_OP_KERNEL(conv3d_transpose_grad, CUDNN, ::paddle::platform::CUDAPlace,
588+
ops::CUDNNConvTransposeGradOpKernel<plat::float16>,
584589
ops::CUDNNConvTransposeGradOpKernel<float>,
585590
ops::CUDNNConvTransposeGradOpKernel<double>);

paddle/fluid/operators/expand_op.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,19 @@ limitations under the License. */
1414
#include "paddle/fluid/operators/expand_op.h"
1515

1616
namespace ops = paddle::operators;
17+
namespace plat = paddle::platform;
18+
1719
REGISTER_OP_CUDA_KERNEL(
1820
expand, ops::ExpandKernel<paddle::platform::CUDADeviceContext, float>,
1921
ops::ExpandKernel<paddle::platform::CUDADeviceContext, double>,
22+
ops::ExpandKernel<paddle::platform::CUDADeviceContext, plat::float16>,
2023
ops::ExpandKernel<paddle::platform::CUDADeviceContext, int>,
2124
ops::ExpandKernel<paddle::platform::CUDADeviceContext, int64_t>,
2225
ops::ExpandKernel<paddle::platform::CUDADeviceContext, bool>);
2326
REGISTER_OP_CUDA_KERNEL(
2427
expand_grad,
2528
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, float>,
2629
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, double>,
30+
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
2731
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, int>,
2832
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, int64_t>);

paddle/fluid/operators/pad2d_op.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,8 +461,12 @@ class Pad2dGradCUDAKernel : public framework::OpKernel<T> {
461461
} // namespace paddle
462462

463463
namespace ops = paddle::operators;
464-
REGISTER_OP_CUDA_KERNEL(pad2d, ops::Pad2dCUDAKernel<float>,
464+
namespace plat = paddle::platform;
465+
466+
REGISTER_OP_CUDA_KERNEL(pad2d, ops::Pad2dCUDAKernel<plat::float16>,
467+
ops::Pad2dCUDAKernel<float>,
465468
ops::Pad2dCUDAKernel<double>, ops::Pad2dCUDAKernel<int>,
466469
ops::Pad2dCUDAKernel<int64_t>);
467-
REGISTER_OP_CUDA_KERNEL(pad2d_grad, ops::Pad2dGradCUDAKernel<float>,
470+
REGISTER_OP_CUDA_KERNEL(pad2d_grad, ops::Pad2dGradCUDAKernel<plat::float16>,
471+
ops::Pad2dGradCUDAKernel<float>,
468472
ops::Pad2dGradCUDAKernel<double>);

paddle/fluid/operators/squeeze_op.cu.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,35 @@ limitations under the License. */
1515
#include "paddle/fluid/operators/squeeze_op.h"
1616

1717
namespace ops = paddle::operators;
18+
namespace plat = paddle::platform;
1819

1920
REGISTER_OP_CUDA_KERNEL(
2021
squeeze, ops::SqueezeKernel<paddle::platform::CUDADeviceContext, float>,
2122
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, double>,
23+
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
2224
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int>,
2325
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
2426
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
2527
REGISTER_OP_CUDA_KERNEL(
2628
squeeze_grad,
2729
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, float>,
2830
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, double>,
31+
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
2932
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
3033
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
3134
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
3235
REGISTER_OP_CUDA_KERNEL(
3336
squeeze2, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, float>,
3437
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, double>,
38+
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, plat::float16>,
3539
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int>,
3640
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int8_t>,
3741
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int64_t>);
3842
REGISTER_OP_CUDA_KERNEL(
3943
squeeze2_grad,
4044
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, float>,
4145
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, double>,
46+
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
4247
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
4348
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
4449
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>);

paddle/fluid/operators/unsqueeze_op.cu.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,38 @@ limitations under the License. */
1515
#include "paddle/fluid/operators/unsqueeze_op.h"
1616

1717
namespace ops = paddle::operators;
18+
namespace plat = paddle::platform;
1819

1920
REGISTER_OP_CUDA_KERNEL(
2021
unsqueeze, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>,
2122
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>,
23+
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
2224
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
2325
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
2426
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
2527
REGISTER_OP_CUDA_KERNEL(
2628
unsqueeze_grad,
2729
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, float>,
2830
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, double>,
31+
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext,
32+
plat::float16>,
2933
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
3034
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
3135
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
3236
REGISTER_OP_CUDA_KERNEL(
3337
unsqueeze2,
3438
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>,
3539
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>,
40+
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
3641
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
3742
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
3843
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
3944
REGISTER_OP_CUDA_KERNEL(
4045
unsqueeze2_grad,
4146
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, float>,
4247
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, double>,
48+
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext,
49+
plat::float16>,
4350
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
4451
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
4552
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>);

0 commit comments

Comments
 (0)