@@ -18,6 +18,7 @@ limitations under the License. */
18
18
#include < thrust/random.h>
19
19
#include < thrust/transform.h>
20
20
#include " paddle/fluid/operators/dropout_op.h"
21
+ #include " paddle/fluid/platform/float16.h"
21
22
22
23
namespace paddle {
23
24
namespace operators {
@@ -51,7 +52,7 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
51
52
auto * x = context.Input <Tensor>(" X" );
52
53
auto * y = context.Output <Tensor>(" Out" );
53
54
y->mutable_data <T>(context.GetPlace ());
54
- AttrType dropout_prob = context.Attr <AttrType>(" dropout_prob" );
55
+ AttrType dropout_prob = context.Attr <AttrType>(" dropout_prob" )) ;
55
56
56
57
auto X = EigenMatrix<T>::Reshape (*x, 1 );
57
58
auto Y = EigenMatrix<T>::Reshape (*y, 1 );
@@ -74,7 +75,7 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
74
75
context.cuda_device_context().stream()>>> (
75
76
size, seed, dropout_prob, x_data, mask_data, y_data);
76
77
} else {
77
- Y.device (place) = X * (1 .0f - dropout_prob);
78
+ Y.device (place) = X * static_cast <T> (1 .0f - dropout_prob);
78
79
}
79
80
}
80
81
};
@@ -83,9 +84,9 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
83
84
} // namespace paddle
84
85
85
86
namespace ops = paddle::operators;
87
+ namespace plat = paddle::platform;
86
88
REGISTER_OP_CUDA_KERNEL (
87
- dropout,
88
- ops::GPUDropoutKernel<paddle::platform::CUDADeviceContext, float , float >);
89
- REGISTER_OP_CUDA_KERNEL (
90
- dropout_grad,
91
- ops::DropoutGradKernel<paddle::platform::CUDADeviceContext, float >);
89
+ dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float , float >,
90
+ ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16, float >);
91
+ REGISTER_OP_CUDA_KERNEL (dropout_grad,
92
+ ops::DropoutGradKernel<plat::CUDADeviceContext, float >);
0 commit comments