@@ -28,39 +28,61 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel {
28
28
29
29
protected:
30
30
void InferShape (framework::InferShapeContext* ctx) const override {
31
- PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) should not be null." );
32
- PADDLE_ENFORCE (ctx->HasInput (" Y" ), " Input(Y) should not be null." );
33
- PADDLE_ENFORCE (ctx->HasInput (" Weight" ),
34
- " Input(Weight) should not be null." );
35
- PADDLE_ENFORCE (ctx->HasOutput (" Out" ), " Output(Out) should not be null." );
31
+ PADDLE_ENFORCE_EQ (
32
+ ctx->HasInput (" X" ), true ,
33
+ platform::errors::InvalidArgument (" Input(X) should not be null." ));
34
+ PADDLE_ENFORCE_EQ (
35
+ ctx->HasInput (" Y" ), true ,
36
+ platform::errors::InvalidArgument (" Input(Y) should not be null." ));
37
+ PADDLE_ENFORCE_EQ (
38
+ ctx->HasInput (" Weight" ), true ,
39
+ platform::errors::InvalidArgument (" Input(Weight) should not be null." ));
40
+ PADDLE_ENFORCE_EQ (
41
+ ctx->HasOutput (" Out" ), true ,
42
+ platform::errors::InvalidArgument (" Output(Out) should not be null." ));
36
43
auto x_dims = ctx->GetInputDim (" X" );
37
44
auto y_dims = ctx->GetInputDim (" Y" );
38
45
auto weight_dims = ctx->GetInputDim (" Weight" );
39
46
40
- PADDLE_ENFORCE_EQ (x_dims.size (), 2UL , " The input(X) must be a 2D Tensor." );
41
- PADDLE_ENFORCE_EQ (y_dims.size (), 2UL , " The input(Y) must be a 2D Tensor." );
47
+ PADDLE_ENFORCE_EQ (
48
+ x_dims.size (), 2UL ,
49
+ platform::errors::InvalidArgument (" The input(X) must be a 2D Tensor." ));
50
+ PADDLE_ENFORCE_EQ (
51
+ y_dims.size (), 2UL ,
52
+ platform::errors::InvalidArgument (" The input(Y) must be a 2D Tensor." ));
42
53
PADDLE_ENFORCE_EQ (weight_dims.size (), 3UL ,
43
- " The input(Weight) must be a 3D tensor." );
54
+ platform::errors::InvalidArgument (
55
+ " The input(Weight) must be a 3D tensor." ));
44
56
if (ctx->IsRuntime () || (x_dims[0 ] > 0 && y_dims[0 ] > 0 )) {
45
- PADDLE_ENFORCE_EQ (x_dims[0 ], y_dims[0 ],
46
- " The first dimension(batch_size) of input(X) must be "
47
- " equal to the first dimension of the input(Y)." );
57
+ PADDLE_ENFORCE_EQ (
58
+ x_dims[0 ], y_dims[0 ],
59
+ platform::errors::InvalidArgument (
60
+ " The first dimension(batch_size) of input(X) must be "
61
+ " equal to the first dimension of the input(Y)." ));
48
62
}
49
63
PADDLE_ENFORCE_EQ (x_dims[1 ], weight_dims[1 ],
50
- " The second dimension of input(X) must be equal to "
51
- " the second dimension of the input(Weight)." );
64
+ platform::errors::InvalidArgument (
65
+ " The second dimension of input(X) must be equal to "
66
+ " the second dimension of the input(Weight)." ));
52
67
PADDLE_ENFORCE_EQ (y_dims[1 ], weight_dims[2 ],
53
- " The second dimension of input(Y) must be equal to "
54
- " the third dimension of the input(Weight)." );
68
+ platform::errors::InvalidArgument (
69
+ " The second dimension of input(Y) must be equal to "
70
+ " the third dimension of the input(Weight)." ));
55
71
56
72
if (ctx->HasInput (" Bias" )) {
57
73
auto bias_dims = ctx->GetInputDim (" Bias" );
58
- PADDLE_ENFORCE (bias_dims.size () == 2UL && bias_dims[0 ] == 1UL ,
59
- " The Input(Bias) must be a 2-D tensor with "
60
- " the 2nd dimension fixed to 1 (a row vector)." );
74
+ PADDLE_ENFORCE_EQ (bias_dims.size (), 2UL ,
75
+ platform::errors::InvalidArgument (
76
+ " The Input(Bias) must be a 2-D tensor with "
77
+ " the 2nd dimension fixed to 1 (a row vector)." ));
78
+ PADDLE_ENFORCE_EQ (bias_dims[0 ], 1UL ,
79
+ platform::errors::InvalidArgument (
80
+ " The Input(Bias) must be a 2-D tensor with "
81
+ " the 2nd dimension fixed to 1 (a row vector)." ));
61
82
PADDLE_ENFORCE_EQ (bias_dims[1 ], weight_dims[0 ],
62
- " The second dimension of input(Bias) must be equal "
63
- " to the first dimension of the input(Weight)." );
83
+ platform::errors::InvalidArgument (
84
+ " The second dimension of input(Bias) must be equal "
85
+ " to the first dimension of the input(Weight)." ));
64
86
}
65
87
66
88
ctx->SetOutputDim (" Out" , {x_dims[0 ], weight_dims[0 ]});
@@ -104,27 +126,36 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel {
104
126
105
127
protected:
106
128
void InferShape (framework::InferShapeContext* ctx) const override {
107
- PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) should not be null." );
108
- PADDLE_ENFORCE (ctx->HasInput (" Y" ), " Input(Y) should not be null." );
109
- PADDLE_ENFORCE (ctx->HasInput (" Weight" ),
110
- " Input(Weight) should not be null." );
111
- PADDLE_ENFORCE (ctx->HasInput (framework::GradVarName (" Out" )),
112
- " Input(Out@GRAD) should not be null." );
129
+ PADDLE_ENFORCE_EQ (
130
+ ctx->HasInput (" X" ), true ,
131
+ platform::errors::InvalidArgument (" Input(X) should not be null." ));
132
+ PADDLE_ENFORCE_EQ (
133
+ ctx->HasInput (" Y" ), true ,
134
+ platform::errors::InvalidArgument (" Input(Y) should not be null." ));
135
+ PADDLE_ENFORCE_EQ (
136
+ ctx->HasInput (" Weight" ), true ,
137
+ platform::errors::InvalidArgument (" Input(Weight) should not be null." ));
138
+ PADDLE_ENFORCE_EQ (ctx->HasInput (framework::GradVarName (" Out" )), true ,
139
+ platform::errors::InvalidArgument (
140
+ " Input(Out@GRAD) should not be null." ));
113
141
auto x_dims = ctx->GetInputDim (" X" );
114
142
auto y_dims = ctx->GetInputDim (" Y" );
115
143
auto weight_dims = ctx->GetInputDim (" Weight" );
116
144
auto out_dims = ctx->GetInputDim (framework::GradVarName (" Out" ));
117
145
118
146
PADDLE_ENFORCE_EQ (out_dims.size (), 2UL ,
119
- " The input(Out@GRAD) must be a 2D Tensor." );
147
+ platform::errors::InvalidArgument (
148
+ " The input(Out@GRAD) must be a 2D Tensor." ));
120
149
PADDLE_ENFORCE_EQ (
121
150
x_dims[0 ], out_dims[0 ],
122
- " The first dimension(batch_size) of input(Out@GRAD) must be "
123
- " equal to the first dimension of the Input(X)." );
151
+ platform::errors::InvalidArgument (
152
+ " The first dimension(batch_size) of input(Out@GRAD) must be "
153
+ " equal to the first dimension of the Input(X)." ));
124
154
PADDLE_ENFORCE_EQ (
125
155
weight_dims[0 ], out_dims[1 ],
126
- " The second dimension of input(Out@GRAD) must be equal to "
127
- " the third dimension of the Input(Weight)." );
156
+ platform::errors::InvalidArgument (
157
+ " The second dimension of input(Out@GRAD) must be equal to "
158
+ " the third dimension of the Input(Weight)." ));
128
159
129
160
auto bias_grad_name = framework::GradVarName (" Bias" );
130
161
if (ctx->HasOutput (bias_grad_name)) {
0 commit comments