1616# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
1717# SOFTWARE.
1818"""
19- This module implements the task specific estimator for PyTorch YOLO v3 and v5 object detectors.
19+ This module implements the task specific estimator for PyTorch YOLO v3, v5, v8+ object detectors.
2020
2121| Paper link: https://arxiv.org/abs/1804.02767
2222"""
2323from __future__ import annotations
2424
2525import logging
26- from typing import TYPE_CHECKING
26+ from typing import TYPE_CHECKING , Optional , Union
2727
2828import numpy as np
29+ import torch
2930
3031from art .estimators .object_detection .pytorch_object_detector import PyTorchObjectDetector
3132
4041logger = logging .getLogger (__name__ )
4142
4243
44+ class PyTorchYoloLossWrapper (torch .nn .Module ):
45+ """Wrapper for YOLO v8+ models to handle loss dict format."""
46+
47+ def __init__ (self , model , name ):
48+ super ().__init__ ()
49+ self .model = model
50+ try :
51+ from ultralytics .models .yolo .detect import DetectionPredictor
52+ from ultralytics .utils .loss import v8DetectionLoss , E2EDetectLoss
53+
54+ self .detection_predictor = DetectionPredictor ()
55+ self .model .args = self .detection_predictor .args
56+ if 'v10' in name :
57+ self .model .criterion = E2EDetectLoss (model )
58+ else :
59+ self .model .criterion = v8DetectionLoss (model )
60+ except ImportError as e :
61+ raise ImportError ("The 'ultralytics' package is required for YOLO v8+ models but not installed." ) from e
62+
63+ def forward (self , x , targets = None ):
64+ if self .training :
65+ # batch_idx is used to identify which predictions/boxes relate to which image
66+ boxes = []
67+ labels = []
68+ indices = []
69+ for i , item in enumerate (targets ):
70+ boxes .append (item ['boxes' ])
71+ labels .append (item ['labels' ])
72+ indices = indices + ([i ]* len (item ['labels' ]))
73+ items = {'boxes' : torch .concatenate (boxes ) / x .shape [2 ],
74+ 'labels' : torch .concatenate (labels ).type (torch .float32 ),
75+ 'batch_idx' : torch .tensor (indices )}
76+ items ['bboxes' ] = items .pop ('boxes' )
77+ items ['cls' ] = items .pop ('labels' )
78+ items ['img' ] = x
79+
80+ loss , loss_components = self .model .loss (items )
81+ loss_components_dict = {"loss_total" : loss .sum ()}
82+ loss_components_dict ['loss_box' ] = loss_components [0 ]
83+ loss_components_dict ['loss_cls' ] = loss_components [1 ]
84+ loss_components_dict ['loss_dfl' ] = loss_components [2 ]
85+ return loss_components_dict
86+ else :
87+ preds = self .model (x )
88+ self .detection_predictor .model = self .model
89+ self .detection_predictor .batch = [x ]
90+ preds = self .detection_predictor .postprocess (preds , x , x )
91+ # translate the preds to ART supported format
92+ items = []
93+ for pred in preds :
94+ items .append ({'boxes' : pred .boxes .xyxy ,
95+ 'scores' : pred .boxes .conf ,
96+ 'labels' : pred .boxes .cls .type (torch .int )})
97+ return items
98+
99+
43100class PyTorchYolo (PyTorchObjectDetector ):
44101 """
45- This module implements the model- and task specific estimator for YOLO v3, v5 object detector models in PyTorch.
102+ This module implements the model- and task specific estimator for YOLO v3, v5, v8+ object detector models in PyTorch.
46103
47104 | Paper link: https://arxiv.org/abs/1804.02767
48105 """
@@ -65,11 +122,12 @@ def __init__(
65122 ),
66123 device_type : str = "gpu" ,
67124 is_yolov8 : bool = False ,
125+ model_name : str = "" ,
68126 ):
69127 """
70128 Initialization.
71129
72- :param model: YOLO v3 or v5 model wrapped as demonstrated in examples/get_started_yolo.py.
130+ :param model: YOLO v3, v5, or v8+ model wrapped as demonstrated in examples/get_started_yolo.py.
73131 The output of the model is `list[dict[str, torch.Tensor]]`, one for each input image.
74132 The fields of the dict are as follows:
75133
@@ -93,8 +151,13 @@ def __init__(
93151 'loss_objectness', and 'loss_rpn_box_reg'.
94152 :param device_type: Type of device to be used for model and tensors, if `cpu` run on CPU, if `gpu` run on GPU
95153 if available otherwise run on CPU.
96- :param is_yolov8: The flag to be used for marking the YOLOv8 model.
154+ :param is_yolov8: The flag to be used for marking the YOLOv8+ model.
155+ :param model_name: The name of the model (e.g., 'yolov8n', 'yolov10n') for determining loss function.
97156 """
157+ # Wrap the model with YoloWrapper if it's a YOLO v8+ model
158+ if is_yolov8 :
159+ model = YoloWrapper (model , model_name )
160+
98161 super ().__init__ (
99162 model = model ,
100163 input_shape = input_shape ,
@@ -154,11 +217,23 @@ def _translate_predictions(self, predictions: "torch.Tensor") -> list[dict[str,
154217 Translate object detection predictions from the model format (YOLO) to ART format (torchvision) and
155218 convert tensors to numpy arrays.
156219
157- :param predictions: Object detection labels in format xcycwh (YOLO).
220+ :param predictions: Object detection labels in format xcycwh (YOLO) or list of dicts (YOLO v8+) .
158221 :return: Object detection labels in format x1y1x2y2 (torchvision).
159222 """
160223 import torch
161224
225+ # Handle YOLO v8+ predictions (list of dicts)
226+ if isinstance (predictions , list ) and len (predictions ) > 0 and isinstance (predictions [0 ], dict ):
227+ predictions_x1y1x2y2 : list [dict [str , np .ndarray ]] = []
228+ for pred in predictions :
229+ prediction = {}
230+ prediction ["boxes" ] = pred ["boxes" ].detach ().cpu ().numpy ()
231+ prediction ["labels" ] = pred ["labels" ].detach ().cpu ().numpy ()
232+ prediction ["scores" ] = pred ["scores" ].detach ().cpu ().numpy ()
233+ predictions_x1y1x2y2 .append (prediction )
234+ return predictions_x1y1x2y2
235+
236+ # Handle traditional YOLO predictions (tensor format)
162237 if self .channels_first :
163238 height = self .input_shape [1 ]
164239 width = self .input_shape [2 ]
0 commit comments