@@ -179,8 +179,8 @@ class RmspropOpKernel : public framework::OpKernel<T> {
179
179
auto &mg_tensor = *ctx.Input <LoDTensor>(" MeanGrad" );
180
180
auto mg = EigenVector<T>::Flatten (mg_tensor);
181
181
auto *mean_grad_out = ctx.Output <LoDTensor>(" MeanGradOut" );
182
- PADDLE_ENFORCE (&mg_tensor, mean_grad_out,
183
- " MeanGrad and MeanGradOut must be the same Tensor" );
182
+ PADDLE_ENFORCE_EQ (&mg_tensor, mean_grad_out,
183
+ " MeanGrad and MeanGradOut must be the same Tensor" );
184
184
auto mg_out = EigenVector<T>::Flatten (*mean_grad_out);
185
185
186
186
mg_out.device (place) = rho * mg + (1 - rho) * g;
@@ -198,8 +198,8 @@ class RmspropOpKernel : public framework::OpKernel<T> {
198
198
if (centered) {
199
199
auto &mg_tensor = *ctx.Input <LoDTensor>(" MeanGrad" );
200
200
auto *mean_grad_out = ctx.Output <LoDTensor>(" MeanGradOut" );
201
- PADDLE_ENFORCE (&mg_tensor, mean_grad_out,
202
- " MeanGrad and MeanGradOut must be the same Tensor" );
201
+ PADDLE_ENFORCE_EQ (&mg_tensor, mean_grad_out,
202
+ " MeanGrad and MeanGradOut must be the same Tensor" );
203
203
for_range (CenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>(
204
204
param_out->mutable_data <T>(ctx.GetPlace ()),
205
205
mean_square_out->mutable_data <T>(ctx.GetPlace ()),
@@ -243,8 +243,8 @@ class RmspropOpKernel : public framework::OpKernel<T> {
243
243
if (centered) {
244
244
auto &mg_tensor = *ctx.Input <LoDTensor>(" MeanGrad" );
245
245
auto *mean_grad_out = ctx.Output <LoDTensor>(" MeanGradOut" );
246
- PADDLE_ENFORCE (&mg_tensor, mean_grad_out,
247
- " MeanGrad and MeanGradOut must be the same Tensor" );
246
+ PADDLE_ENFORCE_EQ (&mg_tensor, mean_grad_out,
247
+ " MeanGrad and MeanGradOut must be the same Tensor" );
248
248
for_range (CenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>(
249
249
param_out->mutable_data <T>(ctx.GetPlace ()),
250
250
mean_square_out->mutable_data <T>(ctx.GetPlace ()),
0 commit comments