@@ -26,7 +26,8 @@ __global__ void KeNearestNeighborInterpFw(
26
26
const size_t num_channels, const float ratio_h, const float ratio_w) {
27
27
int nthreads = output_h * output_w;
28
28
int tid = blockIdx .x * blockDim .x + threadIdx .x ;
29
- if (tid < nthreads) {
29
+ int stride = blockDim .x * gridDim .x ;
30
+ for (; tid < nthreads; tid += stride) {
30
31
int out_id_h = tid / output_w;
31
32
int out_id_w = tid % output_w;
32
33
int in_img_size = input_w / num_channels;
@@ -52,7 +53,8 @@ __global__ void KeNearestNeighborInterpBw(
52
53
const size_t num_channels, const float ratio_h, const float ratio_w) {
53
54
int nthreads = output_h * output_w;
54
55
int tid = blockIdx .x * blockDim .x + threadIdx .x ;
55
- if (tid < nthreads) {
56
+ int stride = blockDim .x * gridDim .x ;
57
+ for (; tid < nthreads; tid += stride) {
56
58
int out_id_h = tid / output_w;
57
59
int out_id_w = tid % output_w;
58
60
int in_img_size = input_w / num_channels;
@@ -80,7 +82,8 @@ __global__ void KeBilinearInterpFw(
80
82
const size_t num_channels, const float ratio_h, const float ratio_w) {
81
83
int nthreads = output_h * output_w;
82
84
int tid = blockIdx .x * blockDim .x + threadIdx .x ;
83
- if (tid < nthreads) {
85
+ int stride = blockDim .x * gridDim .x ;
86
+ for (; tid < nthreads; tid += stride) {
84
87
int out_id_h = tid / output_w;
85
88
int out_id_w = tid % output_w;
86
89
int in_img_size = input_w / num_channels;
@@ -118,7 +121,8 @@ __global__ void KeBilinearInterpBw(
118
121
const size_t num_channels, const T ratio_h, const T ratio_w) {
119
122
int nthreads = output_h * output_w;
120
123
int tid = blockIdx .x * blockDim .x + threadIdx .x ;
121
- if (tid < nthreads) {
124
+ int stride = blockDim .x * gridDim .x ;
125
+ for (; tid < nthreads; tid += stride) {
122
126
int out_id_h = tid / output_w;
123
127
int out_id_w = tid % output_w;
124
128
int in_img_size = input_w / num_channels;
@@ -194,17 +198,18 @@ class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
194
198
return ;
195
199
}
196
200
197
- int threadNum = n * out_chw;
198
- int blocks = (threadNum + 1024 - 1 ) / 1024 ;
201
+ int pixelNum = n * out_chw;
202
+ int grid_dim = (pixelNum + 512 - 1 ) / 512 ;
203
+ grid_dim = grid_dim > 8 ? 8 : grid_dim;
199
204
200
205
if (" nearest" == interp_method) {
201
206
KeNearestNeighborInterpFw<
202
- T><<<blocks, 1024 , 0 , ctx.cuda_device_context().stream()>>> (
207
+ T><<<grid_dim, 512 , 0 , ctx.cuda_device_context().stream()>>> (
203
208
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
204
209
out_chw, c, ratio_h, ratio_w);
205
210
} else if (" bilinear" == interp_method) {
206
211
KeBilinearInterpFw<
207
- T><<<blocks, 1024 , 0 , ctx.cuda_device_context().stream()>>> (
212
+ T><<<grid_dim, 512 , 0 , ctx.cuda_device_context().stream()>>> (
208
213
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
209
214
out_chw, c, ratio_h, ratio_w);
210
215
}
@@ -257,17 +262,18 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
257
262
return ;
258
263
}
259
264
260
- int threadNum = n * out_chw;
261
- int blocks = (threadNum + 1024 - 1 ) / 1024 ;
265
+ int pixelNum = n * out_chw;
266
+ int grid_dim = (pixelNum + 512 - 1 ) / 512 ;
267
+ grid_dim = grid_dim > 8 ? 8 : grid_dim;
262
268
263
269
if (" nearest" == interp_method) {
264
270
KeNearestNeighborInterpBw<
265
- T><<<blocks, 1024 , 0 , ctx.cuda_device_context().stream()>>> (
271
+ T><<<grid_dim, 512 , 0 , ctx.cuda_device_context().stream()>>> (
266
272
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
267
273
out_w, n, out_chw, c, ratio_h, ratio_w);
268
274
} else if (" bilinear" == interp_method) {
269
275
KeBilinearInterpBw<
270
- T><<<blocks, 1024 , 0 , ctx.cuda_device_context().stream()>>> (
276
+ T><<<grid_dim, 512 , 0 , ctx.cuda_device_context().stream()>>> (
271
277
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
272
278
out_w, n, out_chw, c, ratio_h, ratio_w);
273
279
}
0 commit comments