@@ -17,7 +17,8 @@ void DepthToSpaceLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
1717 this ->output_top_shape .push_back (bottom_shape[1 ] * this ->block_size );
1818 this ->output_top_shape .push_back (bottom_shape[2 ] * this ->block_size );
1919 this ->output_top_shape .push_back (bottom_shape[3 ] / (this ->block_size *this ->block_size ));
20- } else if (this ->data_format == " NCHW" ){
20+ } else if (this ->data_format == " NCHW" || this ->data_format == " CRD" || this ->data_format == " DCR" ){
21+ std::cout << " the second output shape" << std::endl;
2122 this ->output_top_shape .push_back (bottom_shape[1 ] / (this ->block_size *this ->block_size ));
2223 this ->output_top_shape .push_back (bottom_shape[2 ] * this ->block_size );
2324 this ->output_top_shape .push_back (bottom_shape[3 ] * this ->block_size );
@@ -67,7 +68,7 @@ void DepthToSpaceLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
6768 }
6869 }
6970 }
70- } else {
71+ } else if ( this -> data_format == " NCHW " || this -> data_format == " DCR " ) {
7172 const int batch_size = this ->output_top_shape [0 ];
7273 const int output_depth = this ->output_top_shape [1 ];
7374 const int output_height = this ->output_top_shape [2 ];
@@ -91,15 +92,46 @@ void DepthToSpaceLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
9192 const int offset_d =
9293 (offset_h * this ->block_size + offset_w) * output_depth;
9394 for (int d = 0 ; d < output_depth; ++d) {
94- const int in_d = d + offset_d;
95- const int out_index = ((b*output_depth + d)*output_height + h)*output_width + w;
96- const int in_index = ((b*input_depth + in_d)*input_height + in_h)*input_width + in_w;
97- top_data[out_index] = bottom_data[in_index];
95+ const int in_d = d + offset_d;
96+ const int out_index = ((b*output_depth + d)*output_height + h)*output_width + w;
97+ const int in_index = ((b*input_depth + in_d)*input_height + in_h)*input_width + in_w;
98+ top_data[out_index] = bottom_data[in_index];
9899 }
99100 }
100101 }
101102 }
102103 }
104+ else if (this ->data_format == " CRD" ) {
105+ const int batch_size = this ->output_top_shape [0 ];
106+ const int output_depth = this ->output_top_shape [1 ];
107+ const int output_height = this ->output_top_shape [2 ];
108+ const int output_width = this ->output_top_shape [3 ];
109+
110+ vector<int > bottom_shape = bottom[0 ]->shape ();
111+ const int input_depth = bottom_shape[1 ];
112+ const int input_height = bottom_shape[2 ];
113+ const int input_width = bottom_shape[3 ];
114+
115+ const Dtype* bottom_data = bottom[0 ]->cpu_data ();
116+ Dtype* top_data = top[0 ]->mutable_cpu_data ();
117+
118+ for (int b = 0 ; b < batch_size; ++b) {
119+ for (int h = 0 ; h < output_height; ++h) {
120+ const int in_h = h / this ->block_size ;
121+ const int offset_h = (h % this ->block_size );
122+ for (int w = 0 ; w < output_width; ++w) {
123+ const int in_w = w / this ->block_size ;
124+ const int offset_w = (w % this ->block_size );
125+ for (int d = 0 ; d < output_depth; ++d) {
126+ const int in_d = (d * this ->block_size + offset_h) * this ->block_size + offset_w;
127+ const int out_index = ((b*output_depth + d)*output_height + h)*output_width + w;
128+ const int in_index = ((b*input_depth + in_d)*input_height + in_h)*input_width + in_w;
129+ top_data[out_index] = bottom_data[in_index];
130+ }
131+ }
132+ }
133+ }
134+ }
103135}
104136
105137
0 commit comments