@@ -18,6 +18,11 @@ void ResizeBilinearLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
1818 half_pixel_centers_ = this ->layer_param_ .resize_bilinear_param ().half_pixel_centers ();
1919 CHECK (!(align_corners_ && half_pixel_centers_)) <<
2020 " If half_pixel_centers is True, align_corners must be False." ;
21+ pytorch_half_pixel_ = this ->layer_param_ .resize_bilinear_param ().pytorch_half_pixel ();
22+ CHECK (!(align_corners_ && pytorch_half_pixel_)) <<
23+ " If pytorch_half_pixel_ is True, align_corners must be False." ;
24+ CHECK (!(half_pixel_centers_ && pytorch_half_pixel_)) <<
25+ " If pytorch_half_pixel_ is True, half_pixel_centers_ must be False." ;
2126}
2227
2328template <typename Dtype>
@@ -70,14 +75,21 @@ void ResizeBilinearLayer<Dtype>::compute_interpolation_weights(const int out_siz
7075 interpolation[out_size].upper = 0 ;
7176 for (int i = out_size - 1 ; i >= 0 ; --i) {
7277 float in;
73- if (!half_pixel_centers_ )
78+ if (align_corners_ )
7479 {
7580 in = static_cast <float >(i) * scale;
7681 }
7782 else // if (half_pixel_centers_)
7883 {
79- in = (static_cast <float >(i) + 0 .5f ) * scale - 0 .5f ;
80- // ref: https://github.com/tensorflow/tensorflow/blob/r1.15/tensorflow/core/kernels/image_resizer_state.h#L50
84+ if (half_pixel_centers_ || out_size > 1 ) {
85+ in = (static_cast <float >(i) + 0 .5f ) * scale - 0 .5f ;
86+ // ref: https://github.com/tensorflow/tensorflow/blob/r1.15/tensorflow/core/kernels/image_resizer_state.h#L50
87+ }else {
88+ // pytorch_half_pixel_ && out_size <= 1
89+ in = -0 .5f ;
90+ // ref: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Resize
91+ // https://github.com/onnx/onnx/blob/master/onnx/backend/test/case/node/resize.py#L132
92+ }
8193 }
8294 const float in_f = std::floor (in);
8395 interpolation[i].lower =
0 commit comments