@@ -23,13 +23,13 @@ limitations under the License. */
23
23
namespace paddle {
24
24
namespace operators {
25
25
26
- template <typename T, typename AttrType >
26
+ template <typename T>
27
27
__global__ void RandomGenerator (const size_t n, const int seed,
28
- const AttrType dropout_prob, const T* src,
28
+ const float dropout_prob, const T* src,
29
29
T* mask_data, T* dst) {
30
30
thrust::minstd_rand rng;
31
31
rng.seed (seed);
32
- thrust::uniform_real_distribution<AttrType > dist (0 , 1 );
32
+ thrust::uniform_real_distribution<float > dist (0 , 1 );
33
33
34
34
int idx = blockDim .x * blockIdx .x + threadIdx .x ;
35
35
for (; idx < n; idx += blockDim .x * gridDim .x ) {
@@ -45,14 +45,14 @@ __global__ void RandomGenerator(const size_t n, const int seed,
45
45
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
46
46
// Use std::random and thrust::random(thrust is a std library in CUDA) to
47
47
// implement uniform random.
48
- template <typename Place, typename T, typename AttrType >
48
+ template <typename Place, typename T>
49
49
class GPUDropoutKernel : public framework ::OpKernel<T> {
50
50
public:
51
51
void Compute (const framework::ExecutionContext& context) const override {
52
52
auto * x = context.Input <Tensor>(" X" );
53
53
auto * y = context.Output <Tensor>(" Out" );
54
54
y->mutable_data <T>(context.GetPlace ());
55
- AttrType dropout_prob = context.Attr <AttrType >(" dropout_prob" ) );
55
+ float dropout_prob = context.Attr <float >(" dropout_prob" );
56
56
57
57
auto X = EigenMatrix<T>::Reshape (*x, 1 );
58
58
auto Y = EigenMatrix<T>::Reshape (*y, 1 );
@@ -71,8 +71,8 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
71
71
72
72
int threads = 512 ;
73
73
int grid = (x->numel () + threads - 1 ) / threads;
74
- RandomGenerator<T, AttrType> <<<grid, threads, 0 ,
75
- context.cuda_device_context().stream()>>> (
74
+ RandomGenerator<
75
+ T> <<<grid, threads, 0 , context.cuda_device_context().stream()>>> (
76
76
size, seed, dropout_prob, x_data, mask_data, y_data);
77
77
} else {
78
78
Y.device (place) = X * static_cast <T>(1 .0f - dropout_prob);
@@ -86,7 +86,7 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
86
86
namespace ops = paddle::operators;
87
87
namespace plat = paddle::platform;
88
88
REGISTER_OP_CUDA_KERNEL (
89
- dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float , float >,
90
- ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16, float >);
89
+ dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float >,
90
+ ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>);
91
91
REGISTER_OP_CUDA_KERNEL (dropout_grad,
92
92
ops::DropoutGradKernel<plat::CUDADeviceContext, float >);
0 commit comments