@@ -43,74 +43,77 @@ def translate_predictions_xcycwh_to_x1y1x2y2(
4343 y_pred_xcycwh : "torch.Tensor" , input_height : int , input_width : int
4444) -> List [Dict [str , "torch.Tensor" ]]:
4545 """
46- Convert object detection predictions from xcycwh to x1y1x2y2 format .
46+ Convert object detection predictions from xcycwh (YOLO) to x1y1x2y2 (torchvision) .
4747
48- :param y_pred_xcycwh: Labels in format xcycwh.
49- :return: Labels in format x1y1x2y2.
48+ :param y_pred_xcycwh: Object detection labels in format xcycwh (YOLO).
49+ :param height: Height of images in pixels.
50+ :param width: Width if images in pixels.
51+ :return: Object detection labels in format x1y1x2y2 (torchvision).
5052 """
5153 import torch
5254
5355 y_pred_x1y1x2y2 = []
54-
55- for i in range (y_pred_xcycwh .shape [0 ]):
56+ device = y_pred_xcycwh .device
57+
58+ for y_pred in y_pred_xcycwh :
59+ boxes = torch .vstack (
60+ [
61+ torch .maximum ((y_pred [:, 0 ] - y_pred [:, 2 ] / 2 ), torch .tensor (0 ).to (device )),
62+ torch .maximum ((y_pred [:, 1 ] - y_pred [:, 3 ] / 2 ), torch .tensor (0 ).to (device )),
63+ torch .minimum ((y_pred [:, 0 ] + y_pred [:, 2 ] / 2 ), torch .tensor (input_height ).to (device )),
64+ torch .minimum ((y_pred [:, 1 ] + y_pred [:, 3 ] / 2 ), torch .tensor (input_width ).to (device )),
65+ ]
66+ ).permute ((1 , 0 ))
67+ labels = torch .argmax (y_pred [:, 5 :], dim = 1 , keepdim = False )
68+ scores = y_pred [:, 4 ]
5669
5770 y_i = {
58- "boxes" : torch .permute (
59- torch .vstack (
60- [
61- torch .maximum (
62- (y_pred_xcycwh [i , :, 0 ] - y_pred_xcycwh [i , :, 2 ] / 2 ),
63- torch .tensor (0 ).to (y_pred_xcycwh .device ),
64- ),
65- torch .maximum (
66- (y_pred_xcycwh [i , :, 1 ] - y_pred_xcycwh [i , :, 3 ] / 2 ),
67- torch .tensor (0 ).to (y_pred_xcycwh .device ),
68- ),
69- torch .minimum (
70- (y_pred_xcycwh [i , :, 0 ] + y_pred_xcycwh [i , :, 2 ] / 2 ),
71- torch .tensor (input_height ).to (y_pred_xcycwh .device ),
72- ),
73- torch .minimum (
74- (y_pred_xcycwh [i , :, 1 ] + y_pred_xcycwh [i , :, 3 ] / 2 ),
75- torch .tensor (input_width ).to (y_pred_xcycwh .device ),
76- ),
77- ]
78- ),
79- (1 , 0 ),
80- ),
81- "labels" : torch .argmax (y_pred_xcycwh [i , :, 5 :], dim = 1 , keepdim = False ),
82- "scores" : y_pred_xcycwh [i , :, 4 ],
71+ "boxes" : boxes ,
72+ "labels" : labels ,
73+ "scores" : scores ,
8374 }
8475
8576 y_pred_x1y1x2y2 .append (y_i )
8677
8778 return y_pred_x1y1x2y2
8879
8980
90- def translate_labels_art_to_yolov3 (labels_art : List [Dict [str , "torch.Tensor" ]]):
81+ def translate_labels_x1y1x2y2_to_xcycwh (
82+ labels_x1y1x2y2 : List [Dict [str , "torch.Tensor" ]], input_height : int , input_width : int
83+ ) -> "torch.Tensor" :
9184 """
92- Translate labels from ART to YOLO v3 and v5 .
85+ Translate object detection labels from x1y1x2y2 (torchvision) to xcycwh (YOLO) .
9386
94- :param labels_art: Object detection labels in format ART (torchvision).
95- :return: Object detection labels in format YOLO v3 and v5.
87+ :param labels_x1y1x2y2: Object detection labels in format x1y1x2y2 (torchvision).
88+ :param height: Height of images in pixels.
89+ :param width: Width if images in pixels.
90+ :return: Object detection labels in format xcycwh (YOLO).
9691 """
9792 import torch
9893
99- yolo_targets_list = []
94+ labels_xcycwh_list = []
95+
96+ for i , label_dict in enumerate (labels_x1y1x2y2 ):
97+ # create 2D tensor to encode labels and bounding boxes
98+ labels = torch .zeros (len (label_dict ["boxes" ]), 6 )
99+ labels [:, 0 ] = i
100+ labels [:, 1 ] = label_dict ["labels" ]
101+ labels [:, 2 :6 ] = label_dict ["boxes" ]
102+
103+ # normalize bounding boxes to [0, 1]
104+ labels [:, 2 :6 :2 ] /= input_width
105+ labels [:, 3 :6 :2 ] /= input_height
100106
101- for i_dict , label_dict in enumerate (labels_art ):
102- num_detectors = label_dict ["boxes" ].size ()[0 ]
103- targets = torch .zeros (num_detectors , 6 )
104- targets [:, 0 ] = i_dict
105- targets [:, 1 ] = label_dict ["labels" ]
106- targets [:, 2 :6 ] = label_dict ["boxes" ]
107- targets [:, 4 ] = targets [:, 4 ] - targets [:, 2 ]
108- targets [:, 5 ] = targets [:, 5 ] - targets [:, 3 ]
109- yolo_targets_list .append (targets )
107+ # convert from x1y1x2y2 to xcycwh
108+ labels [:, 4 ] = labels [:, 4 ] - labels [:, 2 ]
109+ labels [:, 5 ] = labels [:, 5 ] - labels [:, 3 ]
110+ labels [:, 2 ] = labels [:, 2 ] + labels [:, 4 ] / 2
111+ labels [:, 3 ] = labels [:, 3 ] + labels [:, 5 ] / 2
112+ labels_xcycwh_list .append (labels )
110113
111- yolo_targets = torch .vstack (yolo_targets_list )
114+ labels_xcycwh = torch .vstack (labels_xcycwh_list )
112115
113- return yolo_targets
116+ return labels_xcycwh
114117
115118
116119class PyTorchYolo (ObjectDetectorMixin , PyTorchEstimator ):
@@ -274,8 +277,6 @@ def _get_losses(
274277 import torch
275278
276279 self ._model .train ()
277- self .set_batchnorm (train = False )
278- self .set_dropout (train = False )
279280
280281 # Apply preprocessing
281282 if self .all_framework_preprocessing :
@@ -344,7 +345,16 @@ def _get_losses(
344345 else :
345346 raise NotImplementedError ("Combination of inputs and preprocessing not supported." )
346347
347- labels_t = translate_labels_art_to_yolov3 (labels_art = y_preprocessed )
348+ if self .channels_first :
349+ height = self .input_shape [1 ]
350+ width = self .input_shape [2 ]
351+ else :
352+ height = self .input_shape [0 ]
353+ width = self .input_shape [1 ]
354+
355+ labels_t = translate_labels_x1y1x2y2_to_xcycwh (
356+ labels_x1y1x2y2 = y_preprocessed , input_height = height , input_width = width
357+ )
348358
349359 loss_components = self ._model (inputs_t , labels_t )
350360
@@ -528,6 +538,13 @@ def fit( # pylint: disable=W0221
528538 else :
529539 x_preprocessed = torch .stack ([transform (x_i / norm_factor ).to (self .device ) for x_i in x_preprocessed ])
530540
541+ if self .channels_first :
542+ height = self .input_shape [1 ]
543+ width = self .input_shape [2 ]
544+ else :
545+ height = self .input_shape [0 ]
546+ width = self .input_shape [1 ]
547+
531548 # Convert labels into tensors, if needed
532549 if isinstance (y_preprocessed [0 ]["boxes" ], np .ndarray ):
533550 y_preprocessed_tensor = []
@@ -563,7 +580,9 @@ def fit( # pylint: disable=W0221
563580 self ._optimizer .zero_grad ()
564581
565582 # Form the loss function
566- labels_t = translate_labels_art_to_yolov3 (labels_art = o_batch )
583+ labels_t = translate_labels_x1y1x2y2_to_xcycwh (
584+ labels_x1y1x2y2 = o_batch , input_height = height , input_width = width
585+ )
567586 loss_components = self ._model (i_batch , labels_t )
568587 if isinstance (loss_components , dict ):
569588 loss = sum (loss_components .values ())
0 commit comments