@@ -36,121 +36,39 @@ class MatMulOp : public framework::OperatorWithKernel {
36
36
37
37
auto dim_x = context->GetInputDim (" X" );
38
38
auto dim_y = context->GetInputDim (" Y" );
39
- bool transpose_x = context->Attrs ().Get <bool >(" transpose_X" );
40
- bool transpose_y = context->Attrs ().Get <bool >(" transpose_Y" );
41
-
42
- PADDLE_ENFORCE_GE (dim_x.size (), 1 ,
43
- " Input tensor X must be at least 1-dimensional." );
44
- PADDLE_ENFORCE_GE (dim_y.size (), 1 ,
45
- " Input tensor Y must be at least 1-dimensional." );
46
-
47
- std::vector<int64_t > out_dim;
48
- int64_t batch_count = 1 ;
49
- if (dim_x.size () > 3 ) {
50
- PADDLE_ENFORCE_EQ (
51
- dim_y.size (), dim_x.size (),
52
- " The dimensions of X and Y must be the same, and both of "
53
- " them should be %d-dimensional." ,
54
- dim_x.size ());
55
-
56
- // The first rank-2 dimensions are accumulated on the batch_count, and the
57
- // last two dimensions are used for matrix multiplication.
58
- for (int j = 0 ; j < dim_x.size () - 2 ; ++j) {
59
- PADDLE_ENFORCE_EQ (dim_y[j], dim_x[j],
60
- " The %d-th dimension of X and Y must be the same." ,
61
- j);
62
- out_dim.push_back (dim_x[j]);
63
- batch_count *= dim_x[j];
64
- }
65
- }
66
39
67
- int M = 0 , N = 0 , KX = 0 , KY = 0 , batchCountX = 0 , batchCountY = 0 ;
68
- bool remove_initial_dim = false , remove_final_dim = false ;
69
-
70
- switch (dim_x.size ()) {
71
- case 1 :
72
- if (transpose_x) {
73
- M = dim_x[0 ];
74
- KX = 1 ;
75
- } else {
76
- M = 1 ;
77
- KX = dim_x[0 ];
78
- remove_initial_dim = true ;
79
- }
80
- break ;
81
- case 2 :
82
- M = transpose_x ? dim_x[1 ] : dim_x[0 ];
83
- KX = transpose_x ? dim_x[0 ] : dim_x[1 ];
84
- break ;
85
- case 3 :
86
- batchCountX = dim_x[0 ];
87
- M = transpose_x ? dim_x[2 ] : dim_x[1 ];
88
- KX = transpose_x ? dim_x[1 ] : dim_x[2 ];
89
- break ;
90
- default :
91
- batchCountX = batch_count;
92
- size_t mat_s = dim_x.size () - 2 ;
93
- M = transpose_x ? dim_x[mat_s + 1 ] : dim_x[mat_s];
94
- KX = transpose_x ? dim_x[mat_s] : dim_x[mat_s + 1 ];
95
- break ;
96
- }
40
+ auto mat_dim_x = math::GetMatDim (GetXDim (dim_x), 0 ,
41
+ context->Attrs ().Get <bool >(" transpose_X" ));
42
+ auto mat_dim_y = math::GetMatDim (GetYDim (dim_y), 0 ,
43
+ context->Attrs ().Get <bool >(" transpose_Y" ));
97
44
98
- switch (dim_y.size ()) {
99
- case 1 :
100
- if (transpose_y) {
101
- N = dim_y[0 ];
102
- KY = 1 ;
103
- } else {
104
- N = 1 ;
105
- KY = dim_y[0 ];
106
- remove_final_dim = true ;
107
- }
108
- break ;
109
- case 2 :
110
- KY = transpose_y ? dim_y[1 ] : dim_y[0 ];
111
- N = transpose_y ? dim_y[0 ] : dim_y[1 ];
112
- break ;
113
- case 3 :
114
- batchCountY = dim_y[0 ];
115
- KY = transpose_y ? dim_y[2 ] : dim_y[1 ];
116
- N = transpose_y ? dim_y[1 ] : dim_y[2 ];
117
- break ;
118
- default :
119
- batchCountY = batch_count;
120
- size_t mat_s = dim_y.size () - 2 ;
121
- KY = transpose_y ? dim_y[mat_s + 1 ] : dim_y[mat_s];
122
- N = transpose_y ? dim_y[mat_s] : dim_y[mat_s + 1 ];
45
+ PADDLE_ENFORCE_EQ (mat_dim_x.width_ , mat_dim_y.height_ );
46
+ PADDLE_ENFORCE (mat_dim_x.batch_size_ == mat_dim_y.batch_size_ ||
47
+ mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0 );
48
+ std::vector<int64_t > dim_out;
49
+ if (mat_dim_x.batch_size_ != 0 ) {
50
+ dim_out = framework::vectorize (dim_x);
51
+ dim_out[dim_out.size () - 2 ] = mat_dim_x.height_ ;
52
+ dim_out[dim_out.size () - 1 ] = mat_dim_y.width_ ;
53
+ } else if (mat_dim_y.batch_size_ != 0 ) {
54
+ dim_out = framework::vectorize (dim_y);
55
+ dim_out[dim_out.size () - 2 ] = mat_dim_x.height_ ;
56
+ dim_out[dim_out.size () - 1 ] = mat_dim_y.width_ ;
57
+ } else {
58
+ dim_out = {mat_dim_x.height_ , mat_dim_y.width_ };
123
59
}
124
60
125
- PADDLE_ENFORCE_EQ (
126
- KX, KY,
127
- " First matrix's width must be equal with second matrix's height." );
128
- if (batchCountX && batchCountY) {
129
- PADDLE_ENFORCE_EQ (
130
- batchCountX, batchCountY,
131
- " When Input(X) and Input(Y) are both three dimensional, they "
132
- " must have the same batch dimension." );
61
+ if (dim_x.size () == 1 && dim_out[dim_out.size () - 2 ] == 1 ) {
62
+ std::swap (dim_out[dim_out.size () - 2 ], dim_out[dim_out.size () - 1 ]);
63
+ dim_out.resize (dim_out.size () - 1 );
133
64
}
134
- int batchCount = std::max (batchCountX, batchCountY);
135
65
136
- std::vector<int64_t > dim_out;
137
- if (batchCount) {
138
- if (dim_x.size () > 3 ) {
139
- dim_out.insert (dim_out.begin (), out_dim.begin (), out_dim.end ());
140
- } else {
141
- dim_out.push_back (batchCount);
142
- }
66
+ if (dim_y.size () == 1 && dim_out[dim_out.size () - 1 ] == 1 ) {
67
+ dim_out.resize (dim_out.size () - 1 );
143
68
}
144
- if (!remove_initial_dim) {
145
- dim_out.push_back (M);
146
- }
147
- if (!remove_final_dim) {
148
- dim_out.push_back (N);
149
- }
150
- if (dim_out.size () == 0 ) {
151
- // We don't support 0-dimensional Tensors (scalars), so instead
152
- // treat the output as a Tensor of shape (1, ) in this case.
153
- dim_out.push_back (1 );
69
+
70
+ if (dim_out.empty ()) {
71
+ dim_out = {1 };
154
72
}
155
73
context->SetOutputDim (" Out" , framework::make_ddim (dim_out));
156
74
context->ShareLoD (" X" , /* ->*/ " Out" );
0 commit comments