Skip to content

Commit 5f99ae9

Browse files
refine notation in bilinear_tensor_product_op.h
1 parent 5cf8204 commit 5f99ae9

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

paddle/operators/bilinear_tensor_product_op.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,6 @@ template <typename T, int MajorType = Eigen::RowMajor,
2727
typename IndexType = Eigen::DenseIndex>
2828
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
2929

30-
template <typename T, int MajorType = Eigen::RowMajor,
31-
typename IndexType = Eigen::DenseIndex>
32-
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
33-
3430
template <typename Place, typename T>
3531
class BilinearTensorProductKernel : public framework::OpKernel<T> {
3632
public:
@@ -49,7 +45,9 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
4945
auto weight_dims = weight->dims();
5046
auto place = ctx.GetEigenDevice<Place>();
5147

52-
// Create the intermediate variables.
48+
// Create the intermediate variable to caculate the result of
49+
// Input(X) multiplied by Input(Weight_i), the formula is:
50+
// left_mul = X Weight_i.
5351
Tensor left_mul;
5452
left_mul.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[2]}),
5553
ctx.GetPlace());
@@ -95,31 +93,33 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
9593
auto d_out_mat = EigenMatrix<T>::From(*d_out);
9694
auto place = ctx.GetEigenDevice<Place>();
9795

98-
// Create the intermediate variables for gradient.
96+
// Create the intermediate variable to caculate the Output(Y@Grad).
9997
Tensor x_scale;
10098
x_scale.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[1]}),
10199
ctx.GetPlace());
102100
auto x_scale_mat = EigenMatrix<T>::From(x_scale);
101+
102+
// Create the intermediate variable to caculate the Output(X@Grad).
103103
Tensor y_scale;
104104
y_scale.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[2]}),
105105
ctx.GetPlace());
106106
auto y_scale_mat = EigenMatrix<T>::From(y_scale);
107107

108108
math::SetConstant<Place, T> set_zero;
109109

110-
// Set X@Grad be zero at first.
110+
// Set Output(X@Grad) be zero.
111111
if (d_x) {
112112
d_x->mutable_data<T>(ctx.GetPlace());
113113
set_zero(ctx.device_context(), d_x, static_cast<T>(0));
114114
}
115115

116-
// Set Y@Grad be zero at first.
116+
// Set Output(Y@Grad) be zero.
117117
if (d_y) {
118118
d_y->mutable_data<T>(ctx.GetPlace());
119119
set_zero(ctx.device_context(), d_y, static_cast<T>(0));
120120
}
121121

122-
// Caculate the X@Grad and Y@Grad.
122+
// Caculate the Output(X@Grad) and Output(Y@Grad).
123123
if (d_x || d_y) {
124124
Eigen::DSizes<int, 2> bcast_for_x(1, weight_dims[2]);
125125
Eigen::DSizes<int, 2> bcast_for_y(1, weight_dims[1]);
@@ -150,7 +150,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
150150
}
151151
}
152152

153-
// Caculate the gradient of Weight.
153+
// Caculate the gradient of Input(Weight).
154154
if (d_weight) {
155155
d_weight->mutable_data<T>(ctx.GetPlace());
156156
Eigen::DSizes<int, 2> bcast_for_weight(1, weight_dims[1]);
@@ -169,7 +169,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
169169
}
170170
}
171171

172-
// Caculate the gradient of Bias.
172+
// Caculate the gradient of Input(Bias).
173173
if (d_bias) {
174174
d_bias->mutable_data<T>(ctx.GetPlace());
175175
auto d_bias_mat = EigenMatrix<T>::From(*d_bias);

0 commit comments

Comments
 (0)