33import numpy as np
44import torch .nn .functional as F
55
6+ from ssd .modeling .multibox_loss import MultiBoxLoss
67from ssd .module import L2Norm
78from ssd .module .prior_box import PriorBox
89from ssd .utils import box_utils
@@ -24,6 +25,7 @@ def __init__(self, cfg,
2425 self .classification_headers = classification_headers
2526 self .regression_headers = regression_headers
2627 self .l2_norm = L2Norm (512 , scale = 20 )
28+ self .criterion = MultiBoxLoss (neg_pos_ratio = cfg .MODEL .NEG_POS_RATIO )
2729 self .priors = None
2830 self .reset_parameters ()
2931
@@ -38,7 +40,7 @@ def weights_init(m):
3840 self .classification_headers .apply (weights_init )
3941 self .regression_headers .apply (weights_init )
4042
41- def forward (self , x ):
43+ def forward (self , x , targets = None ):
4244 sources = []
4345 confidences = []
4446 locations = []
@@ -68,17 +70,24 @@ def forward(self, x):
6870 locations = locations .view (locations .size (0 ), - 1 , 4 )
6971
7072 if not self .training :
73+ # when evaluating, decode predictions
7174 if self .priors is None :
7275 self .priors = PriorBox (self .cfg )().to (locations .device )
7376 confidences = F .softmax (confidences , dim = 2 )
7477 boxes = box_utils .convert_locations_to_boxes (
7578 locations , self .priors , self .cfg .MODEL .CENTER_VARIANCE , self .cfg .MODEL .SIZE_VARIANCE
7679 )
7780 boxes = box_utils .center_form_to_corner_form (boxes )
78-
7981 return confidences , boxes
8082 else :
81- return confidences , locations
83+ # when training, compute losses
84+ gt_boxes , gt_labels = targets
85+ regression_loss , classification_loss = self .criterion (confidences , locations , gt_labels , gt_boxes )
86+ loss_dict = dict (
87+ regression_loss = regression_loss ,
88+ classification_loss = classification_loss ,
89+ )
90+ return loss_dict
8291
8392 def init_from_base_net (self , model ):
8493 vgg_weights = torch .load (model , map_location = lambda storage , loc : storage )
0 commit comments