@@ -24,7 +24,7 @@ using LoDTensor = framework::LoDTensor;
24
24
static constexpr int kROISize = 4 ;
25
25
26
26
template <class T >
27
- void pre_calc_for_bilinear_interpolate (
27
+ void PreCalcForBilinearInterpolate (
28
28
const platform::DeviceContext& ctx, const int height, const int width,
29
29
const int pooled_height, const int pooled_width, const int iy_upper,
30
30
const int ix_upper, T roi_ymin, T roi_xmin, T bin_size_h, T bin_size_w,
@@ -53,12 +53,8 @@ void pre_calc_for_bilinear_interpolate(
53
53
pre_calc_index += 1 ;
54
54
continue ;
55
55
}
56
- if (y <= 0 ) {
57
- y = 0 ;
58
- }
59
- if (x <= 0 ) {
60
- x = 0 ;
61
- }
56
+ y = y <= 0 ? 0 : y;
57
+ x = x <= 0 ? 0 : x;
62
58
63
59
int y_low = static_cast <int >(y);
64
60
int x_low = static_cast <int >(x);
@@ -104,12 +100,8 @@ void bilinear_interpolate_gradient(const int height, const int width, T y, T x,
104
100
x_low = x_high = y_low = y_high = -1 ;
105
101
return ;
106
102
}
107
- if (y <= 0 ) {
108
- y = 0 ;
109
- }
110
- if (x <= 0 ) {
111
- x = 0 ;
112
- }
103
+ y = y <= 0 ? 0 : y;
104
+ x = x <= 0 ? 0 : x;
113
105
y_low = static_cast <int >(y);
114
106
x_low = static_cast <int >(x);
115
107
if (y_low >= height - 1 ) {
@@ -139,7 +131,6 @@ void bilinear_interpolate_gradient(const int height, const int width, T y, T x,
139
131
*(batch_grad_data + y_high * width + x_low) += diff3;
140
132
*(batch_grad_data + y_high * width + x_high) += diff4;
141
133
}
142
- return ;
143
134
}
144
135
145
136
template <typename DeviceContext, typename T>
@@ -214,7 +205,7 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
214
205
pre_pos.Resize ({pre_size, kROISize });
215
206
pre_w.Resize ({pre_size, kROISize });
216
207
217
- pre_calc_for_bilinear_interpolate (
208
+ PreCalcForBilinearInterpolate (
218
209
dev_ctx, height, width, pooled_height, pooled_width, roi_bin_grid_h,
219
210
roi_bin_grid_w, roi_ymin, roi_xmin, bin_size_h, bin_size_w,
220
211
roi_bin_grid_h, roi_bin_grid_w, &pre_pos, &pre_w);
@@ -245,7 +236,6 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
245
236
}
246
237
rois_data += roi_stride[0 ];
247
238
}
248
- return ;
249
239
}
250
240
};
251
241
@@ -264,79 +254,78 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> {
264
254
auto spatial_scale = ctx.Attr <float >(" spatial_scale" );
265
255
auto sampling_ratio = ctx.Attr <int >(" sampling_ratio" );
266
256
auto in_dims = in->dims ();
267
- if (in_grad) {
268
- int channels = in_dims[1 ];
269
- int height = in_dims[2 ];
270
- int width = in_dims[3 ];
271
- int rois_num = rois->dims ()[0 ];
272
- Tensor roi_batch_id_list;
273
- roi_batch_id_list.Resize ({rois_num});
274
- int * roi_batch_id_data =
275
- roi_batch_id_list.mutable_data <int >(ctx.GetPlace ());
257
+ if (!in_grad) {
258
+ return ;
259
+ }
260
+ int channels = in_dims[1 ];
261
+ int height = in_dims[2 ];
262
+ int width = in_dims[3 ];
263
+ int rois_num = rois->dims ()[0 ];
264
+ Tensor roi_batch_id_list;
265
+ roi_batch_id_list.Resize ({rois_num});
266
+ int * roi_batch_id_data =
267
+ roi_batch_id_list.mutable_data <int >(ctx.GetPlace ());
276
268
277
- auto rois_lod = rois->lod ().back ();
278
- int rois_batch_size = rois_lod.size () - 1 ;
279
- for (int n = 0 ; n < rois_batch_size; ++n) {
280
- for (size_t i = rois_lod[n]; i < rois_lod[n + 1 ]; ++i) {
281
- roi_batch_id_data[i] = n;
282
- }
269
+ auto rois_lod = rois->lod ().back ();
270
+ int rois_batch_size = rois_lod.size () - 1 ;
271
+ for (int n = 0 ; n < rois_batch_size; ++n) {
272
+ for (size_t i = rois_lod[n]; i < rois_lod[n + 1 ]; ++i) {
273
+ roi_batch_id_data[i] = n;
283
274
}
275
+ }
284
276
285
- const T* rois_data = rois->data <T>();
286
- const T* out_grad_data = out_grad->data <T>();
287
- T* in_grad_data = in_grad->mutable_data <T>(ctx.GetPlace ());
277
+ const T* rois_data = rois->data <T>();
278
+ const T* out_grad_data = out_grad->data <T>();
279
+ T* in_grad_data = in_grad->mutable_data <T>(ctx.GetPlace ());
288
280
289
- auto in_stride = framework::stride (in->dims ());
290
- auto roi_stride = framework::stride (rois->dims ());
291
- auto out_stride = framework::stride (out_grad->dims ());
281
+ auto in_stride = framework::stride (in->dims ());
282
+ auto roi_stride = framework::stride (rois->dims ());
283
+ auto out_stride = framework::stride (out_grad->dims ());
292
284
293
- for (int n = 0 ; n < rois_num; ++n) {
294
- int roi_batch_idx = roi_batch_id_data[n];
295
- T roi_xmin = rois_data[0 ] * spatial_scale;
296
- T roi_ymin = rois_data[1 ] * spatial_scale;
297
- T roi_xmax = rois_data[2 ] * spatial_scale;
298
- T roi_ymax = rois_data[3 ] * spatial_scale;
299
- T roi_width = std::max (roi_xmax - roi_xmin, static_cast <T>(1 .));
300
- T roi_height = std::max (roi_ymax - roi_ymin, static_cast <T>(1 .));
301
- T bin_size_h =
302
- static_cast <T>(roi_height) / static_cast <T>(pooled_height);
303
- T bin_size_w = static_cast <T>(roi_width) / static_cast <T>(pooled_width);
304
- for (int c = 0 ; c < channels; ++c) {
305
- T* batch_grad_data =
306
- in_grad_data + roi_batch_idx * in_stride[0 ] + c * in_stride[1 ];
307
- const T* batch_out_grad_data =
308
- out_grad_data + n * out_stride[0 ] + c * out_stride[1 ];
309
- for (int ph = 0 ; ph < pooled_height; ++ph) {
310
- for (int pw = 0 ; pw < pooled_width; ++pw) {
311
- int pool_index = ph * pooled_width + pw;
312
- T out_grad_this_bin = batch_out_grad_data[pool_index];
313
- int roi_bin_grid_h = (sampling_ratio > 0 )
314
- ? sampling_ratio
315
- : ceil (roi_height / pooled_height);
316
- int roi_bin_grid_w = (sampling_ratio > 0 )
317
- ? sampling_ratio
318
- : ceil (roi_width / pooled_width);
319
- T count = roi_bin_grid_h * roi_bin_grid_w;
320
- for (int iy = 0 ; iy < roi_bin_grid_h; iy++) {
321
- const T y = roi_ymin + ph * bin_size_h +
322
- static_cast <T>(iy + .5f ) * bin_size_h /
323
- static_cast <T>(roi_bin_grid_h);
324
- for (int ix = 0 ; ix < roi_bin_grid_w; ix++) {
325
- const T x = roi_xmin + pw * bin_size_w +
326
- static_cast <T>(ix + .5f ) * bin_size_w /
327
- static_cast <T>(roi_bin_grid_w);
328
- bilinear_interpolate_gradient (height, width, y, x,
329
- out_grad_this_bin, count,
330
- batch_grad_data);
331
- }
285
+ for (int n = 0 ; n < rois_num; ++n) {
286
+ int roi_batch_idx = roi_batch_id_data[n];
287
+ T roi_xmin = rois_data[0 ] * spatial_scale;
288
+ T roi_ymin = rois_data[1 ] * spatial_scale;
289
+ T roi_xmax = rois_data[2 ] * spatial_scale;
290
+ T roi_ymax = rois_data[3 ] * spatial_scale;
291
+ T roi_width = std::max (roi_xmax - roi_xmin, static_cast <T>(1 .));
292
+ T roi_height = std::max (roi_ymax - roi_ymin, static_cast <T>(1 .));
293
+ T bin_size_h = static_cast <T>(roi_height) / static_cast <T>(pooled_height);
294
+ T bin_size_w = static_cast <T>(roi_width) / static_cast <T>(pooled_width);
295
+ for (int c = 0 ; c < channels; ++c) {
296
+ T* batch_grad_data =
297
+ in_grad_data + roi_batch_idx * in_stride[0 ] + c * in_stride[1 ];
298
+ const T* batch_out_grad_data =
299
+ out_grad_data + n * out_stride[0 ] + c * out_stride[1 ];
300
+ for (int ph = 0 ; ph < pooled_height; ++ph) {
301
+ for (int pw = 0 ; pw < pooled_width; ++pw) {
302
+ int pool_index = ph * pooled_width + pw;
303
+ T out_grad_this_bin = batch_out_grad_data[pool_index];
304
+ int roi_bin_grid_h = (sampling_ratio > 0 )
305
+ ? sampling_ratio
306
+ : ceil (roi_height / pooled_height);
307
+ int roi_bin_grid_w = (sampling_ratio > 0 )
308
+ ? sampling_ratio
309
+ : ceil (roi_width / pooled_width);
310
+ T count = roi_bin_grid_h * roi_bin_grid_w;
311
+ for (int iy = 0 ; iy < roi_bin_grid_h; iy++) {
312
+ const T y = roi_ymin + ph * bin_size_h +
313
+ static_cast <T>(iy + .5f ) * bin_size_h /
314
+ static_cast <T>(roi_bin_grid_h);
315
+ for (int ix = 0 ; ix < roi_bin_grid_w; ix++) {
316
+ const T x = roi_xmin + pw * bin_size_w +
317
+ static_cast <T>(ix + .5f ) * bin_size_w /
318
+ static_cast <T>(roi_bin_grid_w);
319
+ bilinear_interpolate_gradient (height, width, y, x,
320
+ out_grad_this_bin, count,
321
+ batch_grad_data);
332
322
}
333
323
}
334
324
}
335
325
}
336
- rois_data += roi_stride[0 ];
337
326
}
327
+ rois_data += roi_stride[0 ];
338
328
}
339
- return ;
340
329
}
341
330
};
342
331
} // namespace operators
0 commit comments