Skip to content

Commit 67b8a30

Browse files
author
baiyf
authored
Merge pull request #10700 from baiyfbupt/develop
fix roi_pool op bug
2 parents ebc7303 + 1d7f91e commit 67b8a30

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

paddle/fluid/operators/roi_pool_op.cu

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ __global__ void GPUROIPoolForward(
3838
int index = blockIdx.x * blockDim.x + threadIdx.x;
3939
int offset = blockDim.x * gridDim.x;
4040
for (size_t i = index; i < nthreads; i += offset) {
41-
int pw = index % pooled_width;
42-
int ph = (index / pooled_width) % pooled_height;
43-
int c = (index / pooled_width / pooled_height) % channels;
44-
int n = index / pooled_width / pooled_height / channels;
41+
int pw = i % pooled_width;
42+
int ph = (i / pooled_width) % pooled_height;
43+
int c = (i / pooled_width / pooled_height) % channels;
44+
int n = i / pooled_width / pooled_height / channels;
4545

4646
const int64_t* offset_input_rois = input_rois + n * kROISize;
4747
int roi_batch_ind = roi_batch_id_data[n];
@@ -52,14 +52,19 @@ __global__ void GPUROIPoolForward(
5252

5353
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
5454
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
55-
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
56-
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
57-
58-
int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
59-
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
60-
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
61-
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
6255

56+
int hstart = static_cast<int>(floor(static_cast<double>(ph) *
57+
static_cast<double>(roi_height) /
58+
static_cast<double>(pooled_height)));
59+
int wstart = static_cast<int>(floor(static_cast<double>(pw) *
60+
static_cast<double>(roi_width) /
61+
static_cast<double>(pooled_width)));
62+
int hend = static_cast<int>(ceil(static_cast<double>(ph + 1) *
63+
static_cast<double>(roi_height) /
64+
static_cast<double>(pooled_height)));
65+
int wend = static_cast<int>(ceil(static_cast<double>(pw + 1) *
66+
static_cast<double>(roi_width) /
67+
static_cast<double>(pooled_width)));
6368
hstart = min(max(hstart + roi_start_h, 0), height);
6469
hend = min(max(hend + roi_start_h, 0), height);
6570
wstart = min(max(wstart + roi_start_w, 0), width);
@@ -79,9 +84,9 @@ __global__ void GPUROIPoolForward(
7984
}
8085
}
8186
}
82-
output_data[index] = maxval;
87+
output_data[i] = maxval;
8388
if (argmax_data) {
84-
argmax_data[index] = maxidx;
89+
argmax_data[i] = maxidx;
8590
}
8691
}
8792
}
@@ -96,10 +101,10 @@ __global__ void GPUROIPoolBackward(
96101
int index = blockIdx.x * blockDim.x + threadIdx.x;
97102
int offset = blockDim.x * gridDim.x;
98103
for (int i = index; i < nthreads; i += offset) {
99-
int pw = index % pooled_width;
100-
int ph = (index / pooled_width) % pooled_height;
101-
int c = (index / pooled_width / pooled_height) % channels;
102-
int n = index / pooled_width / pooled_height / channels;
104+
int pw = i % pooled_width;
105+
int ph = (i / pooled_width) % pooled_height;
106+
int c = (i / pooled_width / pooled_height) % channels;
107+
int n = i / pooled_width / pooled_height / channels;
103108

104109
int roi_batch_ind = roi_batch_id_data[n];
105110
int input_offset = (roi_batch_ind * channels + c) * height * width;
@@ -138,6 +143,7 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> {
138143
int width = in_dims[3];
139144

140145
int rois_num = rois->dims()[0];
146+
141147
if (rois_num == 0) return;
142148

143149
int output_size = out->numel();

0 commit comments

Comments
 (0)