Skip to content

Commit 6519a9e

Browse files
committed
[evconvert] support NCHW data format and optimize NHWC data format
1 parent c4d5600 commit 6519a9e

File tree

1 file changed

+37
-6
lines changed

1 file changed

+37
-6
lines changed

src/caffe/layers/depthtospace_layer.cpp

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@ void DepthToSpaceLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
1818
this->output_top_shape.push_back(bottom_shape[2] * this->block_size);
1919
this->output_top_shape.push_back(bottom_shape[3] / (this->block_size*this->block_size));
2020
} else if(this->data_format == "NCHW"){
21-
NOT_IMPLEMENTED;
22-
// this->output_top_shape.push_back(bottom_shape[1] / (this->block_size*this->block_size));
23-
// this->output_top_shape.push_back(bottom_shape[2] * this->block_size);
24-
// this->output_top_shape.push_back(bottom_shape[3] * this->block_size);
21+
this->output_top_shape.push_back(bottom_shape[1] / (this->block_size*this->block_size));
22+
this->output_top_shape.push_back(bottom_shape[2] * this->block_size);
23+
this->output_top_shape.push_back(bottom_shape[3] * this->block_size);
2524
}
2625
}
2726

@@ -61,8 +60,40 @@ void DepthToSpaceLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
6160
(offset_h * this->block_size + offset_w) * output_depth;
6261
for (int d = 0; d < output_depth; ++d) {
6362
const int in_d = d + offset_d;
64-
const int out_index = b*output_height*output_width*output_depth + h*output_width*output_depth + w*output_depth + d;
65-
const int in_index = b*input_height*input_width*input_depth + in_h*input_width*input_depth + in_w*input_depth + in_d;
63+
const int out_index = ((b*output_height + h)*output_width + w)*output_depth + d;
64+
const int in_index = ((b*input_height + in_h)*input_width + in_w)*input_depth + in_d;
65+
top_data[out_index] = bottom_data[in_index];
66+
}
67+
}
68+
}
69+
}
70+
} else {
71+
const int batch_size = this->output_top_shape[0];
72+
const int output_depth = this->output_top_shape[1];
73+
const int output_height = this->output_top_shape[2];
74+
const int output_width = this->output_top_shape[3];
75+
76+
vector<int> bottom_shape = bottom[0]->shape();
77+
const int input_depth = bottom_shape[1];
78+
const int input_height = bottom_shape[2];
79+
const int input_width = bottom_shape[3];
80+
81+
const Dtype* bottom_data = bottom[0]->cpu_data();
82+
Dtype* top_data = top[0]->mutable_cpu_data();
83+
84+
for (int b = 0; b < batch_size; ++b) {
85+
for (int h = 0; h < output_height; ++h) {
86+
const int in_h = h / this->block_size;
87+
const int offset_h = (h % this->block_size);
88+
for (int w = 0; w < output_width; ++w) {
89+
const int in_w = w / this->block_size;
90+
const int offset_w = (w % this->block_size);
91+
const int offset_d =
92+
(offset_h * this->block_size + offset_w) * output_depth;
93+
for (int d = 0; d < output_depth; ++d) {
94+
const int in_d = d + offset_d;
95+
const int out_index = ((b*output_depth + d)*output_height + h)*output_width + w;
96+
const int in_index = ((b*input_depth + in_d)*input_height + in_h)*input_width + in_w;
6697
top_data[out_index] = bottom_data[in_index];
6798
}
6899
}

0 commit comments

Comments
 (0)