@@ -13,51 +13,75 @@ void MatMulLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype> *> &bottom,
1313 const MatMulParameter &matmul_param = this ->layer_param_ .matmul_param ();
1414 transpose_a = matmul_param.transpose_a ();
1515 transpose_b = matmul_param.transpose_b ();
16+ blob_shape_.clear ();
17+ std::copy (matmul_param.blob_shape ().begin (), matmul_param.blob_shape ().end (),
18+ std::back_inserter (blob_shape_));
1619
17- CHECK_EQ (bottom[0 ]->num_axes (), bottom[1 ]->num_axes ())
18- << " input a and input b should have same dimension!!" ;
20+ if (bottom.size () == 1 && this ->blobs_ .size () != 1 &&
21+ blob_shape_.size () != 0 ) {
22+ this ->blobs_ .resize (1 );
23+ this ->blobs_ [0 ].reset (new Blob<Dtype>(blob_shape_));
24+ // initialize blobs_ value with 0.
25+ caffe_set (this ->blobs_ [0 ]->count (), Dtype (0 ),
26+ this ->blobs_ [0 ]->mutable_cpu_data ());
27+ }
28+ Blob<Dtype> *inputs1 =
29+ (bottom.size () > 1 ) ? bottom[1 ] : this ->blobs_ [0 ].get ();
1930 num_axes = bottom[0 ]->num_axes ();
20- M = transpose_a ? bottom[0 ]->shape (num_axes - 1 )
21- : bottom[0 ]->shape (num_axes - 2 );
22- N = transpose_b ? bottom[1 ]->shape (num_axes - 2 )
23- : bottom[1 ]->shape (num_axes - 1 );
24- K = transpose_a ? bottom[0 ]->shape (num_axes - 2 )
25- : bottom[0 ]->shape (num_axes - 1 );
26- if (transpose_b) {
27- CHECK_EQ (K, bottom[1 ]->shape (num_axes - 1 ))
28- << " input a and input b have incompatible shapes! " ;
31+
32+ CHECK_GE (bottom[0 ]->num_axes (), inputs1->num_axes ())
33+ << " input a and input b should have same dimension or dim(a) > dim(b)!!" ;
34+
35+ if (bottom[0 ]->num_axes () == inputs1->num_axes ()) {
36+ M = transpose_a ? bottom[0 ]->shape (num_axes - 1 )
37+ : bottom[0 ]->shape (num_axes - 2 );
38+ N = transpose_b ? inputs1->shape (num_axes - 2 )
39+ : inputs1->shape (num_axes - 1 );
40+ K = transpose_a ? bottom[0 ]->shape (num_axes - 2 )
41+ : bottom[0 ]->shape (num_axes - 1 );
42+ if (transpose_b) {
43+ CHECK_EQ (K, inputs1->shape (num_axes - 1 ))
44+ << " input a and input b have incompatible shapes! " ;
45+ } else {
46+ CHECK_EQ (K, inputs1->shape (num_axes - 2 ))
47+ << " input a and input b have incompatible shapes! " ;
48+ }
49+ for (int i = 0 ; i < num_axes - 2 ; i++) {
50+ CHECK_EQ (bottom[0 ]->shape (i), inputs1->shape (i))
51+ << " inputs should have same shape except in last two dimensions, but "
52+ " in dimension "
53+ << i << " , the two inputs have different shape!" ;
54+ }
2955 } else {
30- CHECK_EQ (K, bottom[1 ]->shape (num_axes - 2 ))
56+ int axes1 = bottom[0 ]->num_axes ();
57+ int axes2 = inputs1->num_axes ();
58+ K = bottom[0 ]->shape (axes1 - 1 );
59+ M = bottom[0 ]->count () / K;
60+ N = inputs1->shape (axes2 - 1 );
61+ CHECK_GE (axes2, 2 ) << " If dim(a) > dim(b), dim(b) should be 2!!" ;
62+ CHECK_EQ (K, inputs1->shape (axes2 - 2 ))
3163 << " input a and input b have incompatible shapes! " ;
3264 }
33- for (int i = 0 ; i < num_axes - 2 ; i++) {
34- CHECK_EQ (bottom[0 ]->shape (i), bottom[1 ]->shape (i))
35- << " inputs should have same shape except in last two dimensions, but "
36- " in dimension "
37- << i << " , the two inputs have different shape!" ;
38- }
3965}
4066
4167template <typename Dtype>
4268void MatMulLayer<Dtype>::Reshape(const vector<Blob<Dtype> *> &bottom,
4369 const vector<Blob<Dtype> *> &top) {
44-
4570 vector<int > top_shape = bottom[0 ]->shape ();
46- top_shape[num_axes - 2 ] = M;
4771 top_shape[num_axes - 1 ] = N;
4872 top[0 ]->Reshape (top_shape);
4973}
5074
5175template <typename Dtype>
5276void MatMulLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype> *> &bottom,
5377 const vector<Blob<Dtype> *> &top) {
54-
78+ Blob<Dtype> *inputs1 =
79+ (bottom.size () > 1 ) ? bottom[1 ] : this ->blobs_ [0 ].get ();
5580 const Dtype *bottom_data0 = bottom[0 ]->cpu_data ();
56- const Dtype *bottom_data1 = bottom[ 1 ] ->cpu_data ();
81+ const Dtype *bottom_data1 = inputs1 ->cpu_data ();
5782 Dtype *top_data = top[0 ]->mutable_cpu_data ();
5883
59- const int batch_size = bottom[0 ]->count (0 , num_axes - 2 );
60-
84+ const int batch_size = bottom[0 ]->count () / (M * K);
6185 for (int i = 0 ; i < batch_size; ++i) {
6286 int b_idx0 = i * M * K;
6387 int b_idx1 = i * K * N;
0 commit comments