Skip to content

Commit 3794027

Browse files
author
Abhinav Arora
committed
Fix warnings in sgd_op.h
1 parent 59234b7 commit 3794027

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

paddle/fluid/operators/sgd_op.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,16 @@ class SGDOpKernel : public framework::OpKernel<T> {
6565
auto &grad_rows = grad->rows();
6666

6767
size_t grad_row_numel = grad_value.numel() / grad_rows.size();
68-
PADDLE_ENFORCE_EQ(grad_row_numel, param_out->numel() / grad_height);
68+
PADDLE_ENFORCE_EQ(static_cast<int64_t>(grad_row_numel),
69+
param_out->numel() / grad_height);
6970

7071
auto *grad_data = grad_value.data<T>();
7172
auto *out_data = param_out->data<T>();
7273
auto *lr = learning_rate->data<T>();
7374
for (size_t i = 0; i < grad_rows.size(); i++) {
7475
PADDLE_ENFORCE(grad_rows[i] < grad_height,
7576
"Input rows index should less than height");
76-
for (int64_t j = 0; j < grad_row_numel; j++) {
77+
for (size_t j = 0; j < grad_row_numel; j++) {
7778
out_data[grad_rows[i] * grad_row_numel + j] -=
7879
lr[0] * grad_data[i * grad_row_numel + j];
7980
}
@@ -107,7 +108,7 @@ class SGDOpKernel : public framework::OpKernel<T> {
107108
PADDLE_ENFORCE(grad.rows()[i] < grad.height(),
108109
"Input rows index should less than height");
109110
int64_t id_index = param.index(grad.rows()[i]);
110-
for (int64_t j = 0; j < grad_row_width; j++) {
111+
for (size_t j = 0; j < grad_row_width; j++) {
111112
out_data[id_index * grad_row_width + j] -=
112113
lr[0] * grad_data[i * grad_row_width + j];
113114
}

0 commit comments

Comments
 (0)