Skip to content

Commit 60524b7

Browse files
committed
add mode pytorch_half_pixel to ResizeBilinear
1 parent 2b90b20 commit 60524b7

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

include/caffe/layers/resize_bilinear_layer.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class ResizeBilinearLayer : public Layer<Dtype> {
5151
bool align_corners_;
5252
string data_format_;
5353
bool half_pixel_centers_;
54+
bool pytorch_half_pixel_;
5455

5556
// Compute the interpolation indices only once.
5657
struct CachedInterpolation {

src/caffe/layers/resize_bilinear_layer.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ void ResizeBilinearLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
1818
half_pixel_centers_ = this->layer_param_.resize_bilinear_param().half_pixel_centers();
1919
CHECK(!(align_corners_ && half_pixel_centers_)) <<
2020
"If half_pixel_centers is True, align_corners must be False.";
21+
pytorch_half_pixel_ = this->layer_param_.resize_bilinear_param().pytorch_half_pixel();
22+
CHECK(!(align_corners_ && pytorch_half_pixel_)) <<
23+
"If pytorch_half_pixel_ is True, align_corners must be False.";
24+
CHECK(!(half_pixel_centers_ && pytorch_half_pixel_)) <<
25+
"If pytorch_half_pixel_ is True, half_pixel_centers_ must be False.";
2126
}
2227

2328
template <typename Dtype>
@@ -70,14 +75,21 @@ void ResizeBilinearLayer<Dtype>::compute_interpolation_weights(const int out_siz
7075
interpolation[out_size].upper = 0;
7176
for (int i = out_size - 1; i >= 0; --i) {
7277
float in;
73-
if (!half_pixel_centers_)
78+
if (align_corners_)
7479
{
7580
in = static_cast<float>(i) * scale;
7681
}
7782
else //if (half_pixel_centers_)
7883
{
79-
in = (static_cast<float>(i) + 0.5f) * scale - 0.5f;
80-
// ref: https://github.com/tensorflow/tensorflow/blob/r1.15/tensorflow/core/kernels/image_resizer_state.h#L50
84+
if (half_pixel_centers_ || out_size > 1) {
85+
in = (static_cast<float>(i) + 0.5f) * scale - 0.5f;
86+
// ref: https://github.com/tensorflow/tensorflow/blob/r1.15/tensorflow/core/kernels/image_resizer_state.h#L50
87+
}else {
88+
// pytorch_half_pixel_ && out_size <= 1
89+
in = -0.5f;
90+
// ref: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Resize
91+
// https://github.com/onnx/onnx/blob/master/onnx/backend/test/case/node/resize.py#L132
92+
}
8193
}
8294
const float in_f = std::floor(in);
8395
interpolation[i].lower =

src/caffe/proto/caffe.proto

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,7 @@ message LayerParameter {
674674
optional RNNv2Parameter rnn_v2_param = 272;
675675
optional CountNonzeroParameter count_nonzero_param = 273;
676676
optional NotEqualParameter not_equal_param = 274;
677+
optional PointNetParameter point_net_param = 275;
677678

678679
//ONNX related
679680
optional NonMaxSuppressionParameter non_max_suppression_param = 271;
@@ -3110,6 +3111,7 @@ message ResizeBilinearParameter {
31103111
optional float scale_height = 5 [default = 1];
31113112
optional float scale_width = 6 [default = 1];
31123113
optional bool half_pixel_centers = 7 [default = false];
3114+
optional bool pytorch_half_pixel = 8 [default = false];
31133115
}
31143116

31153117
message ReduceSumParameter {
@@ -3304,6 +3306,7 @@ message MatMulParameter {
33043306

33053307
message GatherV2Parameter {
33063308
optional int32 axis = 1[default = 0];
3309+
optional bool batch_flag = 2[default = false];
33073310
}
33083311

33093312
message ScaledTanHParameter {
@@ -3382,3 +3385,7 @@ message CountNonzeroParameter {
33823385
message NotEqualParameter {
33833386
optional float comparand = 1;
33843387
}
3388+
3389+
message PointNetParameter {
3390+
optional uint32 n_sample_point = 1;
3391+
}

0 commit comments

Comments
 (0)