Skip to content

Commit 0ae21a3

Browse files
committed
optimize and reformat unstack layer
1 parent c8e6052 commit 0ae21a3

File tree

2 files changed

+48
-76
lines changed

2 files changed

+48
-76
lines changed

include/caffe/layers/unstack_layer.hpp

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,40 +8,35 @@
88
#include "caffe/proto/caffe.pb.h"
99

1010
namespace caffe {
11-
template <typename Dtype>
12-
class UnstackLayer : public Layer<Dtype> {
13-
public:
14-
explicit UnstackLayer(const LayerParameter& param)
15-
: Layer<Dtype>(param) {}
16-
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
17-
const vector<Blob<Dtype>*>& top);
18-
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
19-
const vector<Blob<Dtype>*>& top);
20-
21-
virtual inline const char* type() const { return "Unstack"; }
11+
template <typename Dtype> class UnstackLayer : public Layer<Dtype> {
12+
public:
13+
explicit UnstackLayer(const LayerParameter &param) : Layer<Dtype>(param) {}
14+
virtual void LayerSetUp(const vector<Blob<Dtype> *> &bottom,
15+
const vector<Blob<Dtype> *> &top);
16+
virtual void Reshape(const vector<Blob<Dtype> *> &bottom,
17+
const vector<Blob<Dtype> *> &top);
18+
19+
virtual inline const char *type() const { return "Unstack"; }
2220
virtual inline int ExactNumBottomBlobs() const { return 1; }
2321
virtual inline int MinTopBlobs() const { return 1; }
2422

25-
protected:
26-
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
27-
const vector<Blob<Dtype>*>& top);
28-
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
29-
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
30-
NOT_IMPLEMENTED;
31-
}
32-
//virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
23+
protected:
24+
virtual void Forward_cpu(const vector<Blob<Dtype> *> &bottom,
25+
const vector<Blob<Dtype> *> &top);
26+
virtual void Backward_cpu(const vector<Blob<Dtype> *> &top,
27+
const vector<bool> &propagate_down,
28+
const vector<Blob<Dtype> *> &bottom) {
29+
NOT_IMPLEMENTED;
30+
}
31+
// virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
3332
// const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
34-
//virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
33+
// virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
3534
// const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
3635

37-
int count_;
38-
int num_unstack_;
39-
int unstack_size_;
4036
int unstack_axis_;
4137
int unstack_num_;
4238
};
4339

44-
} // namespace caffe
45-
46-
#endif // CAFFE_UNSTACK_LAYER_HPP_
40+
} // namespace caffe
4741

42+
#endif // CAFFE_UNSTACK_LAYER_HPP_

src/caffe/layers/unstack_layer.cpp

Lines changed: 27 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,70 +7,47 @@
77
namespace caffe {
88

99
template <typename Dtype>
10-
void UnstackLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
11-
const vector<Blob<Dtype>*>& top) {
12-
const UnstackParameter& unstack_param = this->layer_param_.unstack_param();
10+
void UnstackLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype> *> &bottom,
11+
const vector<Blob<Dtype> *> &top) {
12+
const UnstackParameter &unstack_param = this->layer_param_.unstack_param();
1313
unstack_axis_ = bottom[0]->CanonicalAxisIndex(unstack_param.axis());
14-
const int num = unstack_param.num();
15-
if (num != 0) {
16-
CHECK_EQ(num, bottom[0]->shape(unstack_axis_))
17-
<< "num should equal to the shape in axis!";
18-
}
19-
unstack_num_ = bottom[0]->shape(unstack_axis_);
20-
CHECK_EQ(unstack_num_, top.size())<< "Number of top blobs ("
21-
<< top.size() << ") should euqal to "<< "shape in axis ("
22-
<< unstack_num_ << ")";
14+
unstack_num_ = unstack_param.num();
15+
if (unstack_num_ == 0)
16+
unstack_num_ = bottom[0]->shape(unstack_axis_);
2317
}
2418

2519
template <typename Dtype>
26-
void UnstackLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
27-
const vector<Blob<Dtype>*>& top) {
28-
const int num_axes = bottom[0]->num_axes();
29-
//const UnstackParameter& unstack_param = this->layer_param_.unstack_param();
30-
vector<int> bottom_shape = bottom[0]->shape();
20+
void UnstackLayer<Dtype>::Reshape(const vector<Blob<Dtype> *> &bottom,
21+
const vector<Blob<Dtype> *> &top) {
22+
// const UnstackParameter& unstack_param = this->layer_param_.unstack_param();
3123
vector<int> top_shape = bottom[0]->shape();
32-
top_shape.resize(num_axes - 1);
33-
num_unstack_ = bottom[0]->count(0, unstack_axis_);
34-
unstack_size_ = bottom[0]->count(unstack_axis_ + 1);
35-
int count = 0;
36-
for (int i = unstack_axis_; i < num_axes - 1; ++i) {
37-
top_shape[i] = bottom_shape[i + 1];
38-
}
39-
for (int i = 0; i < top.size(); ++i) {
40-
top[i]->Reshape(top_shape);
41-
count += top[i]->count();
42-
}
43-
CHECK_EQ(count, bottom[0]->count());
44-
if (top.size() == 1) {
45-
top[0]->ShareData(*bottom[0]);
46-
top[0]->ShareDiff(*bottom[0]);
47-
}
24+
top_shape.erase(top_shape.begin() + unstack_axis_);
25+
for (int i = 0; i < top.size(); ++i)
26+
top[i]->Reshape(top_shape);
4827
}
4928

5029
template <typename Dtype>
51-
void UnstackLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
52-
const vector<Blob<Dtype>*>& top) {
53-
if (top.size() == 1) { return; }
54-
int offset_unstack_axis = 0;
55-
const Dtype* bottom_data = bottom[0]->cpu_data();
56-
const int bottom_unstack_axis = bottom[0]->shape(unstack_axis_);
30+
void UnstackLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype> *> &bottom,
31+
const vector<Blob<Dtype> *> &top) {
32+
const Dtype *bottom_data = bottom[0]->cpu_data();
33+
const int size_unstack_axis = bottom[0]->shape(unstack_axis_);
34+
vector<int> bottom_shape = bottom[0]->shape();
35+
int strides = 1;
36+
for (int i = unstack_axis_ + 1; i < bottom_shape.size(); i++)
37+
strides *= bottom_shape[i];
38+
int num_unstack_ = bottom[0]->count() / strides / size_unstack_axis;
39+
// num_unstack_ /= size_unstack_axis;
5740
for (int i = 0; i < top.size(); ++i) {
58-
Dtype* top_data = top[i]->mutable_cpu_data();
41+
Dtype *top_data = top[i]->mutable_cpu_data();
5942
for (int n = 0; n < num_unstack_; ++n) {
60-
const int top_offset = n * unstack_size_;
61-
const int bottom_offset =
62-
(n * bottom_unstack_axis + offset_unstack_axis) * unstack_size_;
63-
caffe_copy(unstack_size_,
64-
bottom_data + bottom_offset, top_data + top_offset);
43+
const int top_offset = n * strides;
44+
const int bottom_offset = (n * size_unstack_axis + i) * strides;
45+
caffe_copy(strides, bottom_data + bottom_offset, top_data + top_offset);
6546
}
66-
offset_unstack_axis += 1;
6747
}
6848
}
6949

70-
71-
7250
INSTANTIATE_CLASS(UnstackLayer);
7351
REGISTER_LAYER_CLASS(Unstack);
7452

75-
} // namespace caffe
76-
53+
} // namespace caffe

0 commit comments

Comments
 (0)