Skip to content

Commit 4054947

Browse files
author
lilong12
authored
[Cherry-pick] fix the computation for dx (grad for x) for prelu operation. (#20949) (#21514)
* fix the computation for dx (grad for x) for prelu operation. (#20949) * set the default value of alpha for prelu to 0.25, test=develop * add the call to __syncthreads(), test=develop * fix the implementation of cpu prelu, test=develop * repair the implementation of element mode prelu, test=develop * modify test_prelu_op.py, test=develop
1 parent e3dd13b commit 4054947

File tree

10 files changed

+147
-173
lines changed

10 files changed

+147
-173
lines changed

paddle/fluid/operators/math/prelu.cu

Lines changed: 47 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -18,120 +18,88 @@ namespace paddle {
1818
namespace operators {
1919
namespace math {
2020

21-
static const int CUDA_NUM_THREADS = 1024;
22-
static const int CUDA_MAX_NUM_BLOCKS = 65535;
23-
inline static int GET_NUM_BLOCKS(const int N) {
21+
#define CUDA_NUM_THREADS 1024
22+
23+
// CUDA: grid stride looping
24+
#define CUDA_KERNEL_LOOP(i, n) \
25+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
26+
i += blockDim.x * gridDim.x)
27+
28+
inline static int PADDLE_GET_BLOCKS(const int N) {
2429
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
2530
}
2631

2732
template <typename T>
2833
__global__ void PReluChannelWiseKernel(const T *input, const T *alpha,
29-
T *output, int channel,
30-
size_t spatial_size) {
31-
size_t offset = blockIdx.x * spatial_size;
32-
const T *in = input + offset;
33-
T *out = output + offset;
34-
T scale = alpha[blockIdx.x % channel];
35-
36-
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
37-
T x = in[i];
38-
out[i] = (x > 0) ? x : scale * x;
34+
T *output, size_t channel_num,
35+
size_t plane_size, size_t numel) {
36+
size_t index;
37+
CUDA_KERNEL_LOOP(index, numel) {
38+
size_t temp = index / plane_size;
39+
size_t channel_index = temp % channel_num;
40+
T scale = alpha[channel_index];
41+
T x = input[index];
42+
output[index] = (x > 0) ? x : scale * x;
3943
}
4044
}
4145

4246
template <typename T>
4347
__global__ void PReluElementWiseKernel(const T *input, const T *alpha,
44-
T *output, size_t spatial_size) {
45-
size_t offset = blockIdx.x * spatial_size;
46-
const T *in = input + offset;
47-
const T *scale = alpha + offset;
48-
T *out = output + offset;
49-
50-
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
51-
T x = in[i];
52-
out[i] = (x > 0) ? x : scale[i] * x;
48+
T *output, size_t spatial_size,
49+
size_t numel) {
50+
size_t index;
51+
CUDA_KERNEL_LOOP(index, numel) {
52+
size_t element_index = index % spatial_size;
53+
T scale = alpha[element_index];
54+
T x = input[index];
55+
output[index] = (x > 0) ? x : scale * x;
5356
}
5457
}
5558

5659
template <typename T>
5760
__global__ void PReluScalarKernel(const T *input, const T *alpha, T *output,
58-
size_t spatial_size) {
59-
size_t offset = blockIdx.x * spatial_size;
60-
const T *in = input + offset;
61-
T scale = *alpha;
62-
T *out = output + offset;
63-
64-
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
65-
T x = in[i];
66-
out[i] = (x > 0) ? x : scale * x;
61+
size_t numel) {
62+
T scale = alpha[0];
63+
size_t index;
64+
CUDA_KERNEL_LOOP(index, numel) {
65+
T x = input[index];
66+
output[index] = (x > 0) ? x : scale * x;
6767
}
6868
}
6969

70-
template <typename T>
71-
static inline void PReluChannelWise(cudaStream_t stream, const T *input,
72-
const T *alpha, T *output,
73-
std::vector<int> input_shape) {
74-
size_t unroll = input_shape[0] * input_shape[1];
75-
size_t spatial_size = input_shape[2] * input_shape[3];
76-
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
77-
PReluChannelWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
78-
input, alpha, output, input_shape[1], spatial_size);
79-
}
80-
81-
template <typename T>
82-
static inline void PReluElementWise(cudaStream_t stream, const T *input,
83-
const T *alpha, T *output,
84-
std::vector<int> input_shape) {
85-
size_t unroll = input_shape[0] * input_shape[1];
86-
size_t spatial_size = input_shape[2] * input_shape[3];
87-
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
88-
PReluElementWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
89-
input, alpha, output, spatial_size);
90-
}
91-
92-
template <typename T>
93-
static inline void PReluScalar(cudaStream_t stream, const T *input,
94-
const T *alpha, T *output,
95-
std::vector<int> input_shape) {
96-
size_t unroll = input_shape[0] * input_shape[1];
97-
size_t spatial_size = input_shape[2] * input_shape[3];
98-
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
99-
PReluScalarKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
100-
input, alpha, output, spatial_size);
101-
}
102-
10370
template <typename T>
10471
void PreluChannelWiseDirectCUDAFunctor<T>::operator()(
10572
cudaStream_t stream, const T *input, const T *alpha, T *output,
10673
std::vector<int> input_shape) {
107-
size_t unroll = input_shape[0] * input_shape[1];
108-
size_t spatial_size = input_shape[2] * input_shape[3];
109-
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
110-
PReluChannelWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
111-
input, alpha, output, input_shape[1], spatial_size);
74+
size_t plane_size = input_shape[2] * input_shape[3];
75+
size_t spatial_size = input_shape[1] * plane_size;
76+
size_t numel = input_shape[0] * spatial_size;
77+
PReluChannelWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0,
78+
stream>>>(input, alpha, output, input_shape[1],
79+
plane_size, numel);
11280
}
11381

11482
template <typename T>
11583
void PreluElementWiseDirectCUDAFunctor<T>::operator()(
11684
cudaStream_t stream, const T *input, const T *alpha, T *output,
11785
std::vector<int> input_shape) {
118-
size_t unroll = input_shape[0] * input_shape[1];
119-
size_t spatial_size = input_shape[2] * input_shape[3];
120-
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
121-
PReluElementWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
122-
input, alpha, output, spatial_size);
86+
size_t plane_size = input_shape[2] * input_shape[3];
87+
size_t spatial_size = input_shape[1] * plane_size;
88+
size_t numel = input_shape[0] * spatial_size;
89+
PReluElementWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0,
90+
stream>>>(input, alpha, output, spatial_size, numel);
12391
}
12492

12593
template <typename T>
12694
void PreluScalarDirectCUDAFunctor<T>::operator()(cudaStream_t stream,
12795
const T *input, const T *alpha,
12896
T *output,
12997
std::vector<int> input_shape) {
130-
size_t unroll = input_shape[0] * input_shape[1];
131-
size_t spatial_size = input_shape[2] * input_shape[3];
132-
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
133-
PReluScalarKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
134-
input, alpha, output, spatial_size);
98+
size_t plane_size = input_shape[2] * input_shape[3];
99+
size_t spatial_size = input_shape[1] * plane_size;
100+
size_t numel = input_shape[0] * spatial_size;
101+
PReluScalarKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, stream>>>(
102+
input, alpha, output, numel);
135103
}
136104

137105
template class PreluChannelWiseDirectCUDAFunctor<float>;

paddle/fluid/operators/math/prelu.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class PreluScalarDirectCUDAFunctor {
4242
void operator()(cudaStream_t stream, const T *input, const T *alpha,
4343
T *output, std::vector<int> input_shape);
4444
};
45+
4546
#endif
4647

4748
} // namespace math

paddle/fluid/operators/prelu_op.cc

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,21 @@ class PReluOp : public framework::OperatorWithKernel {
4242
"equal to the number of channels, should be %d",
4343
x_dim[1]);
4444
} else if (mode == "element") {
45-
PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == product(x_dim),
46-
"For element-wise mode, size of weight Alpha must be "
47-
"equal to the number of input, should be %d",
48-
product(x_dim));
45+
auto alpha_dim = ctx->GetInputDim("Alpha");
46+
auto alpha_rank = alpha_dim.size();
47+
auto x_rank = x_dim.size();
48+
size_t x_product = 1;
49+
size_t alpha_product = 1;
50+
PADDLE_ENFORCE_EQ(alpha_rank, x_rank,
51+
"For element-wise mode, rank of weight Alpha must be ",
52+
"equal to the rank of input.");
53+
for (int64_t i = x_rank - 1; i > 0; i--) {
54+
x_product *= x_dim[i];
55+
alpha_product *= alpha_dim[i];
56+
}
57+
PADDLE_ENFORCE_EQ(x_product, alpha_product,
58+
"For element-wise mode, size of weight Alpha must be "
59+
"equal to the number of input.");
4960
} else {
5061
PADDLE_THROW("Unkown mode %s", mode);
5162
}

paddle/fluid/operators/prelu_op.cu

Lines changed: 55 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,19 @@ limitations under the License. */
2020
namespace paddle {
2121
namespace operators {
2222

23-
static const int CUDA_NUM_THREADS = 1024;
24-
static const int CUDA_MAX_NUM_BLOCKS = 65535;
25-
2623
using Tensor = framework::Tensor;
2724

25+
#define CUDA_NUM_THREADS 1024
26+
27+
// CUDA: grid stride looping
28+
#define CUDA_KERNEL_LOOP(i, n) \
29+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
30+
i += blockDim.x * gridDim.x)
31+
32+
inline static int PADDLE_GET_BLOCKS(const int N) {
33+
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
34+
}
35+
2836
template <typename DeviceContext, typename T>
2937
class CUDAPReluKernel : public framework::OpKernel<T> {
3038
public:
@@ -59,71 +67,47 @@ class CUDAPReluKernel : public framework::OpKernel<T> {
5967
}
6068
};
6169

62-
namespace prelu {
63-
struct ElementWiseMode {};
64-
struct ChannelMode {};
65-
struct ScalarMode {};
66-
} /* namespace prelu */
67-
68-
template <typename T, typename M>
69-
struct AlphaFunctor {
70-
HOSTDEVICE inline T operator()(const T* alpha, size_t channel,
71-
size_t spatial_size, size_t idx) const {}
72-
};
73-
74-
template <typename T>
75-
struct AlphaFunctor<T, prelu::ElementWiseMode> {
76-
HOSTDEVICE inline T operator()(const T* alpha, size_t channel,
77-
size_t spatial_size, size_t idx) const {
78-
return alpha[blockIdx.x * spatial_size + idx];
79-
}
80-
};
70+
enum PRELU_MODE { Element, Channel, Scalar };
8171

8272
template <typename T>
83-
struct AlphaFunctor<T, prelu::ChannelMode> {
84-
HOSTDEVICE inline T operator()(const T* alpha, size_t channel,
85-
size_t spatial_size, size_t idx) const {
86-
return alpha[blockIdx.x % channel];
87-
}
88-
};
89-
90-
template <typename T>
91-
struct AlphaFunctor<T, prelu::ScalarMode> {
92-
HOSTDEVICE inline T operator()(const T* alpha, size_t channel,
93-
size_t spatial_size, size_t idx) const {
94-
return alpha[0];
95-
}
96-
};
97-
98-
template <typename T, typename M>
99-
__global__ void PReluGradElementWiseKernel(const T* x_ptr, const T* y_ptr,
100-
const T* alpha_ptr, const T* dy_ptr,
101-
T* dx_ptr, T* dalpha_ptr,
102-
size_t channel,
103-
size_t spatial_size) {
104-
size_t offset = blockIdx.x * spatial_size;
105-
AlphaFunctor<T, M> alpha_func;
106-
107-
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
108-
T y = y_ptr[offset + i];
109-
T x = x_ptr[offset + i];
110-
T dy = dy_ptr[offset + i];
111-
T alpha = alpha_func(alpha_ptr, channel, spatial_size, i);
112-
if (dx_ptr != nullptr) dx_ptr[offset + i] = (y > 0) ? dy : alpha * dy;
113-
if (dalpha_ptr != nullptr) dalpha_ptr[offset + i] = (x > 0) ? 0 : x * dy;
73+
__global__ void PReluOpGradKernel(const T* x_ptr, const T* y_ptr,
74+
const T* alpha_ptr, const T* dy_ptr,
75+
T* dx_ptr, T* dalpha_ptr, size_t channel_num,
76+
size_t plane_size, size_t spatial_size,
77+
size_t numel, PRELU_MODE mode) {
78+
size_t index;
79+
CUDA_KERNEL_LOOP(index, numel) {
80+
T scale;
81+
if (mode == Element) {
82+
size_t element_index = index % spatial_size;
83+
scale = alpha_ptr[element_index];
84+
} else if (mode == Channel) {
85+
size_t temp = index / plane_size;
86+
size_t channel_index = temp % channel_num;
87+
scale = alpha_ptr[channel_index];
88+
} else {
89+
scale = alpha_ptr[0];
90+
}
91+
T x = x_ptr[index];
92+
T dy = dy_ptr[index];
93+
if (dx_ptr != nullptr) dx_ptr[index] = (x > 0) ? dy : scale * dy;
94+
if (dalpha_ptr != nullptr) dalpha_ptr[index] = (x > 0) ? 0 : x * dy;
11495
}
11596
}
11697

117-
template <typename T, typename M>
118-
class PreluGradElementwiseFunctor {
98+
template <typename T>
99+
class PreluOpGradFunctor {
119100
public:
120101
void operator()(cudaStream_t stream, const T* x, const T* y, const T* alpha,
121-
const T* dy, T* dx, T* dalpha, std::vector<int> input_shape) {
122-
size_t unroll = input_shape[0] * input_shape[1];
123-
size_t spatial_size = input_shape[2] * input_shape[3];
124-
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
125-
PReluGradElementWiseKernel<T, M><<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
126-
x, y, alpha, dy, dx, dalpha, input_shape[1], spatial_size);
102+
const T* dy, T* dx, T* dalpha, std::vector<int> input_shape,
103+
PRELU_MODE mode) {
104+
size_t plane_size = input_shape[2] * input_shape[3];
105+
size_t spatial_size = plane_size * input_shape[1];
106+
size_t numel = spatial_size * input_shape[0];
107+
PReluOpGradKernel<
108+
T><<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, stream>>>(
109+
x, y, alpha, dy, dx, dalpha, input_shape[1], plane_size, spatial_size,
110+
numel, mode);
127111
}
128112
};
129113

@@ -162,33 +146,32 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
162146

163147
T* dalpha_tmp_ptr;
164148
Tensor dalpha_tmp;
165-
if (mode == "element" || dalpha_ptr == nullptr) {
149+
if (dalpha_ptr == nullptr) {
166150
dalpha_tmp_ptr = dalpha_ptr;
167151
} else {
168152
auto& dev_ctx = context.template device_context<DeviceContext>();
169153
dalpha_tmp = context.AllocateTmpTensor<T, DeviceContext>(dim, dev_ctx);
170154
dalpha_tmp_ptr = dalpha_tmp.mutable_data<T>(context.GetPlace());
171155
}
172156

157+
PRELU_MODE m;
173158
if (mode == "element") {
174-
PreluGradElementwiseFunctor<T, prelu::ElementWiseMode> prelu_grad;
175-
prelu_grad(stream, x_ptr, y_ptr, alpha_ptr, dy_ptr, dx_ptr,
176-
dalpha_tmp_ptr, input_shape);
159+
m = Element;
177160
} else if (mode == "channel") {
178-
PreluGradElementwiseFunctor<T, prelu::ChannelMode> prelu_grad;
179-
prelu_grad(stream, x_ptr, y_ptr, alpha_ptr, dy_ptr, dx_ptr,
180-
dalpha_tmp_ptr, input_shape);
161+
m = Channel;
181162
} else {
182-
PreluGradElementwiseFunctor<T, prelu::ScalarMode> prelu_grad;
183-
prelu_grad(stream, x_ptr, y_ptr, alpha_ptr, dy_ptr, dx_ptr,
184-
dalpha_tmp_ptr, input_shape);
163+
m = Scalar;
185164
}
165+
PreluOpGradFunctor<T> prelu_grad;
166+
prelu_grad(stream, x_ptr, y_ptr, alpha_ptr, dy_ptr, dx_ptr, dalpha_tmp_ptr,
167+
input_shape, m);
186168

187-
if (mode == "element" || dalpha_tmp_ptr == nullptr) return;
169+
if (dalpha_tmp_ptr == nullptr) return;
188170

189171
std::vector<int> reduce_dims;
190172
for (size_t i = 0; i < input_shape.size(); i++) {
191173
if (mode == "channel" && i == 1) continue;
174+
if (mode == "element" && i != 0) continue;
192175
reduce_dims.push_back(i);
193176
}
194177

0 commit comments

Comments
 (0)