Skip to content

Commit fef2faa

Browse files
committed
limit CUDA kernel parallel threads max number to 4096. test=develop
1 parent 34bfae2 commit fef2faa

File tree

2 files changed

+34
-19
lines changed

2 files changed

+34
-19
lines changed

paddle/fluid/operators/interpolate_op.cu

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ __global__ void KeNearestNeighborInterpFw(
2626
const size_t num_channels, const float ratio_h, const float ratio_w) {
2727
int nthreads = output_h * output_w;
2828
int tid = blockIdx.x * blockDim.x + threadIdx.x;
29-
if (tid < nthreads) {
29+
int stride = blockDim.x * gridDim.x;
30+
for (; tid < nthreads; tid += stride) {
3031
int out_id_h = tid / output_w;
3132
int out_id_w = tid % output_w;
3233
int in_img_size = input_w / num_channels;
@@ -52,7 +53,8 @@ __global__ void KeNearestNeighborInterpBw(
5253
const size_t num_channels, const float ratio_h, const float ratio_w) {
5354
int nthreads = output_h * output_w;
5455
int tid = blockIdx.x * blockDim.x + threadIdx.x;
55-
if (tid < nthreads) {
56+
int stride = blockDim.x * gridDim.x;
57+
for (; tid < nthreads; tid += stride) {
5658
int out_id_h = tid / output_w;
5759
int out_id_w = tid % output_w;
5860
int in_img_size = input_w / num_channels;
@@ -80,7 +82,8 @@ __global__ void KeBilinearInterpFw(
8082
const size_t num_channels, const float ratio_h, const float ratio_w) {
8183
int nthreads = output_h * output_w;
8284
int tid = blockIdx.x * blockDim.x + threadIdx.x;
83-
if (tid < nthreads) {
85+
int stride = blockDim.x * gridDim.x;
86+
for (; tid < nthreads; tid += stride) {
8487
int out_id_h = tid / output_w;
8588
int out_id_w = tid % output_w;
8689
int in_img_size = input_w / num_channels;
@@ -118,7 +121,8 @@ __global__ void KeBilinearInterpBw(
118121
const size_t num_channels, const T ratio_h, const T ratio_w) {
119122
int nthreads = output_h * output_w;
120123
int tid = blockIdx.x * blockDim.x + threadIdx.x;
121-
if (tid < nthreads) {
124+
int stride = blockDim.x * gridDim.x;
125+
for (; tid < nthreads; tid += stride) {
122126
int out_id_h = tid / output_w;
123127
int out_id_w = tid % output_w;
124128
int in_img_size = input_w / num_channels;
@@ -194,17 +198,18 @@ class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
194198
return;
195199
}
196200

197-
int threadNum = n * out_chw;
198-
int blocks = (threadNum + 1024 - 1) / 1024;
201+
int pixelNum = n * out_chw;
202+
int grid_dim = (pixelNum + 512 - 1) / 512;
203+
grid_dim = grid_dim > 8 ? 8 : grid_dim;
199204

200205
if ("nearest" == interp_method) {
201206
KeNearestNeighborInterpFw<
202-
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
207+
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
203208
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
204209
out_chw, c, ratio_h, ratio_w);
205210
} else if ("bilinear" == interp_method) {
206211
KeBilinearInterpFw<
207-
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
212+
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
208213
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
209214
out_chw, c, ratio_h, ratio_w);
210215
}
@@ -257,17 +262,18 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
257262
return;
258263
}
259264

260-
int threadNum = n * out_chw;
261-
int blocks = (threadNum + 1024 - 1) / 1024;
265+
int pixelNum = n * out_chw;
266+
int grid_dim = (pixelNum + 512 - 1) / 512;
267+
grid_dim = grid_dim > 8 ? 8 : grid_dim;
262268

263269
if ("nearest" == interp_method) {
264270
KeNearestNeighborInterpBw<
265-
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
271+
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
266272
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
267273
out_w, n, out_chw, c, ratio_h, ratio_w);
268274
} else if ("bilinear" == interp_method) {
269275
KeBilinearInterpBw<
270-
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
276+
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
271277
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
272278
out_w, n, out_chw, c, ratio_h, ratio_w);
273279
}

python/paddle/fluid/tests/unittests/test_interpolate_op.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,13 @@ def init_test_case(self):
167167
self.out_size = np.array([65, 129]).astype("int32")
168168

169169

170-
# class TestBilinearInterpBigScale(TestInterpolateOp):
171-
# def init_test_case(self):
172-
# self.interp_method = 'bilinear'
173-
# self.input_shape = [32, 16, 128, 64]
174-
# self.out_h = 200
175-
# self.out_w = 100
176-
# self.out_size = np.array([201, 101]).astype('int32')
170+
class TestBilinearInterpBigScale(TestInterpolateOp):
171+
def init_test_case(self):
172+
self.interp_method = 'bilinear'
173+
self.input_shape = [4, 4, 64, 32]
174+
self.out_h = 100
175+
self.out_w = 50
176+
self.out_size = np.array([101, 51]).astype('int32')
177177

178178

179179
class TestInterpolateOpUint8(OpTest):
@@ -273,6 +273,15 @@ def init_test_case(self):
273273
self.out_size = np.array([65, 129]).astype("int32")
274274

275275

276+
class TestNearestNeighborInterpBigScale(TestInterpolateOp):
277+
def init_test_case(self):
278+
self.interp_method = 'nearest'
279+
self.input_shape = [4, 4, 64, 32]
280+
self.out_h = 100
281+
self.out_w = 50
282+
self.out_size = np.array([101, 51]).astype('int32')
283+
284+
276285
class TestNearestNeighborInterpCase1Uint8(TestInterpolateOpUint8):
277286
def init_test_case(self):
278287
self.interp_method = 'nearest'

0 commit comments

Comments
 (0)