Skip to content

Commit 00c9717

Browse files
committed
refine depth_to_space layer with onnx implementation
1 parent e73a136 commit 00c9717

File tree

1 file changed

+38
-6
lines changed

1 file changed

+38
-6
lines changed

src/caffe/layers/depth_to_space_layer.cpp

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ void DepthToSpaceLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
1717
this->output_top_shape.push_back(bottom_shape[1] * this->block_size);
1818
this->output_top_shape.push_back(bottom_shape[2] * this->block_size);
1919
this->output_top_shape.push_back(bottom_shape[3] / (this->block_size*this->block_size));
20-
} else if(this->data_format == "NCHW"){
20+
} else if(this->data_format == "NCHW" || this->data_format == "CRD" || this->data_format == "DCR"){
21+
std::cout << "the second output shape" << std::endl;
2122
this->output_top_shape.push_back(bottom_shape[1] / (this->block_size*this->block_size));
2223
this->output_top_shape.push_back(bottom_shape[2] * this->block_size);
2324
this->output_top_shape.push_back(bottom_shape[3] * this->block_size);
@@ -67,7 +68,7 @@ void DepthToSpaceLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
6768
}
6869
}
6970
}
70-
} else {
71+
} else if (this->data_format == "NCHW" || this->data_format == "DCR") {
7172
const int batch_size = this->output_top_shape[0];
7273
const int output_depth = this->output_top_shape[1];
7374
const int output_height = this->output_top_shape[2];
@@ -91,15 +92,46 @@ void DepthToSpaceLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
9192
const int offset_d =
9293
(offset_h * this->block_size + offset_w) * output_depth;
9394
for (int d = 0; d < output_depth; ++d) {
94-
const int in_d = d + offset_d;
95-
const int out_index = ((b*output_depth + d)*output_height + h)*output_width + w;
96-
const int in_index = ((b*input_depth + in_d)*input_height + in_h)*input_width + in_w;
97-
top_data[out_index] = bottom_data[in_index];
95+
const int in_d = d + offset_d;
96+
const int out_index = ((b*output_depth + d)*output_height + h)*output_width + w;
97+
const int in_index = ((b*input_depth + in_d)*input_height + in_h)*input_width + in_w;
98+
top_data[out_index] = bottom_data[in_index];
9899
}
99100
}
100101
}
101102
}
102103
}
104+
else if (this->data_format == "CRD") {
105+
const int batch_size = this->output_top_shape[0];
106+
const int output_depth = this->output_top_shape[1];
107+
const int output_height = this->output_top_shape[2];
108+
const int output_width = this->output_top_shape[3];
109+
110+
vector<int> bottom_shape = bottom[0]->shape();
111+
const int input_depth = bottom_shape[1];
112+
const int input_height = bottom_shape[2];
113+
const int input_width = bottom_shape[3];
114+
115+
const Dtype* bottom_data = bottom[0]->cpu_data();
116+
Dtype* top_data = top[0]->mutable_cpu_data();
117+
118+
for (int b = 0; b < batch_size; ++b) {
119+
for (int h = 0; h < output_height; ++h) {
120+
const int in_h = h / this->block_size;
121+
const int offset_h = (h % this->block_size);
122+
for (int w = 0; w < output_width; ++w) {
123+
const int in_w = w / this->block_size;
124+
const int offset_w = (w % this->block_size);
125+
for (int d = 0; d < output_depth; ++d) {
126+
const int in_d = (d * this->block_size + offset_h) * this->block_size + offset_w;
127+
const int out_index = ((b*output_depth + d)*output_height + h)*output_width + w;
128+
const int in_index = ((b*input_depth + in_d)*input_height + in_h)*input_width + in_w;
129+
top_data[out_index] = bottom_data[in_index];
130+
}
131+
}
132+
}
133+
}
134+
}
103135
}
104136

105137

0 commit comments

Comments
 (0)