Skip to content

Commit eba9854

Browse files
committed
Fix models train error
* PointRCNN: PyTorch v2.1.0 and above add check in BCE loss * Fix open-mmlab#1682 and open-mmlab#1691 * See pytorch/pytorch#97814 * PV-RCNN: Fix open-mmlab#762
1 parent 3711628 commit eba9854

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

pcdet/models/backbones_3d/pfe/voxel_set_abstraction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def aggregate_keypoint_features_from_one_source(
327327
xyz_batch_cnt=xyz_batch_cnt,
328328
new_xyz=new_xyz,
329329
new_xyz_batch_cnt=new_xyz_batch_cnt,
330-
features=xyz_features.contiguous(),
330+
features=xyz_features.contiguous() if xyz_features is not None else None,
331331
)
332332
return pooled_features
333333

pcdet/models/roi_heads/roi_head_template.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,18 @@ def get_box_cls_layer_loss(self, forward_ret_dict):
202202
rcnn_cls = forward_ret_dict['rcnn_cls']
203203
rcnn_cls_labels = forward_ret_dict['rcnn_cls_labels'].view(-1)
204204
if loss_cfgs.CLS_LOSS == 'BinaryCrossEntropy':
205-
rcnn_cls_flat = rcnn_cls.view(-1)
206-
batch_loss_cls = F.binary_cross_entropy(torch.sigmoid(rcnn_cls_flat), rcnn_cls_labels.float(), reduction='none')
207-
cls_valid_mask = (rcnn_cls_labels >= 0).float()
208-
rcnn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp(cls_valid_mask.sum(), min=1.0)
205+
valid_mask = (rcnn_cls_labels >= 0)
206+
valid_rcnn_cls_preds = rcnn_cls.view(-1)[valid_mask]
207+
valid_rcnn_cls_labels = rcnn_cls_labels[valid_mask].float()
208+
209+
if valid_rcnn_cls_labels.numel() > 0:
210+
rcnn_loss_cls = F.binary_cross_entropy_with_logits(
211+
valid_rcnn_cls_preds,
212+
valid_rcnn_cls_labels,
213+
reduction='mean',
214+
)
215+
else:
216+
rcnn_loss_cls = torch.tensor(0.0, device=rcnn_cls.device)
209217
elif loss_cfgs.CLS_LOSS == 'CrossEntropy':
210218
batch_loss_cls = F.cross_entropy(rcnn_cls, rcnn_cls_labels, reduction='none', ignore_index=-1)
211219
cls_valid_mask = (rcnn_cls_labels >= 0).float()

0 commit comments

Comments
 (0)