@@ -34,35 +34,28 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel {
34
34
auto y_dims = ctx->GetInputDim (" Y" );
35
35
auto weight_dims = ctx->GetInputDim (" Weight" );
36
36
37
- PADDLE_ENFORCE_EQ (x_dims.size (), 2UL , " The input X must be a 2D Tensor." );
38
- PADDLE_ENFORCE_EQ (y_dims.size (), 2UL , " The input Y must be a 2D Tensor." );
37
+ PADDLE_ENFORCE_EQ (x_dims.size (), 2UL , " The input(X) must be a 2D Tensor." );
38
+ PADDLE_ENFORCE_EQ (y_dims.size (), 2UL , " The input(Y) must be a 2D Tensor." );
39
39
PADDLE_ENFORCE_EQ (weight_dims.size (), 3UL ,
40
- " The input Weight must be a 3D tensor." );
41
- PADDLE_ENFORCE (weight_dims[0 ],
42
- " The first dimension of Weight must be larger than 0." );
43
- PADDLE_ENFORCE (weight_dims[1 ],
44
- " The second dimension of Weight must be larger than 0." );
45
- PADDLE_ENFORCE (weight_dims[2 ],
46
- " The third dimension of Weight must be larger than 0." );
40
+ " The input(Weight) must be a 3D tensor." );
47
41
PADDLE_ENFORCE_EQ (x_dims[0 ], y_dims[0 ],
48
- " The first dimension(batch_size) of X must be "
49
- " equal to the first dimension of the Y ." );
42
+ " The first dimension(batch_size) of input(X) must be "
43
+ " equal to the first dimension of the input(Y) ." );
50
44
PADDLE_ENFORCE_EQ (x_dims[1 ], weight_dims[1 ],
51
- " The second dimension of X must be equal to the second "
52
- " dimension of the Weight." );
45
+ " The second dimension of input(X) must be equal to "
46
+ " the second dimension of the input( Weight) ." );
53
47
PADDLE_ENFORCE_EQ (y_dims[1 ], weight_dims[2 ],
54
- " The second dimension of Y must be equal to the third "
55
- " dimension of the Weight." );
48
+ " The second dimension of input(Y) must be equal to "
49
+ " the third dimension of the input( Weight) ." );
56
50
57
51
if (ctx->HasInput (" Bias" )) {
58
52
auto bias_dims = ctx->GetInputDim (" Bias" );
59
- PADDLE_ENFORCE_EQ (bias_dims.size (), 2UL ,
60
- " The input Bias must have 2 dimensions." );
61
- PADDLE_ENFORCE_EQ (bias_dims[0 ], 1UL ,
62
- " The first dimention of input Bias must be 1." );
53
+ PADDLE_ENFORCE (bias_dims.size () == 2UL && bias_dims[0 ] == 1UL ,
54
+ " The Input(Bias) must be a 2-D tensor with "
55
+ " the 2nd dimension fixed to 1 (a row vector)." );
63
56
PADDLE_ENFORCE_EQ (bias_dims[1 ], weight_dims[0 ],
64
- " The second dimension of Bias must be equal to the "
65
- " first dimension of the Weight." );
57
+ " The second dimension of input( Bias) must be equal "
58
+ " to the first dimension of the input( Weight) ." );
66
59
}
67
60
68
61
ctx->SetOutputDim (" Out" , {x_dims[0 ], weight_dims[0 ]});
@@ -75,12 +68,13 @@ class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker {
75
68
BilinearTensorProductOpMaker (framework::OpProto* proto,
76
69
framework::OpAttrChecker* op_checker)
77
70
: OpProtoAndCheckerMaker(proto, op_checker) {
78
- AddInput (" X" , " The first input of BilinearTensorProduct op." );
79
- AddInput (" Y" , " The second input of BilinearTensorProduct op." );
80
- AddInput (" Weight" , " The input weight of BilinearTensorProduct op." );
81
- AddInput (" Bias" , " The input bias of BilinearTensorProduct op." )
71
+ AddInput (" X" , " The first input of bilinear_tensor_product operator." );
72
+ AddInput (" Y" , " The second input of bilinear_tensor_product operator." );
73
+ AddInput (" Weight" ,
74
+ " The learnable parameters of bilinear_tensor_product operator." );
75
+ AddInput (" Bias" , " The learnable bias of bilinear_tensor_product operator." )
82
76
.AsDispensable ();
83
- AddOutput (" Out" , " The output of BilinearTensorProduct op ." );
77
+ AddOutput (" Out" , " The output of bilinear_tensor_product operator ." );
84
78
AddComment (R"DOC(
85
79
Bilinear Tensor Product operator.
86
80
Given input X and Y, a 3D tensor weight, and bias. Each column of the
@@ -104,27 +98,29 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel {
104
98
PADDLE_ENFORCE (ctx->HasInput (" Weight" ),
105
99
" Input(Weight) should not be null." );
106
100
PADDLE_ENFORCE (ctx->HasInput (framework::GradVarName (" Out" )),
107
- " Input (Out@GRAD) should not be null." );
101
+ " Input(Out@GRAD) should not be null." );
108
102
auto x_dims = ctx->GetInputDim (" X" );
109
103
auto y_dims = ctx->GetInputDim (" Y" );
110
104
auto weight_dims = ctx->GetInputDim (" Weight" );
111
105
auto out_dims = ctx->GetInputDim (framework::GradVarName (" Out" ));
112
106
113
107
PADDLE_ENFORCE_EQ (out_dims.size (), 2UL ,
114
- " The Out@GRAD must be a 2D Tensor." );
108
+ " The input( Out@GRAD) must be a 2D Tensor." );
115
109
PADDLE_ENFORCE_EQ (
116
110
x_dims[0 ], out_dims[0 ],
117
- " The first dimension(batch_size) of Out@GRAD must be equal to "
118
- " the first dimension of the Input(X)." );
119
- PADDLE_ENFORCE_EQ (weight_dims[0 ], out_dims[1 ],
120
- " The second dimension of Out@GRAD must be equal to "
121
- " the third dimension of the Input(Weight)." );
111
+ " The first dimension(batch_size) of input(Out@GRAD) must be "
112
+ " equal to the first dimension of the Input(X)." );
113
+ PADDLE_ENFORCE_EQ (
114
+ weight_dims[0 ], out_dims[1 ],
115
+ " The second dimension of input(Out@GRAD) must be equal to "
116
+ " the third dimension of the Input(Weight)." );
122
117
123
118
if (ctx->HasInput (" Bias" )) {
124
119
auto bias_dims = ctx->GetInputDim (" Bias" );
125
- PADDLE_ENFORCE_EQ (bias_dims[1 ], out_dims[1 ],
126
- " The second dimension of Out@GRAD must be equal to "
127
- " the second dimension of the Input(Bias)." );
120
+ PADDLE_ENFORCE_EQ (
121
+ bias_dims[1 ], out_dims[1 ],
122
+ " The second dimension of input(Out@GRAD) must be equal to "
123
+ " the second dimension of the Input(Bias)." );
128
124
auto bias_grad_name = framework::GradVarName (" Bias" );
129
125
if (ctx->HasOutput (bias_grad_name))
130
126
ctx->SetOutputDim (bias_grad_name, bias_dims);
@@ -155,7 +151,9 @@ REGISTER_OP(bilinear_tensor_product, ops::BilinearTensorProductOp,
155
151
ops::BilinearTensorProductOpGrad);
156
152
REGISTER_OP_CPU_KERNEL (
157
153
bilinear_tensor_product,
158
- ops::BilinearTensorProductKernel<paddle::platform::CPUPlace, float >);
154
+ ops::BilinearTensorProductKernel<paddle::platform::CPUPlace, float >,
155
+ ops::BilinearTensorProductKernel<paddle::platform::CPUPlace, double >);
159
156
REGISTER_OP_CPU_KERNEL (
160
157
bilinear_tensor_product_grad,
161
- ops::BilinearTensorProductGradKernel<paddle::platform::CPUPlace, float >);
158
+ ops::BilinearTensorProductGradKernel<paddle::platform::CPUPlace, float >,
159
+ ops::BilinearTensorProductGradKernel<paddle::platform::CPUPlace, double >);
0 commit comments