16
16
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17
17
# SOFTWARE.
18
18
"""
19
- This module implements the task specific estimator for PyTorch YOLO v3 and v5 object detectors.
19
+ This module implements the task specific estimator for PyTorch YOLO v3, v5, v8+ object detectors.
20
20
21
21
| Paper link: https://arxiv.org/abs/1804.02767
22
22
"""
42
42
43
43
class PyTorchYolo (PyTorchObjectDetector ):
44
44
"""
45
- This module implements the model- and task specific estimator for YOLO v3, v5 object detector models in PyTorch.
45
+ This module implements the model- and task specific estimator for YOLO object detector models in PyTorch.
46
46
47
47
| Paper link: https://arxiv.org/abs/1804.02767
48
48
"""
@@ -65,11 +65,12 @@ def __init__(
65
65
),
66
66
device_type : str = "gpu" ,
67
67
is_yolov8 : bool = False ,
68
+ model_name : str | None = None ,
68
69
):
69
70
"""
70
71
Initialization.
71
72
72
- :param model: YOLO v3 or v5 model wrapped as demonstrated in examples/get_started_yolo.py.
73
+ :param model: YOLO v3, v5, or v8+ model wrapped as demonstrated in examples/get_started_yolo.py.
73
74
The output of the model is `list[dict[str, torch.Tensor]]`, one for each input image.
74
75
The fields of the dict are as follows:
75
76
@@ -93,8 +94,15 @@ def __init__(
93
94
'loss_objectness', and 'loss_rpn_box_reg'.
94
95
:param device_type: Type of device to be used for model and tensors, if `cpu` run on CPU, if `gpu` run on GPU
95
96
if available otherwise run on CPU.
96
- :param is_yolov8: The flag to be used for marking the YOLOv8 model.
97
+ :param is_yolov8: The flag to be used for marking the YOLOv8+ model.
98
+ :param model_name: The name of the model (e.g., 'yolov8n', 'yolov10n') for determining loss function.
97
99
"""
100
+ # Wrap the model with YoloWrapper if it's a YOLO v8+ model
101
+ if is_yolov8 :
102
+ from art .estimators .object_detection .pytorch_yolo_loss_wrapper import PyTorchYoloLossWrapper
103
+
104
+ model = PyTorchYoloLossWrapper (model , model_name )
105
+
98
106
super ().__init__ (
99
107
model = model ,
100
108
input_shape = input_shape ,
@@ -154,20 +162,31 @@ def _translate_predictions(self, predictions: "torch.Tensor") -> list[dict[str,
154
162
Translate object detection predictions from the model format (YOLO) to ART format (torchvision) and
155
163
convert tensors to numpy arrays.
156
164
157
- :param predictions: Object detection labels in format xcycwh (YOLO).
165
+ :param predictions: Object detection labels in format xcycwh (YOLO) or list of dicts (YOLO v8+) .
158
166
:return: Object detection labels in format x1y1x2y2 (torchvision).
159
167
"""
160
168
import torch
161
169
170
+ predictions_x1y1x2y2 : list [dict [str , np .ndarray ]] = []
171
+
172
+ # Handle YOLO v8+ predictions (list of dicts)
173
+ if isinstance (predictions , list ) and len (predictions ) > 0 and isinstance (predictions [0 ], dict ):
174
+ for pred in predictions :
175
+ prediction = {}
176
+ prediction ["boxes" ] = pred ["boxes" ].detach ().cpu ().numpy ()
177
+ prediction ["labels" ] = pred ["labels" ].detach ().cpu ().numpy ()
178
+ prediction ["scores" ] = pred ["scores" ].detach ().cpu ().numpy ()
179
+ predictions_x1y1x2y2 .append (prediction )
180
+ return predictions_x1y1x2y2
181
+
182
+ # Handle traditional YOLO predictions (tensor format)
162
183
if self .channels_first :
163
184
height = self .input_shape [1 ]
164
185
width = self .input_shape [2 ]
165
186
else :
166
187
height = self .input_shape [0 ]
167
188
width = self .input_shape [1 ]
168
189
169
- predictions_x1y1x2y2 : list [dict [str , np .ndarray ]] = []
170
-
171
190
for pred in predictions :
172
191
boxes = torch .vstack (
173
192
[
0 commit comments