File tree Expand file tree Collapse file tree 1 file changed +10
-3
lines changed Expand file tree Collapse file tree 1 file changed +10
-3
lines changed Original file line number Diff line number Diff line change @@ -156,11 +156,18 @@ def get_cls_layer_loss(self, tb_dict=None):
156156
157157 def get_part_layer_loss (self , tb_dict = None ):
158158 pos_mask = self .forward_ret_dict ['point_cls_labels' ] > 0
159- pos_normalizer = max (1 , (pos_mask > 0 ).sum ().item ())
160159 point_part_labels = self .forward_ret_dict ['point_part_labels' ]
161160 point_part_preds = self .forward_ret_dict ['point_part_preds' ]
162- point_loss_part = F .binary_cross_entropy (torch .sigmoid (point_part_preds ), point_part_labels , reduction = 'none' )
163- point_loss_part = (point_loss_part .sum (dim = - 1 ) * pos_mask .float ()).sum () / (3 * pos_normalizer )
161+ valid_row_mask = (point_part_labels >= 0 ).all (dim = - 1 )
162+ valid_mask = pos_mask & valid_row_mask
163+ valid_point_part_labels = point_part_labels [valid_mask ]
164+ valid_point_part_preds = point_part_preds [valid_mask ]
165+ if valid_point_part_labels .numel () > 0 :
166+ point_loss_part = F .binary_cross_entropy_with_logits (
167+ valid_point_part_preds , valid_point_part_labels , reduction = 'mean'
168+ )
169+ else :
170+ point_loss_part = torch .tensor (0.0 , device = point_part_labels .device )
164171
165172 loss_weights_dict = self .model_cfg .LOSS_CONFIG .LOSS_WEIGHTS
166173 point_loss_part = point_loss_part * loss_weights_dict ['point_part_weight' ]
You can’t perform that action at this time.
0 commit comments