@@ -25,7 +25,7 @@ namespace operators {
25
25
* Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
26
26
* original x_dim is returned.
27
27
*/
28
- static framework::DDim RowMatrixFromVector (const framework::DDim& x_dim) {
28
+ static framework::DDim RowMatrixFromVector (const framework::DDim & x_dim) {
29
29
if (x_dim.size () > 1 ) {
30
30
return x_dim;
31
31
}
@@ -36,7 +36,7 @@ static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) {
36
36
* Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
37
37
* original y_dim is returned.
38
38
*/
39
- static framework::DDim ColumnMatrixFromVector (const framework::DDim& y_dim) {
39
+ static framework::DDim ColumnMatrixFromVector (const framework::DDim & y_dim) {
40
40
if (y_dim.size () > 1 ) {
41
41
return y_dim;
42
42
}
@@ -46,12 +46,12 @@ static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) {
46
46
template <typename DeviceContext, typename T>
47
47
class MatMulKernel : public framework ::OpKernel<T> {
48
48
public:
49
- void Compute (const framework::ExecutionContext& context) const override {
50
- auto & x =
49
+ void Compute (const framework::ExecutionContext & context) const override {
50
+ auto & x =
51
51
detail::Ref (context.Input <framework::Tensor>(" X" ), " Cannot find X" );
52
- auto & y =
52
+ auto & y =
53
53
detail::Ref (context.Input <framework::Tensor>(" Y" ), " Cannot find Y" );
54
- auto * out = context.Output <framework::Tensor>(" Out" );
54
+ auto * out = context.Output <framework::Tensor>(" Out" );
55
55
out->mutable_data <T>(context.GetPlace ());
56
56
57
57
auto blas = math::GetBlas<DeviceContext, T>(context);
@@ -65,7 +65,7 @@ class MatMulKernel : public framework::OpKernel<T> {
65
65
66
66
// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
67
67
// Identity op if the tensor is not of rank 3.
68
- static framework::Tensor FoldInitDims (const framework::Tensor& input) {
68
+ static framework::Tensor FoldInitDims (const framework::Tensor & input) {
69
69
auto output = input;
70
70
auto in_dims = input.dims ();
71
71
if (in_dims.size () == 3 ) {
@@ -78,8 +78,8 @@ static framework::Tensor FoldInitDims(const framework::Tensor& input) {
78
78
// (Warning: This requires transposing data and writes into new memory.)
79
79
// Identity op if the tensor is not of rank 3.
80
80
template <typename DeviceContext, typename T>
81
- static framework::Tensor FoldHeadAndLastDims (const DeviceContext& context,
82
- const framework::Tensor& input) {
81
+ static framework::Tensor FoldHeadAndLastDims (const DeviceContext & context,
82
+ const framework::Tensor & input) {
83
83
auto in_dims = input.dims ();
84
84
if (in_dims.size () != 3 ) {
85
85
return input;
@@ -102,7 +102,7 @@ static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context,
102
102
* If transposed, `H,W` will be swapped.
103
103
*/
104
104
static void ReshapeTensorIntoMatrixSequence (
105
- framework::Tensor* x, const math::MatDescriptor& descriptor) {
105
+ framework::Tensor * x, const math::MatDescriptor & descriptor) {
106
106
int64_t h, w;
107
107
h = descriptor.height_ ;
108
108
w = descriptor.width_ ;
@@ -130,9 +130,9 @@ static void ReshapeTensorIntoMatrixSequence(
130
130
* If any of `X` and `Y` has batch size BatchSize, the out will have the
131
131
* BatchSize.
132
132
*/
133
- static void ReshapeXYOutIntoMatrixSequence (framework::Tensor* x,
134
- framework::Tensor* y,
135
- framework::Tensor* out, bool trans_x,
133
+ static void ReshapeXYOutIntoMatrixSequence (framework::Tensor * x,
134
+ framework::Tensor * y,
135
+ framework::Tensor * out, bool trans_x,
136
136
bool trans_y) {
137
137
auto x_dim = RowMatrixFromVector (x->dims ());
138
138
auto y_dim = ColumnMatrixFromVector (y->dims ());
@@ -177,29 +177,29 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x,
177
177
template <typename DeviceContext, typename T>
178
178
class MatMulGradKernel : public framework ::OpKernel<T> {
179
179
public:
180
- void MatMul (const framework::ExecutionContext& context,
181
- const framework::Tensor& a, bool trans_a,
182
- const framework::Tensor& b, bool trans_b,
183
- framework::Tensor* out) const {
180
+ void MatMul (const framework::ExecutionContext & context,
181
+ const framework::Tensor & a, bool trans_a,
182
+ const framework::Tensor & b, bool trans_b,
183
+ framework::Tensor * out) const {
184
184
out->mutable_data <T>(context.GetPlace ());
185
185
auto blas = math::GetBlas<DeviceContext, T>(context);
186
186
auto mat_dim_a = math::CreateMatrixDescriptor (a.dims (), 0 , trans_a);
187
187
auto mat_dim_b = math::CreateMatrixDescriptor (b.dims (), 0 , trans_b);
188
188
blas.MatMul (a, mat_dim_a, b, mat_dim_b, T (1 ), out, T (0 ));
189
189
}
190
190
191
- void CalcInputGrad (const framework::ExecutionContext& context,
192
- const framework::Tensor& a, bool trans_a,
193
- bool is_fold_init_dims_a, const framework::Tensor& b,
191
+ void CalcInputGrad (const framework::ExecutionContext & context,
192
+ const framework::Tensor & a, bool trans_a,
193
+ bool is_fold_init_dims_a, const framework::Tensor & b,
194
194
bool trans_b, bool is_fold_init_dims_b,
195
- framework::Tensor* out) const {
195
+ framework::Tensor * out) const {
196
196
if (out == nullptr ) return ;
197
197
bool need_combine = (a.dims ().size () == 3 || b.dims ().size () == 3 ) &&
198
198
out->dims ().size () == 2 ;
199
199
if (!need_combine) {
200
200
MatMul (context, a, trans_a, b, trans_b, out);
201
201
} else {
202
- auto & ctx = context.template device_context <DeviceContext>();
202
+ auto & ctx = context.template device_context <DeviceContext>();
203
203
MatMul (context, is_fold_init_dims_a
204
204
? FoldInitDims (a)
205
205
: FoldHeadAndLastDims<DeviceContext, T>(ctx, a),
@@ -210,13 +210,13 @@ class MatMulGradKernel : public framework::OpKernel<T> {
210
210
}
211
211
}
212
212
213
- void Compute (const framework::ExecutionContext& context) const override {
213
+ void Compute (const framework::ExecutionContext & context) const override {
214
214
auto x = *context.Input <framework::Tensor>(" X" );
215
215
auto y = *context.Input <framework::Tensor>(" Y" );
216
216
auto dout =
217
217
*context.Input <framework::Tensor>(framework::GradVarName (" Out" ));
218
- auto * dx = context.Output <framework::Tensor>(framework::GradVarName (" X" ));
219
- auto * dy = context.Output <framework::Tensor>(framework::GradVarName (" Y" ));
218
+ auto * dx = context.Output <framework::Tensor>(framework::GradVarName (" X" ));
219
+ auto * dy = context.Output <framework::Tensor>(framework::GradVarName (" Y" ));
220
220
bool transpose_x = context.Attr <bool >(" transpose_X" );
221
221
bool transpose_y = context.Attr <bool >(" transpose_Y" );
222
222
@@ -269,7 +269,7 @@ class MatMulOp : public framework::OperatorWithKernel {
269
269
using framework::OperatorWithKernel::OperatorWithKernel;
270
270
271
271
protected:
272
- void InferShape (framework::InferShapeContext* context) const override {
272
+ void InferShape (framework::InferShapeContext * context) const override {
273
273
PADDLE_ENFORCE (context->HasInput (" X" ),
274
274
" Input(X) of MatMulOp should not be null." );
275
275
PADDLE_ENFORCE (context->HasInput (" Y" ),
@@ -375,7 +375,7 @@ class MatMulOpGrad : public framework::OperatorWithKernel {
375
375
using framework::OperatorWithKernel::OperatorWithKernel;
376
376
377
377
protected:
378
- void InferShape (framework::InferShapeContext* context) const override {
378
+ void InferShape (framework::InferShapeContext * context) const override {
379
379
PADDLE_ENFORCE (context->HasInput (" X" ), " Input(X) should not be null" );
380
380
PADDLE_ENFORCE (context->HasInput (" Y" ), " Input(Y) should not be null" );
381
381
PADDLE_ENFORCE (context->HasInput (framework::GradVarName (" Out" )),
@@ -401,7 +401,7 @@ class MatMulOpGradMaker : public framework::SingleGradOpDescMaker {
401
401
402
402
protected:
403
403
std::unique_ptr<framework::OpDesc> Apply () const override {
404
- auto * retv = new framework::OpDesc ();
404
+ auto * retv = new framework::OpDesc ();
405
405
retv->SetType (" matmul_grad" );
406
406
retv->SetInput (" X" , Input (" X" ));
407
407
retv->SetInput (" Y" , Input (" Y" ));
@@ -420,15 +420,27 @@ REGISTER_OPERATOR(matmul, ops::MatMulOp, ops::MatMulOpMaker,
420
420
ops::MatMulOpGradMaker);
421
421
REGISTER_OPERATOR (matmul_grad, ops::MatMulOpGrad);
422
422
REGISTER_OP_CPU_KERNEL (
423
- matmul, ops::MatMulKernel<paddle::platform::CPUDeviceContext, float >);
423
+ matmul, ops::MatMulKernel<paddle::platform::CPUDeviceContext, float >,
424
+ ops::MatMulKernel<paddle::platform::CPUDeviceContext, double >,
425
+ ops::MatMulKernel<paddle::platform::CPUDeviceContext,
426
+ paddle::platform::float16>);
424
427
REGISTER_OP_CPU_KERNEL (
425
428
matmul_grad,
426
- ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, float >);
429
+ ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, float >,
430
+ ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, double >,
431
+ ops::MatMulGradKernel<paddle::platform::CPUDeviceContext,
432
+ paddle::platform::float16>);
427
433
428
434
#ifdef PADDLE_WITH_CUDA
429
435
REGISTER_OP_CUDA_KERNEL (
430
- matmul, ops::MatMulKernel<paddle::platform::CUDADeviceContext, float >);
436
+ matmul, ops::MatMulKernel<paddle::platform::CUDADeviceContext, float >,
437
+ ops::MatMulKernel<paddle::platform::CUDADeviceContext, double >,
438
+ ops::MatMulKernel<paddle::platform::CUDADeviceContext,
439
+ paddle::platform::float16>);
431
440
REGISTER_OP_CUDA_KERNEL (
432
441
matmul_grad,
433
- ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, float >);
442
+ ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, float >,
443
+ ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, double >,
444
+ ops::MatMulGradKernel<paddle::platform::CUDADeviceContext,
445
+ paddle::platform::float16>);
434
446
#endif
0 commit comments