Skip to content

Commit 0ce9bf7

Browse files
authored
Merge pull request #5931 from guoshengCS/fix-ROIPoolOP-warn
Refine roi_pool_op to avoid warning
2 parents 95cdbfe + 19a37ec commit 0ce9bf7

File tree

1 file changed

+21
-27
lines changed

1 file changed

+21
-27
lines changed

paddle/operators/roi_pool_op.h

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -133,53 +133,47 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel<T> {
133133
auto* in = ctx.Input<framework::Tensor>("X");
134134
auto* rois = ctx.Input<framework::Tensor>("ROIs");
135135
auto* argmax = ctx.Input<framework::Tensor>("Argmax");
136-
137136
auto* out_grad =
138137
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
139-
auto* x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
138+
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
140139

141140
auto pooled_height = ctx.Attr<int>("pooled_height");
142141
auto pooled_width = ctx.Attr<int>("pooled_width");
143142

144-
if (x_grad) {
145-
int channels = in->dims()[1];
146-
auto in_stride = framework::stride(in->dims());
147-
auto roi_stride = framework::stride(rois->dims());
148-
143+
if (in_grad) {
149144
const int64_t* rois_data = rois->data<int64_t>();
150-
int rois_num = rois->dims()[0];
151-
152-
T* x_grad_data = x_grad->mutable_data<T>(ctx.GetPlace());
145+
const T* out_grad_data = out_grad->data<T>();
146+
const int64_t* argmax_data = argmax->data<int64_t>();
147+
T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
153148
math::SetConstant<Place, T> set_zero;
154-
set_zero(ctx.device_context(), x_grad, static_cast<T>(0));
149+
set_zero(ctx.device_context(), in_grad, static_cast<T>(0));
155150

156-
size_t roi_offset = roi_stride[0];
157-
size_t batch_offset = in_stride[0];
158-
size_t channel_offset = in_stride[1];
151+
auto in_stride = framework::stride(in->dims());
152+
auto argmax_stride = framework::stride(argmax->dims());
153+
auto roi_stride = framework::stride(rois->dims());
154+
auto out_stride = framework::stride(out_grad->dims());
159155

160-
const T* out_grad_data = out_grad->data<T>();
161-
size_t pool_channel_offset = pooled_height * pooled_width;
162-
const int64_t* argmax_data = argmax->data<int64_t>();
156+
int rois_num = rois->dims()[0];
157+
int channels = in->dims()[1];
163158

164-
for (size_t n = 0; n < rois_num; ++n) {
165-
size_t roi_batch_idx = rois_data[0];
166-
T* batch_grad_data = x_grad_data + batch_offset * roi_batch_idx;
159+
for (int n = 0; n < rois_num; ++n) {
160+
int roi_batch_idx = rois_data[0];
161+
T* batch_grad_data = in_grad_data + roi_batch_idx * in_stride[0];
167162
for (int c = 0; c < channels; ++c) {
168163
for (int ph = 0; ph < pooled_height; ++ph) {
169164
for (int pw = 0; pw < pooled_width; ++pw) {
170-
size_t pool_index = ph * pooled_width + pw;
171-
165+
int pool_index = ph * pooled_width + pw;
172166
if (argmax_data[pool_index] >= 0) {
173-
size_t index = static_cast<size_t>(argmax_data[pool_index]);
167+
auto index = argmax_data[pool_index];
174168
batch_grad_data[index] += out_grad_data[pool_index];
175169
}
176170
}
177171
}
178-
batch_grad_data += channel_offset;
179-
out_grad_data += pool_channel_offset;
180-
argmax_data += pool_channel_offset;
172+
batch_grad_data += in_stride[1];
173+
out_grad_data += out_stride[1];
174+
argmax_data += argmax_stride[1];
181175
}
182-
rois_data += roi_offset;
176+
rois_data += roi_stride[0];
183177
}
184178
}
185179
}

0 commit comments

Comments
 (0)