16
16
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17
17
# SOFTWARE.
18
18
"""
19
- This module implements the task specific estimator for PyTorch YOLO v3 and v5 object detectors.
19
+ This module implements the task specific estimator for PyTorch YOLO v3, v5, v8+ object detectors.
20
20
21
21
| Paper link: https://arxiv.org/abs/1804.02767
22
22
"""
23
23
from __future__ import annotations
24
24
25
25
import logging
26
- from typing import TYPE_CHECKING
26
+ from typing import TYPE_CHECKING , Optional , Union
27
27
28
28
import numpy as np
29
+ import torch
29
30
30
31
from art .estimators .object_detection .pytorch_object_detector import PyTorchObjectDetector
31
32
40
41
logger = logging .getLogger (__name__ )
41
42
42
43
44
+ class PyTorchYoloLossWrapper (torch .nn .Module ):
45
+ """Wrapper for YOLO v8+ models to handle loss dict format."""
46
+
47
+ def __init__ (self , model , name ):
48
+ super ().__init__ ()
49
+ self .model = model
50
+ try :
51
+ from ultralytics .models .yolo .detect import DetectionPredictor
52
+ from ultralytics .utils .loss import v8DetectionLoss , E2EDetectLoss
53
+
54
+ self .detection_predictor = DetectionPredictor ()
55
+ self .model .args = self .detection_predictor .args
56
+ if 'v10' in name :
57
+ self .model .criterion = E2EDetectLoss (model )
58
+ else :
59
+ self .model .criterion = v8DetectionLoss (model )
60
+ except ImportError as e :
61
+ raise ImportError ("The 'ultralytics' package is required for YOLO v8+ models but not installed." ) from e
62
+
63
+ def forward (self , x , targets = None ):
64
+ if self .training :
65
+ # batch_idx is used to identify which predictions/boxes relate to which image
66
+ boxes = []
67
+ labels = []
68
+ indices = []
69
+ for i , item in enumerate (targets ):
70
+ boxes .append (item ['boxes' ])
71
+ labels .append (item ['labels' ])
72
+ indices = indices + ([i ]* len (item ['labels' ]))
73
+ items = {'boxes' : torch .concatenate (boxes ) / x .shape [2 ],
74
+ 'labels' : torch .concatenate (labels ).type (torch .float32 ),
75
+ 'batch_idx' : torch .tensor (indices )}
76
+ items ['bboxes' ] = items .pop ('boxes' )
77
+ items ['cls' ] = items .pop ('labels' )
78
+ items ['img' ] = x
79
+
80
+ loss , loss_components = self .model .loss (items )
81
+ loss_components_dict = {"loss_total" : loss .sum ()}
82
+ loss_components_dict ['loss_box' ] = loss_components [0 ]
83
+ loss_components_dict ['loss_cls' ] = loss_components [1 ]
84
+ loss_components_dict ['loss_dfl' ] = loss_components [2 ]
85
+ return loss_components_dict
86
+ else :
87
+ preds = self .model (x )
88
+ self .detection_predictor .model = self .model
89
+ self .detection_predictor .batch = [x ]
90
+ preds = self .detection_predictor .postprocess (preds , x , x )
91
+ # translate the preds to ART supported format
92
+ items = []
93
+ for pred in preds :
94
+ items .append ({'boxes' : pred .boxes .xyxy ,
95
+ 'scores' : pred .boxes .conf ,
96
+ 'labels' : pred .boxes .cls .type (torch .int )})
97
+ return items
98
+
99
+
43
100
class PyTorchYolo (PyTorchObjectDetector ):
44
101
"""
45
- This module implements the model- and task specific estimator for YOLO v3, v5 object detector models in PyTorch.
102
+ This module implements the model- and task specific estimator for YOLO v3, v5, v8+ object detector models in PyTorch.
46
103
47
104
| Paper link: https://arxiv.org/abs/1804.02767
48
105
"""
@@ -65,11 +122,12 @@ def __init__(
65
122
),
66
123
device_type : str = "gpu" ,
67
124
is_yolov8 : bool = False ,
125
+ model_name : str = "" ,
68
126
):
69
127
"""
70
128
Initialization.
71
129
72
- :param model: YOLO v3 or v5 model wrapped as demonstrated in examples/get_started_yolo.py.
130
+ :param model: YOLO v3, v5, or v8+ model wrapped as demonstrated in examples/get_started_yolo.py.
73
131
The output of the model is `list[dict[str, torch.Tensor]]`, one for each input image.
74
132
The fields of the dict are as follows:
75
133
@@ -93,8 +151,13 @@ def __init__(
93
151
'loss_objectness', and 'loss_rpn_box_reg'.
94
152
:param device_type: Type of device to be used for model and tensors, if `cpu` run on CPU, if `gpu` run on GPU
95
153
if available otherwise run on CPU.
96
- :param is_yolov8: The flag to be used for marking the YOLOv8 model.
154
+ :param is_yolov8: The flag to be used for marking the YOLOv8+ model.
155
+ :param model_name: The name of the model (e.g., 'yolov8n', 'yolov10n') for determining loss function.
97
156
"""
157
+ # Wrap the model with YoloWrapper if it's a YOLO v8+ model
158
+ if is_yolov8 :
159
+ model = YoloWrapper (model , model_name )
160
+
98
161
super ().__init__ (
99
162
model = model ,
100
163
input_shape = input_shape ,
@@ -154,11 +217,23 @@ def _translate_predictions(self, predictions: "torch.Tensor") -> list[dict[str,
154
217
Translate object detection predictions from the model format (YOLO) to ART format (torchvision) and
155
218
convert tensors to numpy arrays.
156
219
157
- :param predictions: Object detection labels in format xcycwh (YOLO).
220
+ :param predictions: Object detection labels in format xcycwh (YOLO) or list of dicts (YOLO v8+) .
158
221
:return: Object detection labels in format x1y1x2y2 (torchvision).
159
222
"""
160
223
import torch
161
224
225
+ # Handle YOLO v8+ predictions (list of dicts)
226
+ if isinstance (predictions , list ) and len (predictions ) > 0 and isinstance (predictions [0 ], dict ):
227
+ predictions_x1y1x2y2 : list [dict [str , np .ndarray ]] = []
228
+ for pred in predictions :
229
+ prediction = {}
230
+ prediction ["boxes" ] = pred ["boxes" ].detach ().cpu ().numpy ()
231
+ prediction ["labels" ] = pred ["labels" ].detach ().cpu ().numpy ()
232
+ prediction ["scores" ] = pred ["scores" ].detach ().cpu ().numpy ()
233
+ predictions_x1y1x2y2 .append (prediction )
234
+ return predictions_x1y1x2y2
235
+
236
+ # Handle traditional YOLO predictions (tensor format)
162
237
if self .channels_first :
163
238
height = self .input_shape [1 ]
164
239
width = self .input_shape [2 ]
0 commit comments