@@ -256,13 +256,15 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> {
256
256
auto spatial_scale = ctx.Attr <float >(" spatial_scale" );
257
257
auto sampling_ratio = ctx.Attr <int >(" sampling_ratio" );
258
258
auto in_dims = in->dims ();
259
- if (!in_grad) {
260
- return ;
261
- }
259
+
262
260
int channels = in_dims[1 ];
263
261
int height = in_dims[2 ];
264
262
int width = in_dims[3 ];
265
263
int rois_num = rois->dims ()[0 ];
264
+
265
+ if (!in_grad) {
266
+ return ;
267
+ }
266
268
Tensor roi_batch_id_list;
267
269
roi_batch_id_list.Resize ({rois_num});
268
270
int * roi_batch_id_data =
@@ -276,14 +278,21 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> {
276
278
}
277
279
}
278
280
279
- const T* rois_data = rois->data <T>();
280
- const T* out_grad_data = out_grad->data <T>();
281
- T* in_grad_data = in_grad->mutable_data <T>(ctx.GetPlace ());
282
-
281
+ in_grad->mutable_data <T>(ctx.GetPlace ());
283
282
auto & dev_ctx = ctx.template device_context <DeviceContext>();
284
283
math::SetConstant<DeviceContext, T> set_zero;
285
284
set_zero (dev_ctx, in_grad, static_cast <T>(0 ));
286
285
286
+ int output_grad_size = out_grad->numel ();
287
+
288
+ if ((!out_grad->IsInitialized ()) || (output_grad_size <= 0 )) {
289
+ return ;
290
+ }
291
+
292
+ const T* rois_data = rois->data <T>();
293
+ const T* out_grad_data = out_grad->data <T>();
294
+ T* in_grad_data = in_grad->mutable_data <T>(ctx.GetPlace ());
295
+
287
296
auto in_stride = framework::stride (in->dims ());
288
297
auto roi_stride = framework::stride (rois->dims ());
289
298
auto out_stride = framework::stride (out_grad->dims ());
0 commit comments