Skip to content

Commit 0c45caf

Browse files
committed
implemantation of tf.batch_to_space_nd
1 parent c37534a commit 0c45caf

File tree

2 files changed

+169
-0
lines changed

2 files changed

+169
-0
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#ifndef CAFFE_BATCHTOSPACEND_LAYER_HPP_
2+
#define CAFFE_BATCHTOSPACEND_LAYER_HPP_
3+
4+
#include <vector>
5+
6+
#include "caffe/blob.hpp"
7+
#include "caffe/layer.hpp"
8+
#include "caffe/proto/caffe.pb.h"
9+
10+
namespace caffe {
11+
12+
template <typename Dtype>
13+
class BatchToSpaceNDLayer : public Layer<Dtype> {
14+
public:
15+
16+
explicit BatchToSpaceNDLayer(const LayerParameter& param)
17+
: Layer<Dtype>(param) {}
18+
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
19+
const vector<Blob<Dtype>*>& top);
20+
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
21+
const vector<Blob<Dtype>*>& top);
22+
23+
virtual inline const char* type() const { return "BatchToSpaceND"; }
24+
virtual inline int ExactNumBottomBlobs() const { return 1; }
25+
virtual inline int ExactNumTopBlobs() const { return 1; }
26+
27+
protected:
28+
29+
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
30+
const vector<Blob<Dtype>*>& top);
31+
/// @brief Not implemented (non-differentiable function)
32+
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
33+
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
34+
NOT_IMPLEMENTED;
35+
}
36+
private:
37+
inline vector<int> indices(int offset, const vector<int> & shape) const;
38+
inline int offset(const vector<int>& indices, const vector<int> & shape) const;
39+
40+
vector<int> block_shape_;
41+
vector<int> crops_;
42+
};
43+
44+
} // namespace caffe
45+
46+
#endif // CAFFE_BATCHTOSPACEND_LAYER_HPP_
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#include <algorithm>
2+
#include <functional>
3+
#include <utility>
4+
#include <vector>
5+
#include <numeric>
6+
7+
#include "caffe/layers/batch_to_space_nd_layer.hpp"
8+
// implementation of https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
9+
namespace caffe {
10+
using namespace std;
11+
12+
template <typename Dtype>
13+
void BatchToSpaceNDLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
14+
const vector<Blob<Dtype>*>& top) {
15+
const BatchToSpaceNDParameter& batch_to_space_nd_param = this->layer_param_.batch_to_space_nd_param();
16+
for(auto i : batch_to_space_nd_param.block_shape())
17+
block_shape_.push_back(i);
18+
for(auto i : batch_to_space_nd_param.crops())
19+
crops_.push_back(i);
20+
}
21+
22+
template <typename Dtype>
23+
void BatchToSpaceNDLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
24+
const vector<Blob<Dtype>*>& top) {
25+
auto shape = bottom[0]->shape();
26+
for(auto i = 0; i < block_shape_.size(); i++){
27+
shape[0] /= block_shape_[i];
28+
shape[i+1] *= block_shape_[i];
29+
shape[i+1] -= crops_[2*i] + crops_[2*i+1];
30+
}
31+
top[0]->Reshape(shape);
32+
}
33+
34+
template <typename Dtype>
35+
inline vector<int> BatchToSpaceNDLayer<Dtype>::indices(int offset, const vector<int> & shape) const {
36+
vector<int> indices(shape.size());
37+
int r = offset;
38+
for(int i = shape.size()-1; i>=0; i--){
39+
indices[i] = r % shape[i];
40+
r /= shape[i];
41+
}
42+
return indices;
43+
}
44+
45+
template <typename Dtype>
46+
inline int BatchToSpaceNDLayer<Dtype>::offset(const vector<int>& indices, const vector<int> & shape) const {
47+
int offset = 0;
48+
for (int i = 0; i < shape.size(); ++i) {
49+
offset *= shape[i];
50+
offset += indices[i];
51+
}
52+
return offset;
53+
}
54+
55+
template <typename Dtype>
56+
void BatchToSpaceNDLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
57+
const vector<Blob<Dtype>*>& top) {
58+
const Dtype* bottom_data = bottom[0]->cpu_data();
59+
Dtype* top_data = top[0]->mutable_cpu_data();
60+
vector<int> bottom_shape = bottom[0]->shape();
61+
// 1. Reshape input to reshaped of shape:
62+
// [block_shape[0], ..., block_shape[M-1], batch / prod(block_shape), input_shape[1], ..., input_shape[N-1]]
63+
// Permute dimensions of reshaped_padded to produce permuted_reshaped_padded of shape:
64+
// block_shape + [batch] + [padded_shape[1] / block_shape[0], ..., padded_shape[M] / block_shape[M-1]] + remaining_shape
65+
vector<Dtype> bottom_temp(bottom[0]->count());
66+
vector<int> bottom_temp_shape = bottom_shape;
67+
68+
bottom_temp_shape.insert(bottom_temp_shape.begin(), block_shape_.begin(), block_shape_.end());
69+
for(auto i : block_shape_)
70+
bottom_temp_shape[block_shape_.size()] /= i;
71+
// 2. Permute dimensions of reshaped to produce permuted of shape [batch / prod(block_shape),
72+
// input_shape[1], block_shape[0], ..., input_shape[M], block_shape[M-1],
73+
// input_shape[M+1], ..., input_shape[N-1]]
74+
vector<int> permuted_shape = bottom_temp_shape;
75+
vector<int> permuted_order(bottom_temp_shape.size());
76+
iota(permuted_order.begin(), permuted_order.end(), 0);
77+
for(int i=0; i<block_shape_.size(); i++){
78+
permuted_shape[2*i+1] = bottom_temp_shape[i+block_shape_.size()+1];
79+
permuted_shape[2*i+2] = bottom_temp_shape[i];
80+
permuted_order[2*i+1] = i + block_shape_.size() + 1;
81+
permuted_order[2*i+2] = i;
82+
}
83+
permuted_order[0] = block_shape_.size();
84+
permuted_shape[0] = bottom_temp_shape[permuted_order[0]];
85+
86+
int strides = 1;
87+
for(int i=2*block_shape_.size()+1; i<bottom_temp_shape.size(); i++)
88+
strides *= bottom_temp_shape[i];
89+
90+
for(int position=0; position<bottom[0]->count()/strides; position++){
91+
vector<int> coord_bottom = indices(position*strides, bottom_temp_shape);
92+
vector<int> coord_permuted(coord_bottom);
93+
for(int i=0; i<bottom_temp_shape.size(); i++)
94+
coord_permuted[i] = coord_bottom[permuted_order[i]];
95+
int position_permuted = offset(coord_permuted, permuted_shape);
96+
copy_n(bottom_data+position*strides, strides, bottom_temp.begin()+position_permuted);
97+
}
98+
// 3. Reshape permuted to produce reshaped_permuted of shape [batch / prod(block_shape),
99+
// input_shape[1] * block_shape[0], ..., input_shape[M] * block_shape[M-1],
100+
// input_shape[M+1], ..., input_shape[N-1]]
101+
for(int i=0; i<block_shape_.size(); i++){
102+
permuted_shape[1+i] *= permuted_shape[2+i];
103+
permuted_shape.erase(permuted_shape.begin()+2+i, permuted_shape.begin()+3+i);
104+
}
105+
// input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
106+
// input_shape[M+1], ..., input_shape[N-1]]
107+
for(int i=0; i<top[0]->count(); i++){
108+
vector<int> coord_top = indices(i, top[0]->shape());
109+
vector<int> coord_cropped = coord_top;
110+
for(int i=0; i<crops_.size()/2; i++){
111+
coord_cropped[i+1] += crops_[2*i];
112+
}
113+
int position_cropped = offset(coord_cropped, permuted_shape);
114+
top_data[i] = bottom_temp[position_cropped];
115+
// copy_n(bottom_temp.begin()+position_cropped, 1, top_data+i);
116+
}
117+
118+
}
119+
120+
INSTANTIATE_CLASS(BatchToSpaceNDLayer);
121+
REGISTER_LAYER_CLASS(BatchToSpaceND);
122+
123+
} // namespace caffe

0 commit comments

Comments
 (0)