Skip to content

Commit e348d4c

Browse files
committed
imlement onnx style half pixel floor for resize
1 parent 4117ce0 commit e348d4c

File tree

3 files changed

+62
-26
lines changed

3 files changed

+62
-26
lines changed

include/caffe/layers/resize_nearest_neighbor_layer.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class ResizeNearestNeighborLayer : public Layer<Dtype> {
4949
bool align_corners;
5050
string data_format;
5151
bool half_pixel_centers;
52+
bool half_pixel_onnx;
5253
};
5354

5455
} // namespace caffe

src/caffe/layers/resize_nearest_neighbor_layer.cpp

Lines changed: 60 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ void ResizeNearestNeighborLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& b
1616
this->output_height = this->layer_param_.resize_nearest_neighbor_param().output_height();
1717
this->output_width = this->layer_param_.resize_nearest_neighbor_param().output_width();
1818
this->half_pixel_centers = this->layer_param_.resize_nearest_neighbor_param().half_pixel_centers();
19-
CHECK(!(this->align_corners && this->half_pixel_centers)) <<
20-
"If half_pixel_centers is True, align_corners must be False.";
19+
this->half_pixel_onnx = this->layer_param_.resize_nearest_neighbor_param().half_pixel_onnx();
20+
CHECK_LE((this->align_corners + this->half_pixel_centers + this->half_pixel_onnx), 1) <<
21+
"Maximum one Flag in align_corners, half_pixel_center or half_pixel_onnx could be True.";
2122
}
2223

2324
template <typename Dtype>
@@ -77,6 +78,7 @@ void ResizeNearestNeighborLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>&
7778

7879
const bool align_corners = this->align_corners;
7980
const bool half_pixel_centers = this->half_pixel_centers;
81+
const bool half_pixel_onnx = this->half_pixel_onnx;
8082

8183
const Dtype* bottom_data = bottom[0]->cpu_data();
8284
Dtype* top_data = top[0]->mutable_cpu_data();
@@ -97,36 +99,50 @@ void ResizeNearestNeighborLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>&
9799
for (int b = 0; b < batch_size; b++) {
98100
for (int h = 0; h < output_height; h++) {
99101
int in_h;
100-
if (!half_pixel_centers)
102+
if (half_pixel_onnx)
101103
{
102-
in_h = std::min((align_corners)
103-
? static_cast<int>(roundf(h * height_scale))
104-
: static_cast<int>(floorf(h * height_scale)),
105-
input_height - 1);
104+
in_h = std::max(std::min(static_cast<int>(
105+
floorf((static_cast<float>(h) + 0.5f) * height_scale - 0.5)),
106+
input_height - 1),
107+
0);
106108
}
107-
else //if (half_pixel_centers)
109+
else if (half_pixel_centers)
108110
{
109111
in_h = std::max(std::min(static_cast<int>(
110112
floorf((static_cast<float>(h) + 0.5f) * height_scale)),
111113
input_height - 1),
112114
0);
113115
}
116+
else
117+
{
118+
in_h = std::min((align_corners)
119+
? static_cast<int>(roundf(h * height_scale))
120+
: static_cast<int>(floorf(h * height_scale)),
121+
input_height - 1);
122+
}
114123
for (int w = 0; w < output_width; w++) {
115124
int in_w;
116-
if (!half_pixel_centers)
125+
if (half_pixel_onnx)
117126
{
118-
in_w = std::min((align_corners)
119-
? static_cast<int>(roundf(w * width_scale))
120-
: static_cast<int>(floorf(w * width_scale)),
121-
input_width - 1);
127+
in_w = std::max(std::min(static_cast<int>(
128+
floorf((static_cast<float>(w) + 0.5f) * width_scale - 0.5)),
129+
input_width - 1),
130+
0);
122131
}
123-
else //if (half_pixel_centers)
132+
else if (half_pixel_centers)
124133
{
125134
in_w = std::max(std::min(static_cast<int>(
126135
floorf((static_cast<float>(w) + 0.5f) * width_scale)),
127136
input_width - 1),
128137
0);
129138
}
139+
else
140+
{
141+
in_w = std::min((align_corners)
142+
? static_cast<int>(roundf(w * width_scale))
143+
: static_cast<int>(floorf(w * width_scale)),
144+
input_width - 1);
145+
}
130146
for(int c = 0; c < channels; c++) {
131147
const int input_index = ((b*input_height + in_h)*input_width + in_w)*channels + c;
132148
const int output_index = ((b*output_height + h)*output_width + w)*channels + c;
@@ -147,6 +163,7 @@ void ResizeNearestNeighborLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>&
147163

148164
const bool align_corners = this->align_corners;
149165
const bool half_pixel_centers = this->half_pixel_centers;
166+
const bool half_pixel_onnx = this->half_pixel_onnx;
150167

151168
const Dtype* bottom_data = bottom[0]->cpu_data();
152169
Dtype* top_data = top[0]->mutable_cpu_data();
@@ -156,41 +173,58 @@ void ResizeNearestNeighborLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>&
156173
const float width_scale =
157174
CalculateResizeScale(output_width, input_width, align_corners);
158175

176+
//LOG(INFO)<<height_scale<<" "<<width_scale<<std::endl;
177+
159178
//it implies NCHW data format
160179
for (int b = 0; b < batch_size; b++) {
161180
for(int c = 0; c < channels; c++) {
162181
for (int h = 0; h < output_height; h++) {
163182
int in_h;
164-
if (!half_pixel_centers)
183+
if (half_pixel_onnx)
165184
{
166-
in_h = std::min((align_corners)
167-
? static_cast<int>(roundf(h * height_scale))
168-
: static_cast<int>(floorf(h * height_scale)),
169-
input_height - 1);
185+
in_h = std::max(std::min(static_cast<int>(
186+
floorf((static_cast<float>(h) + 0.5f) * height_scale - 0.5)),
187+
input_height - 1),
188+
0);
170189
}
171-
else //if (half_pixel_centers)
190+
else if (half_pixel_centers)
172191
{
173192
in_h = std::max(std::min(static_cast<int>(
174193
floorf((static_cast<float>(h) + 0.5f) * height_scale)),
175194
input_height - 1),
176195
0);
177196
}
197+
else
198+
{
199+
in_h = std::min((align_corners)
200+
? static_cast<int>(roundf(h * height_scale))
201+
: static_cast<int>(floorf(h * height_scale)),
202+
input_height - 1);
203+
}
178204
for (int w = 0; w < output_width; w++) {
179205
int in_w;
180-
if (!half_pixel_centers)
206+
if (half_pixel_onnx)
181207
{
182-
in_w = std::min((align_corners)
183-
? static_cast<int>(roundf(w * width_scale))
184-
: static_cast<int>(floorf(w * width_scale)),
185-
input_width - 1);
208+
in_w = std::max(std::min(static_cast<int>(
209+
floorf((static_cast<float>(w) + 0.5f) * width_scale - 0.5)),
210+
input_width - 1),
211+
0);
186212
}
187-
else //if (half_pixel_centers)
213+
else if (half_pixel_centers)
188214
{
189215
in_w = std::max(std::min(static_cast<int>(
190216
floorf((static_cast<float>(w) + 0.5f) * width_scale)),
191217
input_width - 1),
192218
0);
193219
}
220+
else
221+
{
222+
in_w = std::min((align_corners)
223+
? static_cast<int>(roundf(w * width_scale))
224+
: static_cast<int>(floorf(w * width_scale)),
225+
input_width - 1);
226+
}
227+
//LOG(INFO)<<w<<" "<<in_w<<std::endl;
194228
const int input_index = ((b*channels + c)*input_height + in_h)*input_width + in_w;
195229
const int output_index = ((b*channels + c)*output_height + h)*output_width + w;
196230
top_data[output_index] = bottom_data[input_index];

src/caffe/proto/caffe.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3010,6 +3010,7 @@ message ResizeNearestNeighborParameter {
30103010
optional float scale_height = 5 [default = 1];
30113011
optional float scale_width = 6 [default = 1];
30123012
optional bool half_pixel_centers = 7 [default = false];
3013+
optional bool half_pixel_onnx = 8 [default = false]; //ONNX style half_pixel, only implement nearest_mode=floor case now
30133014
}
30143015

30153016
message GatherParameter{

0 commit comments

Comments
 (0)