Skip to content

Commit fcbe84c

Browse files
authored
Merge pull request #14270 from sneaxiy/fix_rmsprop_enforce_bug
Fix rmsprop_op enforce bug
2 parents 45bad76 + 11f032a commit fcbe84c

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

paddle/fluid/operators/rmsprop_op.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,8 @@ class RmspropOpKernel : public framework::OpKernel<T> {
179179
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
180180
auto mg = EigenVector<T>::Flatten(mg_tensor);
181181
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
182-
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
183-
"MeanGrad and MeanGradOut must be the same Tensor");
182+
PADDLE_ENFORCE_EQ(&mg_tensor, mean_grad_out,
183+
"MeanGrad and MeanGradOut must be the same Tensor");
184184
auto mg_out = EigenVector<T>::Flatten(*mean_grad_out);
185185

186186
mg_out.device(place) = rho * mg + (1 - rho) * g;
@@ -198,8 +198,8 @@ class RmspropOpKernel : public framework::OpKernel<T> {
198198
if (centered) {
199199
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
200200
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
201-
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
202-
"MeanGrad and MeanGradOut must be the same Tensor");
201+
PADDLE_ENFORCE_EQ(&mg_tensor, mean_grad_out,
202+
"MeanGrad and MeanGradOut must be the same Tensor");
203203
for_range(CenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>(
204204
param_out->mutable_data<T>(ctx.GetPlace()),
205205
mean_square_out->mutable_data<T>(ctx.GetPlace()),
@@ -243,8 +243,8 @@ class RmspropOpKernel : public framework::OpKernel<T> {
243243
if (centered) {
244244
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
245245
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
246-
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
247-
"MeanGrad and MeanGradOut must be the same Tensor");
246+
PADDLE_ENFORCE_EQ(&mg_tensor, mean_grad_out,
247+
"MeanGrad and MeanGradOut must be the same Tensor");
248248
for_range(CenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>(
249249
param_out->mutable_data<T>(ctx.GetPlace()),
250250
mean_square_out->mutable_data<T>(ctx.GetPlace()),

0 commit comments

Comments
 (0)