@@ -16,8 +16,9 @@ void ResizeNearestNeighborLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& b
1616 this ->output_height = this ->layer_param_ .resize_nearest_neighbor_param ().output_height ();
1717 this ->output_width = this ->layer_param_ .resize_nearest_neighbor_param ().output_width ();
1818 this ->half_pixel_centers = this ->layer_param_ .resize_nearest_neighbor_param ().half_pixel_centers ();
19- CHECK (!(this ->align_corners && this ->half_pixel_centers )) <<
20- " If half_pixel_centers is True, align_corners must be False." ;
19+ this ->half_pixel_onnx = this ->layer_param_ .resize_nearest_neighbor_param ().half_pixel_onnx ();
20+ CHECK_LE ((this ->align_corners + this ->half_pixel_centers + this ->half_pixel_onnx ), 1 ) <<
21+ " Maximum one Flag in align_corners, half_pixel_center or half_pixel_onnx could be True." ;
2122}
2223
2324template <typename Dtype>
@@ -77,6 +78,7 @@ void ResizeNearestNeighborLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>&
7778
7879 const bool align_corners = this ->align_corners ;
7980 const bool half_pixel_centers = this ->half_pixel_centers ;
81+ const bool half_pixel_onnx = this ->half_pixel_onnx ;
8082
8183 const Dtype* bottom_data = bottom[0 ]->cpu_data ();
8284 Dtype* top_data = top[0 ]->mutable_cpu_data ();
@@ -97,36 +99,50 @@ void ResizeNearestNeighborLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>&
9799 for (int b = 0 ; b < batch_size; b++) {
98100 for (int h = 0 ; h < output_height; h++) {
99101 int in_h;
100- if (!half_pixel_centers )
102+ if (half_pixel_onnx )
101103 {
102- in_h = std::min ((align_corners)
103- ? static_cast <int >( roundf (h * height_scale))
104- : static_cast < int >( floorf (h * height_scale) ),
105- input_height - 1 );
104+ in_h = std::max ( std:: min (static_cast < int >(
105+ floorf (( static_cast <float >(h) + 0 . 5f ) * height_scale - 0.5 )),
106+ input_height - 1 ),
107+ 0 );
106108 }
107- else // if (half_pixel_centers)
109+ else if (half_pixel_centers)
108110 {
109111 in_h = std::max (std::min (static_cast <int >(
110112 floorf ((static_cast <float >(h) + 0 .5f ) * height_scale)),
111113 input_height - 1 ),
112114 0 );
113115 }
116+ else
117+ {
118+ in_h = std::min ((align_corners)
119+ ? static_cast <int >(roundf (h * height_scale))
120+ : static_cast <int >(floorf (h * height_scale)),
121+ input_height - 1 );
122+ }
114123 for (int w = 0 ; w < output_width; w++) {
115124 int in_w;
116- if (!half_pixel_centers )
125+ if (half_pixel_onnx )
117126 {
118- in_w = std::min ((align_corners)
119- ? static_cast <int >( roundf (w * width_scale))
120- : static_cast < int >( floorf (w * width_scale) ),
121- input_width - 1 );
127+ in_w = std::max ( std:: min (static_cast < int >(
128+ floorf (( static_cast <float >(w) + 0 . 5f ) * width_scale - 0.5 )),
129+ input_width - 1 ),
130+ 0 );
122131 }
123- else // if (half_pixel_centers)
132+ else if (half_pixel_centers)
124133 {
125134 in_w = std::max (std::min (static_cast <int >(
126135 floorf ((static_cast <float >(w) + 0 .5f ) * width_scale)),
127136 input_width - 1 ),
128137 0 );
129138 }
139+ else
140+ {
141+ in_w = std::min ((align_corners)
142+ ? static_cast <int >(roundf (w * width_scale))
143+ : static_cast <int >(floorf (w * width_scale)),
144+ input_width - 1 );
145+ }
130146 for (int c = 0 ; c < channels; c++) {
131147 const int input_index = ((b*input_height + in_h)*input_width + in_w)*channels + c;
132148 const int output_index = ((b*output_height + h)*output_width + w)*channels + c;
@@ -147,6 +163,7 @@ void ResizeNearestNeighborLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>&
147163
148164 const bool align_corners = this ->align_corners ;
149165 const bool half_pixel_centers = this ->half_pixel_centers ;
166+ const bool half_pixel_onnx = this ->half_pixel_onnx ;
150167
151168 const Dtype* bottom_data = bottom[0 ]->cpu_data ();
152169 Dtype* top_data = top[0 ]->mutable_cpu_data ();
@@ -156,41 +173,58 @@ void ResizeNearestNeighborLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>&
156173 const float width_scale =
157174 CalculateResizeScale (output_width, input_width, align_corners);
158175
176+ // LOG(INFO)<<height_scale<<" "<<width_scale<<std::endl;
177+
159178 // it implies NCHW data format
160179 for (int b = 0 ; b < batch_size; b++) {
161180 for (int c = 0 ; c < channels; c++) {
162181 for (int h = 0 ; h < output_height; h++) {
163182 int in_h;
164- if (!half_pixel_centers )
183+ if (half_pixel_onnx )
165184 {
166- in_h = std::min ((align_corners)
167- ? static_cast <int >( roundf (h * height_scale))
168- : static_cast < int >( floorf (h * height_scale) ),
169- input_height - 1 );
185+ in_h = std::max ( std:: min (static_cast < int >(
186+ floorf (( static_cast <float >(h) + 0 . 5f ) * height_scale - 0.5 )),
187+ input_height - 1 ),
188+ 0 );
170189 }
171- else // if (half_pixel_centers)
190+ else if (half_pixel_centers)
172191 {
173192 in_h = std::max (std::min (static_cast <int >(
174193 floorf ((static_cast <float >(h) + 0 .5f ) * height_scale)),
175194 input_height - 1 ),
176195 0 );
177196 }
197+ else
198+ {
199+ in_h = std::min ((align_corners)
200+ ? static_cast <int >(roundf (h * height_scale))
201+ : static_cast <int >(floorf (h * height_scale)),
202+ input_height - 1 );
203+ }
178204 for (int w = 0 ; w < output_width; w++) {
179205 int in_w;
180- if (!half_pixel_centers )
206+ if (half_pixel_onnx )
181207 {
182- in_w = std::min ((align_corners)
183- ? static_cast <int >( roundf (w * width_scale))
184- : static_cast < int >( floorf (w * width_scale) ),
185- input_width - 1 );
208+ in_w = std::max ( std:: min (static_cast < int >(
209+ floorf (( static_cast <float >(w) + 0 . 5f ) * width_scale - 0.5 )),
210+ input_width - 1 ),
211+ 0 );
186212 }
187- else // if (half_pixel_centers)
213+ else if (half_pixel_centers)
188214 {
189215 in_w = std::max (std::min (static_cast <int >(
190216 floorf ((static_cast <float >(w) + 0 .5f ) * width_scale)),
191217 input_width - 1 ),
192218 0 );
193219 }
220+ else
221+ {
222+ in_w = std::min ((align_corners)
223+ ? static_cast <int >(roundf (w * width_scale))
224+ : static_cast <int >(floorf (w * width_scale)),
225+ input_width - 1 );
226+ }
227+ // LOG(INFO)<<w<<" "<<in_w<<std::endl;
194228 const int input_index = ((b*channels + c)*input_height + in_h)*input_width + in_w;
195229 const int output_index = ((b*channels + c)*output_height + h)*output_width + w;
196230 top_data[output_index] = bottom_data[input_index];
0 commit comments