@@ -157,14 +157,15 @@ class DummyModel(torch.nn.Module):
157
157
def __init__ (self ):
158
158
super ().__init__ ()
159
159
self .criterion = None # Will be set by wrapper
160
+
160
161
def loss (self , items ):
161
162
# Return different loss components based on model version
162
163
return (torch .tensor ([1.0 , 2.0 , 3.0 ]), [torch .tensor (1.0 ), torch .tensor (2.0 ), torch .tensor (3.0 )])
163
164
164
165
# Mock ultralytics imports
165
166
import sys
166
167
import types
167
-
168
+
168
169
def create_mock_imports ():
169
170
return types .SimpleNamespace (
170
171
models = types .SimpleNamespace (
@@ -173,10 +174,7 @@ def create_mock_imports():
173
174
)
174
175
),
175
176
utils = types .SimpleNamespace (
176
- loss = types .SimpleNamespace (
177
- v8DetectionLoss = lambda m : "v8_loss" ,
178
- E2EDetectLoss = lambda m : "v10_loss"
179
- )
177
+ loss = types .SimpleNamespace (v8DetectionLoss = lambda m : "v8_loss" , E2EDetectLoss = lambda m : "v10_loss" )
180
178
),
181
179
)
182
180
@@ -208,13 +206,15 @@ def test_yolov8_inference_mode():
208
206
class DummyYoloV8Model (torch .nn .Module ):
209
207
def __init__ (self ):
210
208
super ().__init__ ()
209
+
211
210
def forward (self , x ):
212
211
# Return format matching YOLO v8+ output structure
213
212
return [{"boxes" : torch .ones (1 , 4 ), "scores" : torch .ones (1 ), "labels" : torch .zeros (1 )}]
214
213
215
214
# Mock ultralytics imports
216
215
import sys
217
216
import types
217
+
218
218
ultralytics_mock = types .SimpleNamespace (
219
219
models = types .SimpleNamespace (
220
220
yolo = types .SimpleNamespace (
@@ -228,10 +228,10 @@ def forward(self, x):
228
228
boxes = types .SimpleNamespace (
229
229
xyxy = torch .tensor ([[1.0 , 2.0 , 3.0 , 4.0 ]]),
230
230
conf = torch .tensor ([0.95 ]),
231
- cls = torch .tensor ([1 ])
231
+ cls = torch .tensor ([1 ]),
232
232
)
233
233
)
234
- ]
234
+ ],
235
235
)
236
236
)
237
237
)
@@ -273,6 +273,7 @@ def test_yolov8_training_data_format():
273
273
class DummyModel (torch .nn .Module ):
274
274
def __init__ (self ):
275
275
super ().__init__ ()
276
+
276
277
def loss (self , items ):
277
278
# Validate input format matches expected YOLO v8+ training format
278
279
assert "bboxes" in items
@@ -284,6 +285,7 @@ def loss(self, items):
284
285
# Setup mock imports
285
286
import sys
286
287
import types
288
+
287
289
ultralytics_mock = types .SimpleNamespace (
288
290
models = types .SimpleNamespace (
289
291
yolo = types .SimpleNamespace (
@@ -313,14 +315,10 @@ def loss(self, items):
313
315
for box_count in box_counts :
314
316
x = torch .zeros ((batch_size , 3 , 416 , 416 ))
315
317
targets = [
316
- {
317
- "boxes" : torch .zeros ((box_count , 4 )),
318
- "labels" : torch .zeros (box_count )
319
- }
320
- for _ in range (batch_size )
318
+ {"boxes" : torch .zeros ((box_count , 4 )), "labels" : torch .zeros (box_count )} for _ in range (batch_size )
321
319
]
322
320
losses = wrapper (x , targets )
323
-
321
+
324
322
# Verify loss structure
325
323
assert set (losses .keys ()) == {"loss_total" , "loss_box" , "loss_cls" , "loss_dfl" }
326
324
assert all (isinstance (v , torch .Tensor ) for v in losses .values ())
0 commit comments