@@ -157,14 +157,15 @@ class DummyModel(torch.nn.Module):
157157 def __init__ (self ):
158158 super ().__init__ ()
159159 self .criterion = None # Will be set by wrapper
160+
160161 def loss (self , items ):
161162 # Return different loss components based on model version
162163 return (torch .tensor ([1.0 , 2.0 , 3.0 ]), [torch .tensor (1.0 ), torch .tensor (2.0 ), torch .tensor (3.0 )])
163164
164165 # Mock ultralytics imports
165166 import sys
166167 import types
167-
168+
168169 def create_mock_imports ():
169170 return types .SimpleNamespace (
170171 models = types .SimpleNamespace (
@@ -173,10 +174,7 @@ def create_mock_imports():
173174 )
174175 ),
175176 utils = types .SimpleNamespace (
176- loss = types .SimpleNamespace (
177- v8DetectionLoss = lambda m : "v8_loss" ,
178- E2EDetectLoss = lambda m : "v10_loss"
179- )
177+ loss = types .SimpleNamespace (v8DetectionLoss = lambda m : "v8_loss" , E2EDetectLoss = lambda m : "v10_loss" )
180178 ),
181179 )
182180
@@ -208,13 +206,15 @@ def test_yolov8_inference_mode():
208206 class DummyYoloV8Model (torch .nn .Module ):
209207 def __init__ (self ):
210208 super ().__init__ ()
209+
211210 def forward (self , x ):
212211 # Return format matching YOLO v8+ output structure
213212 return [{"boxes" : torch .ones (1 , 4 ), "scores" : torch .ones (1 ), "labels" : torch .zeros (1 )}]
214213
215214 # Mock ultralytics imports
216215 import sys
217216 import types
217+
218218 ultralytics_mock = types .SimpleNamespace (
219219 models = types .SimpleNamespace (
220220 yolo = types .SimpleNamespace (
@@ -228,10 +228,10 @@ def forward(self, x):
228228 boxes = types .SimpleNamespace (
229229 xyxy = torch .tensor ([[1.0 , 2.0 , 3.0 , 4.0 ]]),
230230 conf = torch .tensor ([0.95 ]),
231- cls = torch .tensor ([1 ])
231+ cls = torch .tensor ([1 ]),
232232 )
233233 )
234- ]
234+ ],
235235 )
236236 )
237237 )
@@ -273,6 +273,7 @@ def test_yolov8_training_data_format():
273273 class DummyModel (torch .nn .Module ):
274274 def __init__ (self ):
275275 super ().__init__ ()
276+
276277 def loss (self , items ):
277278 # Validate input format matches expected YOLO v8+ training format
278279 assert "bboxes" in items
@@ -284,6 +285,7 @@ def loss(self, items):
284285 # Setup mock imports
285286 import sys
286287 import types
288+
287289 ultralytics_mock = types .SimpleNamespace (
288290 models = types .SimpleNamespace (
289291 yolo = types .SimpleNamespace (
@@ -313,14 +315,10 @@ def loss(self, items):
313315 for box_count in box_counts :
314316 x = torch .zeros ((batch_size , 3 , 416 , 416 ))
315317 targets = [
316- {
317- "boxes" : torch .zeros ((box_count , 4 )),
318- "labels" : torch .zeros (box_count )
319- }
320- for _ in range (batch_size )
318+ {"boxes" : torch .zeros ((box_count , 4 )), "labels" : torch .zeros (box_count )} for _ in range (batch_size )
321319 ]
322320 losses = wrapper (x , targets )
323-
321+
324322 # Verify loss structure
325323 assert set (losses .keys ()) == {"loss_total" , "loss_box" , "loss_cls" , "loss_dfl" }
326324 assert all (isinstance (v , torch .Tensor ) for v in losses .values ())
0 commit comments