Skip to content

Commit 09fa9bc

Browse files
committed
Fix PartA2 train error
Negative values in `point_part_labels` should be masked out
1 parent eba9854 commit 09fa9bc

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

pcdet/models/dense_heads/point_head_template.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,18 @@ def get_cls_layer_loss(self, tb_dict=None):
156156

157157
def get_part_layer_loss(self, tb_dict=None):
158158
pos_mask = self.forward_ret_dict['point_cls_labels'] > 0
159-
pos_normalizer = max(1, (pos_mask > 0).sum().item())
160159
point_part_labels = self.forward_ret_dict['point_part_labels']
161160
point_part_preds = self.forward_ret_dict['point_part_preds']
162-
point_loss_part = F.binary_cross_entropy(torch.sigmoid(point_part_preds), point_part_labels, reduction='none')
163-
point_loss_part = (point_loss_part.sum(dim=-1) * pos_mask.float()).sum() / (3 * pos_normalizer)
161+
valid_row_mask = (point_part_labels >= 0).all(dim=-1)
162+
valid_mask = pos_mask & valid_row_mask
163+
valid_point_part_labels = point_part_labels[valid_mask]
164+
valid_point_part_preds = point_part_preds[valid_mask]
165+
if valid_point_part_labels.numel() > 0:
166+
point_loss_part = F.binary_cross_entropy_with_logits(
167+
valid_point_part_preds, valid_point_part_labels, reduction='mean'
168+
)
169+
else:
170+
point_loss_part = torch.tensor(0.0, device=point_part_labels.device)
164171

165172
loss_weights_dict = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
166173
point_loss_part = point_loss_part * loss_weights_dict['point_part_weight']

0 commit comments

Comments
 (0)