Skip to content

Commit f89f350

Browse files
committed
Formatting changes
Signed-off-by: [email protected] <[email protected]>
1 parent ee880b8 commit f89f350

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

tests/estimators/object_detection/test_pytorch_yolo_loss_wrapper.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -157,14 +157,15 @@ class DummyModel(torch.nn.Module):
157157
def __init__(self):
158158
super().__init__()
159159
self.criterion = None # Will be set by wrapper
160+
160161
def loss(self, items):
161162
# Return different loss components based on model version
162163
return (torch.tensor([1.0, 2.0, 3.0]), [torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0)])
163164

164165
# Mock ultralytics imports
165166
import sys
166167
import types
167-
168+
168169
def create_mock_imports():
169170
return types.SimpleNamespace(
170171
models=types.SimpleNamespace(
@@ -173,10 +174,7 @@ def create_mock_imports():
173174
)
174175
),
175176
utils=types.SimpleNamespace(
176-
loss=types.SimpleNamespace(
177-
v8DetectionLoss=lambda m: "v8_loss",
178-
E2EDetectLoss=lambda m: "v10_loss"
179-
)
177+
loss=types.SimpleNamespace(v8DetectionLoss=lambda m: "v8_loss", E2EDetectLoss=lambda m: "v10_loss")
180178
),
181179
)
182180

@@ -208,13 +206,15 @@ def test_yolov8_inference_mode():
208206
class DummyYoloV8Model(torch.nn.Module):
209207
def __init__(self):
210208
super().__init__()
209+
211210
def forward(self, x):
212211
# Return format matching YOLO v8+ output structure
213212
return [{"boxes": torch.ones(1, 4), "scores": torch.ones(1), "labels": torch.zeros(1)}]
214213

215214
# Mock ultralytics imports
216215
import sys
217216
import types
217+
218218
ultralytics_mock = types.SimpleNamespace(
219219
models=types.SimpleNamespace(
220220
yolo=types.SimpleNamespace(
@@ -228,10 +228,10 @@ def forward(self, x):
228228
boxes=types.SimpleNamespace(
229229
xyxy=torch.tensor([[1.0, 2.0, 3.0, 4.0]]),
230230
conf=torch.tensor([0.95]),
231-
cls=torch.tensor([1])
231+
cls=torch.tensor([1]),
232232
)
233233
)
234-
]
234+
],
235235
)
236236
)
237237
)
@@ -273,6 +273,7 @@ def test_yolov8_training_data_format():
273273
class DummyModel(torch.nn.Module):
274274
def __init__(self):
275275
super().__init__()
276+
276277
def loss(self, items):
277278
# Validate input format matches expected YOLO v8+ training format
278279
assert "bboxes" in items
@@ -284,6 +285,7 @@ def loss(self, items):
284285
# Setup mock imports
285286
import sys
286287
import types
288+
287289
ultralytics_mock = types.SimpleNamespace(
288290
models=types.SimpleNamespace(
289291
yolo=types.SimpleNamespace(
@@ -313,14 +315,10 @@ def loss(self, items):
313315
for box_count in box_counts:
314316
x = torch.zeros((batch_size, 3, 416, 416))
315317
targets = [
316-
{
317-
"boxes": torch.zeros((box_count, 4)),
318-
"labels": torch.zeros(box_count)
319-
}
320-
for _ in range(batch_size)
318+
{"boxes": torch.zeros((box_count, 4)), "labels": torch.zeros(box_count)} for _ in range(batch_size)
321319
]
322320
losses = wrapper(x, targets)
323-
321+
324322
# Verify loss structure
325323
assert set(losses.keys()) == {"loss_total", "loss_box", "loss_cls", "loss_dfl"}
326324
assert all(isinstance(v, torch.Tensor) for v in losses.values())

0 commit comments

Comments
 (0)