Skip to content

Commit c8bb663

Browse files
committed
Refine roi_pool_op to avoid warning
1 parent e6546ba commit c8bb663

File tree

1 file changed

+21
-28
lines changed

1 file changed

+21
-28
lines changed

paddle/operators/roi_pool_op.h

100755100644
Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -133,54 +133,47 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel<T> {
133133
auto* in = ctx.Input<framework::Tensor>("X");
134134
auto* rois = ctx.Input<framework::Tensor>("ROIs");
135135
auto* argmax = ctx.Input<framework::Tensor>("Argmax");
136-
137136
auto* out_grad =
138137
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
139-
auto* x_grad =
140-
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
138+
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
141139

142140
auto pooled_height = ctx.Attr<int>("pooled_height");
143141
auto pooled_width = ctx.Attr<int>("pooled_width");
144142

145-
if (x_grad) {
146-
int channels = in->dims()[1];
147-
auto in_stride = framework::stride(in->dims());
148-
auto roi_stride = framework::stride(rois->dims());
149-
143+
if (in_grad) {
150144
const int64_t* rois_data = rois->data<int64_t>();
151-
int rois_num = rois->dims()[0];
152-
153-
T* x_grad_data = x_grad->mutable_data<T>(ctx.GetPlace());
145+
const T* out_grad_data = out_grad->data<T>();
146+
const int64_t* argmax_data = argmax->data<int64_t>();
147+
T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
154148
math::SetConstant<Place, T> set_zero;
155-
set_zero(ctx.device_context(), x_grad, static_cast<T>(0));
149+
set_zero(ctx.device_context(), in_grad, static_cast<T>(0));
156150

157-
size_t roi_offset = roi_stride[0];
158-
size_t batch_offset = in_stride[0];
159-
size_t channel_offset = in_stride[1];
151+
auto in_stride = framework::stride(in->dims());
152+
auto argmax_stride = framework::stride(argmax->dims());
153+
auto roi_stride = framework::stride(rois->dims());
154+
auto out_stride = framework::stride(out_grad->dims());
160155

161-
const T* out_grad_data = out_grad->data<T>();
162-
size_t pool_channel_offset = pooled_height * pooled_width;
163-
const int64_t* argmax_data = argmax->data<int64_t>();
156+
int rois_num = rois->dims()[0];
157+
int channels = in->dims()[1];
164158

165-
for (size_t n = 0; n < rois_num; ++n) {
166-
size_t roi_batch_idx = rois_data[0];
167-
T* batch_grad_data = x_grad_data + batch_offset * roi_batch_idx;
159+
for (int n = 0; n < rois_num; ++n) {
160+
int roi_batch_idx = rois_data[0];
161+
T* batch_grad_data = in_grad_data + roi_batch_idx * in_stride[0];
168162
for (int c = 0; c < channels; ++c) {
169163
for (int ph = 0; ph < pooled_height; ++ph) {
170164
for (int pw = 0; pw < pooled_width; ++pw) {
171-
size_t pool_index = ph * pooled_width + pw;
172-
165+
int pool_index = ph * pooled_width + pw;
173166
if (argmax_data[pool_index] >= 0) {
174-
size_t index = static_cast<size_t>(argmax_data[pool_index]);
167+
auto index = argmax_data[pool_index];
175168
batch_grad_data[index] += out_grad_data[pool_index];
176169
}
177170
}
178171
}
179-
batch_grad_data += channel_offset;
180-
out_grad_data += pool_channel_offset;
181-
argmax_data += pool_channel_offset;
172+
batch_grad_data += in_stride[1];
173+
out_grad_data += out_stride[1];
174+
argmax_data += argmax_stride[1];
182175
}
183-
rois_data += roi_offset;
176+
rois_data += roi_stride[0];
184177
}
185178
}
186179
}

0 commit comments

Comments
 (0)