@@ -27,10 +27,6 @@ template <typename T, int MajorType = Eigen::RowMajor,
27
27
typename IndexType = Eigen::DenseIndex>
28
28
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
29
29
30
- template <typename T, int MajorType = Eigen::RowMajor,
31
- typename IndexType = Eigen::DenseIndex>
32
- using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
33
-
34
30
template <typename Place, typename T>
35
31
class BilinearTensorProductKernel : public framework ::OpKernel<T> {
36
32
public:
@@ -49,7 +45,9 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
49
45
auto weight_dims = weight->dims ();
50
46
auto place = ctx.GetEigenDevice <Place>();
51
47
52
- // Create the intermediate variables.
48
+ // Create the intermediate variable to caculate the result of
49
+ // Input(X) multiplied by Input(Weight_i), the formula is:
50
+ // left_mul = X Weight_i.
53
51
Tensor left_mul;
54
52
left_mul.mutable_data <T>(framework::make_ddim ({batch_size, weight_dims[2 ]}),
55
53
ctx.GetPlace ());
@@ -95,31 +93,33 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
95
93
auto d_out_mat = EigenMatrix<T>::From (*d_out);
96
94
auto place = ctx.GetEigenDevice <Place>();
97
95
98
- // Create the intermediate variables for gradient .
96
+ // Create the intermediate variable to caculate the Output(Y@Grad) .
99
97
Tensor x_scale;
100
98
x_scale.mutable_data <T>(framework::make_ddim ({batch_size, weight_dims[1 ]}),
101
99
ctx.GetPlace ());
102
100
auto x_scale_mat = EigenMatrix<T>::From (x_scale);
101
+
102
+ // Create the intermediate variable to caculate the Output(X@Grad).
103
103
Tensor y_scale;
104
104
y_scale.mutable_data <T>(framework::make_ddim ({batch_size, weight_dims[2 ]}),
105
105
ctx.GetPlace ());
106
106
auto y_scale_mat = EigenMatrix<T>::From (y_scale);
107
107
108
108
math::SetConstant<Place, T> set_zero;
109
109
110
- // Set X@Grad be zero at first .
110
+ // Set Output( X@Grad) be zero.
111
111
if (d_x) {
112
112
d_x->mutable_data <T>(ctx.GetPlace ());
113
113
set_zero (ctx.device_context (), d_x, static_cast <T>(0 ));
114
114
}
115
115
116
- // Set Y@Grad be zero at first .
116
+ // Set Output( Y@Grad) be zero.
117
117
if (d_y) {
118
118
d_y->mutable_data <T>(ctx.GetPlace ());
119
119
set_zero (ctx.device_context (), d_y, static_cast <T>(0 ));
120
120
}
121
121
122
- // Caculate the X@Grad and Y@Grad.
122
+ // Caculate the Output( X@Grad) and Output( Y@Grad) .
123
123
if (d_x || d_y) {
124
124
Eigen::DSizes<int , 2 > bcast_for_x (1 , weight_dims[2 ]);
125
125
Eigen::DSizes<int , 2 > bcast_for_y (1 , weight_dims[1 ]);
@@ -150,7 +150,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
150
150
}
151
151
}
152
152
153
- // Caculate the gradient of Weight.
153
+ // Caculate the gradient of Input( Weight) .
154
154
if (d_weight) {
155
155
d_weight->mutable_data <T>(ctx.GetPlace ());
156
156
Eigen::DSizes<int , 2 > bcast_for_weight (1 , weight_dims[1 ]);
@@ -169,7 +169,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
169
169
}
170
170
}
171
171
172
- // Caculate the gradient of Bias.
172
+ // Caculate the gradient of Input( Bias) .
173
173
if (d_bias) {
174
174
d_bias->mutable_data <T>(ctx.GetPlace ());
175
175
auto d_bias_mat = EigenMatrix<T>::From (*d_bias);
0 commit comments