Skip to content

Commit deee78a

Browse files
FDInSkyjerrywgz
authored andcommitted
fix roi_align_op cpu backward's bug (#18825)
[cherry pick]fix roi_align_op cpu backward's bug
1 parent 1b22dd2 commit deee78a

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

paddle/fluid/operators/roi_align_op.h

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -256,13 +256,15 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> {
256256
auto spatial_scale = ctx.Attr<float>("spatial_scale");
257257
auto sampling_ratio = ctx.Attr<int>("sampling_ratio");
258258
auto in_dims = in->dims();
259-
if (!in_grad) {
260-
return;
261-
}
259+
262260
int channels = in_dims[1];
263261
int height = in_dims[2];
264262
int width = in_dims[3];
265263
int rois_num = rois->dims()[0];
264+
265+
if (!in_grad) {
266+
return;
267+
}
266268
Tensor roi_batch_id_list;
267269
roi_batch_id_list.Resize({rois_num});
268270
int* roi_batch_id_data =
@@ -276,14 +278,21 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> {
276278
}
277279
}
278280

279-
const T* rois_data = rois->data<T>();
280-
const T* out_grad_data = out_grad->data<T>();
281-
T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
282-
281+
in_grad->mutable_data<T>(ctx.GetPlace());
283282
auto& dev_ctx = ctx.template device_context<DeviceContext>();
284283
math::SetConstant<DeviceContext, T> set_zero;
285284
set_zero(dev_ctx, in_grad, static_cast<T>(0));
286285

286+
int output_grad_size = out_grad->numel();
287+
288+
if ((!out_grad->IsInitialized()) || (output_grad_size <= 0)) {
289+
return;
290+
}
291+
292+
const T* rois_data = rois->data<T>();
293+
const T* out_grad_data = out_grad->data<T>();
294+
T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
295+
287296
auto in_stride = framework::stride(in->dims());
288297
auto roi_stride = framework::stride(rois->dims());
289298
auto out_stride = framework::stride(out_grad->dims());

0 commit comments

Comments
 (0)