1616# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
1717# SOFTWARE.
1818"""
19- This module implements the task specific estimator for PyTorch YOLO v3 and v5 object detectors.
19+ This module implements the task specific estimator for PyTorch YOLO v3, v5, v8+ object detectors.
2020
2121| Paper link: https://arxiv.org/abs/1804.02767
2222"""
4242
4343class PyTorchYolo (PyTorchObjectDetector ):
4444 """
45- This module implements the model- and task specific estimator for YOLO v3, v5 object detector models in PyTorch.
45+ This module implements the model- and task specific estimator for YOLO object detector models in PyTorch.
4646
4747 | Paper link: https://arxiv.org/abs/1804.02767
4848 """
@@ -65,11 +65,12 @@ def __init__(
6565 ),
6666 device_type : str = "gpu" ,
6767 is_yolov8 : bool = False ,
68+ model_name : str | None = None ,
6869 ):
6970 """
7071 Initialization.
7172
72- :param model: YOLO v3 or v5 model wrapped as demonstrated in examples/get_started_yolo.py.
73+ :param model: YOLO v3, v5, or v8+ model wrapped as demonstrated in examples/get_started_yolo.py.
7374 The output of the model is `list[dict[str, torch.Tensor]]`, one for each input image.
7475 The fields of the dict are as follows:
7576
@@ -93,8 +94,15 @@ def __init__(
9394 'loss_objectness', and 'loss_rpn_box_reg'.
9495 :param device_type: Type of device to be used for model and tensors, if `cpu` run on CPU, if `gpu` run on GPU
9596 if available otherwise run on CPU.
96- :param is_yolov8: The flag to be used for marking the YOLOv8 model.
97+ :param is_yolov8: The flag to be used for marking the YOLOv8+ model.
98+ :param model_name: The name of the model (e.g., 'yolov8n', 'yolov10n') for determining loss function.
9799 """
100+ # Wrap the model with YoloWrapper if it's a YOLO v8+ model
101+ if is_yolov8 :
102+ from art .estimators .object_detection .pytorch_yolo_loss_wrapper import PyTorchYoloLossWrapper
103+
104+ model = PyTorchYoloLossWrapper (model , model_name )
105+
98106 super ().__init__ (
99107 model = model ,
100108 input_shape = input_shape ,
@@ -154,20 +162,31 @@ def _translate_predictions(self, predictions: "torch.Tensor") -> list[dict[str,
154162 Translate object detection predictions from the model format (YOLO) to ART format (torchvision) and
155163 convert tensors to numpy arrays.
156164
157- :param predictions: Object detection labels in format xcycwh (YOLO).
165+ :param predictions: Object detection labels in format xcycwh (YOLO) or list of dicts (YOLO v8+) .
158166 :return: Object detection labels in format x1y1x2y2 (torchvision).
159167 """
160168 import torch
161169
170+ predictions_x1y1x2y2 : list [dict [str , np .ndarray ]] = []
171+
172+ # Handle YOLO v8+ predictions (list of dicts)
173+ if isinstance (predictions , list ) and len (predictions ) > 0 and isinstance (predictions [0 ], dict ):
174+ for pred in predictions :
175+ prediction = {}
176+ prediction ["boxes" ] = pred ["boxes" ].detach ().cpu ().numpy ()
177+ prediction ["labels" ] = pred ["labels" ].detach ().cpu ().numpy ()
178+ prediction ["scores" ] = pred ["scores" ].detach ().cpu ().numpy ()
179+ predictions_x1y1x2y2 .append (prediction )
180+ return predictions_x1y1x2y2
181+
182+ # Handle traditional YOLO predictions (tensor format)
162183 if self .channels_first :
163184 height = self .input_shape [1 ]
164185 width = self .input_shape [2 ]
165186 else :
166187 height = self .input_shape [0 ]
167188 width = self .input_shape [1 ]
168189
169- predictions_x1y1x2y2 : list [dict [str , np .ndarray ]] = []
170-
171190 for pred in predictions :
172191 boxes = torch .vstack (
173192 [
0 commit comments