Skip to content

Commit da3611b

Browse files
committed
remove Parameter layer of MatMul; support case: (A,B, ...,K)*(K,N) -> (A,B,..,N)
1 parent 8abe829 commit da3611b

File tree

3 files changed

+52
-26
lines changed

3 files changed

+52
-26
lines changed

include/caffe/layers/matmul_layer.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
namespace caffe {
1212
/*
13-
* @brief Resize images to size using nearest neighbor interpolation. ////
13+
* @brief MatMul. ////
1414
* Note: implementation of tf.linalg.matmul()
1515
* https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/linalg/matmul
1616
*/
@@ -24,7 +24,7 @@ template <typename Dtype> class MatMulLayer : public Layer<Dtype> {
2424
const vector<Blob<Dtype> *> &top);
2525

2626
virtual inline const char *type() const { return "MatMul"; }
27-
virtual inline int ExactNumBottomBlobs() const { return 2; }
27+
virtual inline int MinNumBottomBlobs() const { return 1; }
2828
virtual inline int ExactNumTopBlobs() const { return 1; }
2929

3030
protected:
@@ -48,6 +48,7 @@ template <typename Dtype> class MatMulLayer : public Layer<Dtype> {
4848
int K;
4949
bool transpose_a;
5050
bool transpose_b;
51+
vector<int> blob_shape_;
5152
};
5253

5354
} // namespace caffe

src/caffe/layers/matmul_layer.cpp

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,51 +13,75 @@ void MatMulLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype> *> &bottom,
1313
const MatMulParameter &matmul_param = this->layer_param_.matmul_param();
1414
transpose_a = matmul_param.transpose_a();
1515
transpose_b = matmul_param.transpose_b();
16+
blob_shape_.clear();
17+
std::copy(matmul_param.blob_shape().begin(), matmul_param.blob_shape().end(),
18+
std::back_inserter(blob_shape_));
1619

17-
CHECK_EQ(bottom[0]->num_axes(), bottom[1]->num_axes())
18-
<< "input a and input b should have same dimension!!";
20+
if (bottom.size() == 1 && this->blobs_.size() != 1 &&
21+
blob_shape_.size() != 0) {
22+
this->blobs_.resize(1);
23+
this->blobs_[0].reset(new Blob<Dtype>(blob_shape_));
24+
// initialize blobs_ value with 0.
25+
caffe_set(this->blobs_[0]->count(), Dtype(0),
26+
this->blobs_[0]->mutable_cpu_data());
27+
}
28+
Blob<Dtype> *inputs1 =
29+
(bottom.size() > 1) ? bottom[1] : this->blobs_[0].get();
1930
num_axes = bottom[0]->num_axes();
20-
M = transpose_a ? bottom[0]->shape(num_axes - 1)
21-
: bottom[0]->shape(num_axes - 2);
22-
N = transpose_b ? bottom[1]->shape(num_axes - 2)
23-
: bottom[1]->shape(num_axes - 1);
24-
K = transpose_a ? bottom[0]->shape(num_axes - 2)
25-
: bottom[0]->shape(num_axes - 1);
26-
if (transpose_b) {
27-
CHECK_EQ(K, bottom[1]->shape(num_axes - 1))
28-
<< "input a and input b have incompatible shapes! ";
31+
32+
CHECK_GE(bottom[0]->num_axes(), inputs1->num_axes())
33+
<< "input a and input b should have same dimension or dim(a) > dim(b)!!";
34+
35+
if (bottom[0]->num_axes() == inputs1->num_axes()) {
36+
M = transpose_a ? bottom[0]->shape(num_axes - 1)
37+
: bottom[0]->shape(num_axes - 2);
38+
N = transpose_b ? inputs1->shape(num_axes - 2)
39+
: inputs1->shape(num_axes - 1);
40+
K = transpose_a ? bottom[0]->shape(num_axes - 2)
41+
: bottom[0]->shape(num_axes - 1);
42+
if (transpose_b) {
43+
CHECK_EQ(K, inputs1->shape(num_axes - 1))
44+
<< "input a and input b have incompatible shapes! ";
45+
} else {
46+
CHECK_EQ(K, inputs1->shape(num_axes - 2))
47+
<< "input a and input b have incompatible shapes! ";
48+
}
49+
for (int i = 0; i < num_axes - 2; i++) {
50+
CHECK_EQ(bottom[0]->shape(i), inputs1->shape(i))
51+
<< "inputs should have same shape except in last two dimensions, but "
52+
"in dimension "
53+
<< i << ", the two inputs have different shape!";
54+
}
2955
} else {
30-
CHECK_EQ(K, bottom[1]->shape(num_axes - 2))
56+
int axes1 = bottom[0]->num_axes();
57+
int axes2 = inputs1->num_axes();
58+
K = bottom[0]->shape(axes1 - 1);
59+
M = bottom[0]->count() / K;
60+
N = inputs1->shape(axes2 - 1);
61+
CHECK_GE(axes2, 2) << "If dim(a) > dim(b), dim(b) should be 2!!";
62+
CHECK_EQ(K, inputs1->shape(axes2 - 2))
3163
<< "input a and input b have incompatible shapes! ";
3264
}
33-
for (int i = 0; i < num_axes - 2; i++) {
34-
CHECK_EQ(bottom[0]->shape(i), bottom[1]->shape(i))
35-
<< "inputs should have same shape except in last two dimensions, but "
36-
"in dimension "
37-
<< i << ", the two inputs have different shape!";
38-
}
3965
}
4066

4167
template <typename Dtype>
4268
void MatMulLayer<Dtype>::Reshape(const vector<Blob<Dtype> *> &bottom,
4369
const vector<Blob<Dtype> *> &top) {
44-
4570
vector<int> top_shape = bottom[0]->shape();
46-
top_shape[num_axes - 2] = M;
4771
top_shape[num_axes - 1] = N;
4872
top[0]->Reshape(top_shape);
4973
}
5074

5175
template <typename Dtype>
5276
void MatMulLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype> *> &bottom,
5377
const vector<Blob<Dtype> *> &top) {
54-
78+
Blob<Dtype> *inputs1 =
79+
(bottom.size() > 1) ? bottom[1] : this->blobs_[0].get();
5580
const Dtype *bottom_data0 = bottom[0]->cpu_data();
56-
const Dtype *bottom_data1 = bottom[1]->cpu_data();
81+
const Dtype *bottom_data1 = inputs1->cpu_data();
5782
Dtype *top_data = top[0]->mutable_cpu_data();
5883

59-
const int batch_size = bottom[0]->count(0, num_axes - 2);
60-
84+
const int batch_size = bottom[0]->count() / (M * K);
6185
for (int i = 0; i < batch_size; ++i) {
6286
int b_idx0 = i * M * K;
6387
int b_idx1 = i * K * N;

src/caffe/proto/caffe.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3515,6 +3515,7 @@ message LpNormalizationParameter {
35153515
message MatMulParameter {
35163516
optional bool transpose_a = 1[default = false];
35173517
optional bool transpose_b = 2[default = false];
3518+
repeated uint32 blob_shape = 3;
35183519
}
35193520

35203521
message GatherV2Parameter {

0 commit comments

Comments
 (0)