We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 55c2c73 + 24819df commit 9cf6036Copy full SHA for 9cf6036
paddle/operators/cos_sim_op.h
@@ -132,7 +132,7 @@ class CosSimGradKernel : public framework::OpKernel<T> {
132
// compute dy
133
if (out_grad_y) {
134
out_grad_y->mutable_data<T>(context.GetPlace());
135
- auto dy = EigenMatrix<T>::Reshape(*out_grad_y, 1);
+ auto dy = EigenVector<T>::Flatten(*out_grad_y);
136
auto grad = x / norm_prod_bcast - z_bcast * y_bcast / y_snorm_bcast;
137
dy.device(place) = (dz_bcast * grad).sum(Eigen::array<int, 1>({{0}}));
138
}
0 commit comments