@@ -38,10 +38,10 @@ __global__ void GPUROIPoolForward(
38
38
int index = blockIdx .x * blockDim .x + threadIdx .x ;
39
39
int offset = blockDim .x * gridDim .x ;
40
40
for (size_t i = index; i < nthreads; i += offset) {
41
- int pw = index % pooled_width;
42
- int ph = (index / pooled_width) % pooled_height;
43
- int c = (index / pooled_width / pooled_height) % channels;
44
- int n = index / pooled_width / pooled_height / channels;
41
+ int pw = i % pooled_width;
42
+ int ph = (i / pooled_width) % pooled_height;
43
+ int c = (i / pooled_width / pooled_height) % channels;
44
+ int n = i / pooled_width / pooled_height / channels;
45
45
46
46
const int64_t * offset_input_rois = input_rois + n * kROISize ;
47
47
int roi_batch_ind = roi_batch_id_data[n];
@@ -52,14 +52,19 @@ __global__ void GPUROIPoolForward(
52
52
53
53
int roi_width = max (roi_end_w - roi_start_w + 1 , 1 );
54
54
int roi_height = max (roi_end_h - roi_start_h + 1 , 1 );
55
- T bin_size_h = static_cast <T>(roi_height) / static_cast <T>(pooled_height);
56
- T bin_size_w = static_cast <T>(roi_width) / static_cast <T>(pooled_width);
57
-
58
- int hstart = static_cast <int >(floor (static_cast <T>(ph) * bin_size_h));
59
- int wstart = static_cast <int >(floor (static_cast <T>(pw) * bin_size_w));
60
- int hend = static_cast <int >(ceil (static_cast <T>(ph + 1 ) * bin_size_h));
61
- int wend = static_cast <int >(ceil (static_cast <T>(pw + 1 ) * bin_size_w));
62
55
56
+ int hstart = static_cast <int >(floor (static_cast <double >(ph) *
57
+ static_cast <double >(roi_height) /
58
+ static_cast <double >(pooled_height)));
59
+ int wstart = static_cast <int >(floor (static_cast <double >(pw) *
60
+ static_cast <double >(roi_width) /
61
+ static_cast <double >(pooled_width)));
62
+ int hend = static_cast <int >(ceil (static_cast <double >(ph + 1 ) *
63
+ static_cast <double >(roi_height) /
64
+ static_cast <double >(pooled_height)));
65
+ int wend = static_cast <int >(ceil (static_cast <double >(pw + 1 ) *
66
+ static_cast <double >(roi_width) /
67
+ static_cast <double >(pooled_width)));
63
68
hstart = min (max (hstart + roi_start_h, 0 ), height);
64
69
hend = min (max (hend + roi_start_h, 0 ), height);
65
70
wstart = min (max (wstart + roi_start_w, 0 ), width);
@@ -79,9 +84,9 @@ __global__ void GPUROIPoolForward(
79
84
}
80
85
}
81
86
}
82
- output_data[index ] = maxval;
87
+ output_data[i ] = maxval;
83
88
if (argmax_data) {
84
- argmax_data[index ] = maxidx;
89
+ argmax_data[i ] = maxidx;
85
90
}
86
91
}
87
92
}
@@ -96,10 +101,10 @@ __global__ void GPUROIPoolBackward(
96
101
int index = blockIdx .x * blockDim .x + threadIdx .x ;
97
102
int offset = blockDim .x * gridDim .x ;
98
103
for (int i = index; i < nthreads; i += offset) {
99
- int pw = index % pooled_width;
100
- int ph = (index / pooled_width) % pooled_height;
101
- int c = (index / pooled_width / pooled_height) % channels;
102
- int n = index / pooled_width / pooled_height / channels;
104
+ int pw = i % pooled_width;
105
+ int ph = (i / pooled_width) % pooled_height;
106
+ int c = (i / pooled_width / pooled_height) % channels;
107
+ int n = i / pooled_width / pooled_height / channels;
103
108
104
109
int roi_batch_ind = roi_batch_id_data[n];
105
110
int input_offset = (roi_batch_ind * channels + c) * height * width;
@@ -138,6 +143,7 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> {
138
143
int width = in_dims[3 ];
139
144
140
145
int rois_num = rois->dims ()[0 ];
146
+
141
147
if (rois_num == 0 ) return ;
142
148
143
149
int output_size = out->numel ();
0 commit comments