@@ -18,120 +18,88 @@ namespace paddle {
18
18
namespace operators {
19
19
namespace math {
20
20
21
- static const int CUDA_NUM_THREADS = 1024 ;
22
- static const int CUDA_MAX_NUM_BLOCKS = 65535 ;
23
- inline static int GET_NUM_BLOCKS (const int N) {
21
+ #define CUDA_NUM_THREADS 1024
22
+
23
+ // CUDA: grid stride looping
24
+ #define CUDA_KERNEL_LOOP (i, n ) \
25
+ for (int i = blockIdx .x * blockDim .x + threadIdx .x; i < (n); \
26
+ i += blockDim .x * gridDim .x)
27
+
28
+ inline static int PADDLE_GET_BLOCKS (const int N) {
24
29
return (N + CUDA_NUM_THREADS - 1 ) / CUDA_NUM_THREADS;
25
30
}
26
31
27
32
template <typename T>
28
33
__global__ void PReluChannelWiseKernel (const T *input, const T *alpha,
29
- T *output, int channel,
30
- size_t spatial_size) {
31
- size_t offset = blockIdx .x * spatial_size;
32
- const T *in = input + offset;
33
- T *out = output + offset;
34
- T scale = alpha[blockIdx .x % channel];
35
-
36
- for (size_t i = threadIdx .x ; i < spatial_size; i += blockDim .x ) {
37
- T x = in[i];
38
- out[i] = (x > 0 ) ? x : scale * x;
34
+ T *output, size_t channel_num,
35
+ size_t plane_size, size_t numel) {
36
+ size_t index;
37
+ CUDA_KERNEL_LOOP (index, numel) {
38
+ size_t temp = index / plane_size;
39
+ size_t channel_index = temp % channel_num;
40
+ T scale = alpha[channel_index];
41
+ T x = input[index];
42
+ output[index] = (x > 0 ) ? x : scale * x;
39
43
}
40
44
}
41
45
42
46
template <typename T>
43
47
__global__ void PReluElementWiseKernel (const T *input, const T *alpha,
44
- T *output, size_t spatial_size) {
45
- size_t offset = blockIdx .x * spatial_size;
46
- const T *in = input + offset;
47
- const T *scale = alpha + offset;
48
- T *out = output + offset;
49
-
50
- for (size_t i = threadIdx .x ; i < spatial_size; i += blockDim .x ) {
51
- T x = in[i];
52
- out[i] = (x > 0 ) ? x : scale[i] * x;
48
+ T *output, size_t spatial_size,
49
+ size_t numel) {
50
+ size_t index;
51
+ CUDA_KERNEL_LOOP (index, numel) {
52
+ size_t element_index = index % spatial_size;
53
+ T scale = alpha[element_index];
54
+ T x = input[index];
55
+ output[index] = (x > 0 ) ? x : scale * x;
53
56
}
54
57
}
55
58
56
59
template <typename T>
57
60
__global__ void PReluScalarKernel (const T *input, const T *alpha, T *output,
58
- size_t spatial_size) {
59
- size_t offset = blockIdx .x * spatial_size;
60
- const T *in = input + offset;
61
- T scale = *alpha;
62
- T *out = output + offset;
63
-
64
- for (size_t i = threadIdx .x ; i < spatial_size; i += blockDim .x ) {
65
- T x = in[i];
66
- out[i] = (x > 0 ) ? x : scale * x;
61
+ size_t numel) {
62
+ T scale = alpha[0 ];
63
+ size_t index;
64
+ CUDA_KERNEL_LOOP (index, numel) {
65
+ T x = input[index];
66
+ output[index] = (x > 0 ) ? x : scale * x;
67
67
}
68
68
}
69
69
70
- template <typename T>
71
- static inline void PReluChannelWise (cudaStream_t stream, const T *input,
72
- const T *alpha, T *output,
73
- std::vector<int > input_shape) {
74
- size_t unroll = input_shape[0 ] * input_shape[1 ];
75
- size_t spatial_size = input_shape[2 ] * input_shape[3 ];
76
- CHECK_LT (unroll, CUDA_MAX_NUM_BLOCKS);
77
- PReluChannelWiseKernel<<<unroll, CUDA_NUM_THREADS, 0 , stream>>> (
78
- input, alpha, output, input_shape[1 ], spatial_size);
79
- }
80
-
81
- template <typename T>
82
- static inline void PReluElementWise (cudaStream_t stream, const T *input,
83
- const T *alpha, T *output,
84
- std::vector<int > input_shape) {
85
- size_t unroll = input_shape[0 ] * input_shape[1 ];
86
- size_t spatial_size = input_shape[2 ] * input_shape[3 ];
87
- CHECK_LT (unroll, CUDA_MAX_NUM_BLOCKS);
88
- PReluElementWiseKernel<<<unroll, CUDA_NUM_THREADS, 0 , stream>>> (
89
- input, alpha, output, spatial_size);
90
- }
91
-
92
- template <typename T>
93
- static inline void PReluScalar (cudaStream_t stream, const T *input,
94
- const T *alpha, T *output,
95
- std::vector<int > input_shape) {
96
- size_t unroll = input_shape[0 ] * input_shape[1 ];
97
- size_t spatial_size = input_shape[2 ] * input_shape[3 ];
98
- CHECK_LT (unroll, CUDA_MAX_NUM_BLOCKS);
99
- PReluScalarKernel<<<unroll, CUDA_NUM_THREADS, 0 , stream>>> (
100
- input, alpha, output, spatial_size);
101
- }
102
-
103
70
template <typename T>
104
71
void PreluChannelWiseDirectCUDAFunctor<T>::operator ()(
105
72
cudaStream_t stream, const T *input, const T *alpha, T *output,
106
73
std::vector<int > input_shape) {
107
- size_t unroll = input_shape[0 ] * input_shape[1 ];
108
- size_t spatial_size = input_shape[2 ] * input_shape[3 ];
109
- CHECK_LT (unroll, CUDA_MAX_NUM_BLOCKS);
110
- PReluChannelWiseKernel<<<unroll, CUDA_NUM_THREADS, 0 , stream>>> (
111
- input, alpha, output, input_shape[1 ], spatial_size);
74
+ size_t plane_size = input_shape[2 ] * input_shape[3 ];
75
+ size_t spatial_size = input_shape[1 ] * plane_size;
76
+ size_t numel = input_shape[0 ] * spatial_size;
77
+ PReluChannelWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0 ,
78
+ stream>>> (input, alpha, output, input_shape[1 ],
79
+ plane_size, numel);
112
80
}
113
81
114
82
template <typename T>
115
83
void PreluElementWiseDirectCUDAFunctor<T>::operator ()(
116
84
cudaStream_t stream, const T *input, const T *alpha, T *output,
117
85
std::vector<int > input_shape) {
118
- size_t unroll = input_shape[0 ] * input_shape[1 ];
119
- size_t spatial_size = input_shape[2 ] * input_shape[ 3 ] ;
120
- CHECK_LT (unroll, CUDA_MAX_NUM_BLOCKS) ;
121
- PReluElementWiseKernel<<<unroll , CUDA_NUM_THREADS, 0 , stream>>> (
122
- input, alpha, output, spatial_size);
86
+ size_t plane_size = input_shape[2 ] * input_shape[3 ];
87
+ size_t spatial_size = input_shape[1 ] * plane_size ;
88
+ size_t numel = input_shape[ 0 ] * spatial_size ;
89
+ PReluElementWiseKernel<<<PADDLE_GET_BLOCKS(numel) , CUDA_NUM_THREADS, 0 ,
90
+ stream>>> ( input, alpha, output, spatial_size, numel );
123
91
}
124
92
125
93
template <typename T>
126
94
void PreluScalarDirectCUDAFunctor<T>::operator ()(cudaStream_t stream,
127
95
const T *input, const T *alpha,
128
96
T *output,
129
97
std::vector<int > input_shape) {
130
- size_t unroll = input_shape[0 ] * input_shape[1 ];
131
- size_t spatial_size = input_shape[2 ] * input_shape[ 3 ] ;
132
- CHECK_LT (unroll, CUDA_MAX_NUM_BLOCKS) ;
133
- PReluScalarKernel<<<unroll , CUDA_NUM_THREADS, 0 , stream>>> (
134
- input, alpha, output, spatial_size );
98
+ size_t plane_size = input_shape[2 ] * input_shape[3 ];
99
+ size_t spatial_size = input_shape[1 ] * plane_size ;
100
+ size_t numel = input_shape[ 0 ] * spatial_size ;
101
+ PReluScalarKernel<<<PADDLE_GET_BLOCKS(numel) , CUDA_NUM_THREADS, 0 , stream>>> (
102
+ input, alpha, output, numel );
135
103
}
136
104
137
105
template class PreluChannelWiseDirectCUDAFunctor <float >;
0 commit comments