@@ -236,28 +236,11 @@ void ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
236236
237237 if (this ->submanifold_sparse_ )
238238 {
239- if (bottom[0 ]->num_axes ()==4 )
240- {
241- CHECK_EQ (bottom[0 ]->height (), top[0 ]->height ())<<
242- " Input and output blob height not equal! Submanifold sparse computation is invalid!" ;
243- CHECK_EQ (bottom[0 ]->width (), top[0 ]->width ())<<
244- " Input and output blob width not equal! Submanifold sparse computation is invalid!" ;
245- }
246- else if (bottom[0 ]->num_axes ()==5 )
247- {
248- CHECK_EQ (bottom[0 ]->shape (2 ), top[0 ]->shape (2 ))<<
249- " Input and output blob depth not equal! Submanifold sparse computation is invalid!" ;
250- CHECK_EQ (bottom[0 ]->shape (3 ), top[0 ]->shape (3 ))<<
251- " Input and output blob height not equal! Submanifold sparse computation is invalid!" ;
252- CHECK_EQ (bottom[0 ]->shape (4 ), top[0 ]->shape (4 ))<<
253- " Input and output blob width not equal! Submanifold sparse computation is invalid!" ;
254- }
255- else
256- {
257- CHECK_EQ (bottom[0 ]->num_axes (), 3 )<<" Not support Submanifold sparse computation for such blob dimension yet!" ;
258- CHECK_EQ (bottom[0 ]->shape (2 ), top[0 ]->shape (2 ))<<
259- " Input and output blob length not equal! Submanifold sparse computation is invalid!" ;
260- }
239+ CHECK_GE (bottom[0 ]->num_axes (), 3 )<<" Input blob dimension must >=3!" ;
240+ for (int i=2 ; i<bottom[0 ]->num_axes ();i++)
241+ CHECK_EQ (bottom[0 ]->shape (i), top[0 ]->shape (i))<<
242+ " Input and output blob shape does not match! Submanifold sparse computation is invalid!" ;
243+
261244 LOG (INFO)<<" Starts submanifold sparse computation." ;
262245
263246 for (int index=0 ; index<bottom[0 ]->count (2 ); index++)
0 commit comments