Skip to content

Commit edc4ec4

Browse files
committed
compute loss inside model
1 parent b6f4baa commit edc4ec4

File tree

3 files changed

+16
-14
lines changed

3 files changed

+16
-14
lines changed

ssd/engine/trainer.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def do_train(cfg, model,
4747
data_loader,
4848
optimizer,
4949
scheduler,
50-
criterion,
5150
device,
5251
args):
5352
logger = logging.getLogger("SSD.trainer")
@@ -74,14 +73,13 @@ def do_train(cfg, model,
7473
labels = labels.to(device)
7574

7675
optimizer.zero_grad()
77-
confidence, locations = model(images)
78-
regression_loss, classification_loss = criterion(confidence, locations, labels, boxes)
76+
loss_dict = model(images, targets=(boxes, labels))
7977

8078
# reduce losses over all GPUs for logging purposes
81-
loss_dict_reduced = reduce_loss_dict({'regression_loss': regression_loss, 'classification_loss': classification_loss})
79+
loss_dict_reduced = reduce_loss_dict(loss_dict)
8280
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
8381

84-
loss = regression_loss + classification_loss
82+
loss = sum(loss for loss in loss_dict.values())
8583
loss.backward()
8684
optimizer.step()
8785
trained_time += time.time() - end

ssd/modeling/ssd.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import torch.nn.functional as F
55

6+
from ssd.modeling.multibox_loss import MultiBoxLoss
67
from ssd.module import L2Norm
78
from ssd.module.prior_box import PriorBox
89
from ssd.utils import box_utils
@@ -24,6 +25,7 @@ def __init__(self, cfg,
2425
self.classification_headers = classification_headers
2526
self.regression_headers = regression_headers
2627
self.l2_norm = L2Norm(512, scale=20)
28+
self.criterion = MultiBoxLoss(neg_pos_ratio=cfg.MODEL.NEG_POS_RATIO)
2729
self.priors = None
2830
self.reset_parameters()
2931

@@ -38,7 +40,7 @@ def weights_init(m):
3840
self.classification_headers.apply(weights_init)
3941
self.regression_headers.apply(weights_init)
4042

41-
def forward(self, x):
43+
def forward(self, x, targets=None):
4244
sources = []
4345
confidences = []
4446
locations = []
@@ -68,17 +70,24 @@ def forward(self, x):
6870
locations = locations.view(locations.size(0), -1, 4)
6971

7072
if not self.training:
73+
# when evaluating, decode predictions
7174
if self.priors is None:
7275
self.priors = PriorBox(self.cfg)().to(locations.device)
7376
confidences = F.softmax(confidences, dim=2)
7477
boxes = box_utils.convert_locations_to_boxes(
7578
locations, self.priors, self.cfg.MODEL.CENTER_VARIANCE, self.cfg.MODEL.SIZE_VARIANCE
7679
)
7780
boxes = box_utils.center_form_to_corner_form(boxes)
78-
7981
return confidences, boxes
8082
else:
81-
return confidences, locations
83+
# when training, compute losses
84+
gt_boxes, gt_labels = targets
85+
regression_loss, classification_loss = self.criterion(confidences, locations, gt_labels, gt_boxes)
86+
loss_dict = dict(
87+
regression_loss=regression_loss,
88+
classification_loss=classification_loss,
89+
)
90+
return loss_dict
8291

8392
def init_from_base_net(self, model):
8493
vgg_weights = torch.load(model, map_location=lambda storage, loc: storage)

train_ssd.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from ssd.engine.inference import do_evaluation
1313
from ssd.engine.trainer import do_train
1414
from ssd.modeling.data_preprocessing import TrainAugmentation
15-
from ssd.modeling.multibox_loss import MultiBoxLoss
1615
from ssd.modeling.ssd import MatchPrior
1716
from ssd.modeling.vgg_ssd import build_ssd_model
1817
from ssd.module.prior_box import PriorBox
@@ -43,10 +42,6 @@ def train(cfg, args):
4342
# -----------------------------------------------------------------------------
4443
lr = cfg.SOLVER.LR * args.num_gpus # scale by num gpus
4544
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY)
46-
# -----------------------------------------------------------------------------
47-
# Criterion
48-
# -----------------------------------------------------------------------------
49-
criterion = MultiBoxLoss(neg_pos_ratio=cfg.MODEL.NEG_POS_RATIO)
5045

5146
# -----------------------------------------------------------------------------
5247
# Scheduler
@@ -73,7 +68,7 @@ def train(cfg, args):
7368
batch_sampler = samplers.IterationBasedBatchSampler(batch_sampler, num_iterations=cfg.SOLVER.MAX_ITER // args.num_gpus)
7469
train_loader = DataLoader(train_dataset, num_workers=4, batch_sampler=batch_sampler, pin_memory=True)
7570

76-
return do_train(cfg, model, train_loader, optimizer, scheduler, criterion, device, args)
71+
return do_train(cfg, model, train_loader, optimizer, scheduler, device, args)
7772

7873

7974
def main():

0 commit comments

Comments
 (0)