Skip to content

Commit 9d3dd63

Browse files
committed
add edge case for onnx model conversion; refine format
1 parent 8757628 commit 9d3dd63

File tree

1 file changed

+33
-16
lines changed

1 file changed

+33
-16
lines changed

src/caffe/layers/mirror_pad_layer.cpp

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,25 @@ namespace caffe {
99
using namespace std;
1010
template <typename Dtype>
1111
void MirrorPadLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype> *> &bottom,
12-
const vector<Blob<Dtype> *> &top) {
12+
const vector<Blob<Dtype> *> &top) {
1313
const MirrorPadParameter &mirror_pad_param =
1414
this->layer_param_.mirror_pad_param();
1515
constant_values_ = mirror_pad_param.constant_values();
1616
mode_ = mirror_pad_param.mode();
1717
paddings_.clear();
1818
std::copy(mirror_pad_param.paddings().begin(),
19-
mirror_pad_param.paddings().end(), std::back_inserter(paddings_));
19+
mirror_pad_param.paddings().end(), std::back_inserter(paddings_));
2020
int pad_dim = paddings_.size();
2121
CHECK_EQ(pad_dim % 2, 0)
22-
<< "Paddings for each dimension should have 2 values!";
22+
<< "Paddings for each dimension should have 2 values!";
2323
CHECK_EQ(pad_dim / 2, bottom[0]->num_axes())
24-
<< "Paddings' num should be 2 times of bottom dimension!";
24+
<< "Paddings' num should be 2 times of bottom dimension!";
2525
// CHECK_LE(bottom[0]->num_axes(), 4) << "Not support more than 4D paddings!";
2626
}
2727

2828
template <typename Dtype>
2929
void MirrorPadLayer<Dtype>::Reshape(const vector<Blob<Dtype> *> &bottom,
30-
const vector<Blob<Dtype> *> &top) {
30+
const vector<Blob<Dtype> *> &top) {
3131
int num_top_axes = bottom[0]->num_axes();
3232
std::vector<int> shape(num_top_axes, 1);
3333
shape = bottom[0]->shape();
@@ -51,7 +51,7 @@ MirrorPadLayer<Dtype>::indices(int offset, const vector<int> &shape) const {
5151

5252
template <typename Dtype>
5353
inline int MirrorPadLayer<Dtype>::offset(const vector<int> &indices,
54-
const vector<int> &shape) const {
54+
const vector<int> &shape) const {
5555
int offset = 0;
5656
for (int i = 0; i < shape.size(); ++i) {
5757
offset *= shape[i];
@@ -62,7 +62,7 @@ inline int MirrorPadLayer<Dtype>::offset(const vector<int> &indices,
6262

6363
template <typename Dtype>
6464
void MirrorPadLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype> *> &bottom,
65-
const vector<Blob<Dtype> *> &top) {
65+
const vector<Blob<Dtype> *> &top) {
6666
const Dtype *bottom_data = bottom[0]->cpu_data();
6767
Dtype *top_data = top[0]->mutable_cpu_data();
6868
auto bottom_shape = bottom[0]->shape();
@@ -87,17 +87,16 @@ void MirrorPadLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype> *> &bottom,
8787
for (int position = 0; position < top[0]->count() / strides; position++) {
8888
for (int j = 1; j <= paddings_[2 * i]; j++) {
8989
copy_n(top_data + position*strides + inner_strides * (paddings_[2 * i] + j),
90-
inner_strides,
91-
top_data + position*strides + inner_strides * (paddings_[2 * i] - j));
90+
inner_strides,
91+
top_data + position*strides + inner_strides * (paddings_[2 * i] - j));
9292
}
9393
for (int j = 1; j <= paddings_[2 * i + 1]; j++) {
9494
copy_n(top_data + position*strides + inner_strides * (bottom_shape[i] + paddings_[2 * i] - 1 - j),
95-
inner_strides,
96-
top_data + position*strides + inner_strides * (bottom_shape[i] + paddings_[2 * i] - 1 + j));
95+
inner_strides,
96+
top_data + position*strides + inner_strides * (bottom_shape[i] + paddings_[2 * i] - 1 + j));
9797
}
9898
}
9999
}
100-
101100
} else if (mode_ == "SYMMETRIC") {
102101
strides = 1;
103102
for (int i = top_shape.size() - 1; i >= 0; i--) {
@@ -106,13 +105,31 @@ void MirrorPadLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype> *> &bottom,
106105
for (int position = 0; position < top[0]->count() / strides; position++) {
107106
for (int j = 0; j < paddings_[2 * i]; j++) {
108107
copy_n(top_data + position*strides + inner_strides * (paddings_[2 * i] + j),
109-
inner_strides,
110-
top_data + position*strides + inner_strides * (paddings_[2 * i] - j - 1));
108+
inner_strides,
109+
top_data + position*strides + inner_strides * (paddings_[2 * i] - j - 1));
111110
}
112111
for (int j = 0; j < paddings_[2 * i + 1]; j++) {
113112
copy_n(top_data + position*strides + inner_strides * (bottom_shape[i] + paddings_[2 * i] - 1 - j),
114-
inner_strides,
115-
top_data + position*strides + inner_strides * (bottom_shape[i] + paddings_[2 * i] + j));
113+
inner_strides,
114+
top_data + position*strides + inner_strides * (bottom_shape[i] + paddings_[2 * i] + j));
115+
}
116+
}
117+
}
118+
} else if (mode_ == "EDGE") {
119+
strides = 1;
120+
for (int i = top_shape.size() - 1; i >= 0; i--) {
121+
int inner_strides = strides;
122+
strides *= top_shape[i];
123+
for (int position = 0; position < top[0]->count() / strides; position++) {
124+
for (int j = 0; j < paddings_[2 * i]; j++) {
125+
copy_n(top_data + position*strides + inner_strides * (paddings_[2 * i]),
126+
inner_strides,
127+
top_data + position*strides + inner_strides * (paddings_[2 * i] - j - 1));
128+
}
129+
for (int j = 0; j < paddings_[2 * i + 1]; j++) {
130+
copy_n(top_data + position*strides + inner_strides * (bottom_shape[i] + paddings_[2 * i] - 1),
131+
inner_strides,
132+
top_data + position*strides + inner_strides * (bottom_shape[i] + paddings_[2 * i] + j));
116133
}
117134
}
118135
}

0 commit comments

Comments
 (0)