Skip to content

Commit cc636c5

Browse files
committed
Fixing format for pytorch yolo wrapper test.
Signed-off-by: Kieran Fraser <[email protected]>
1 parent ffa93bf commit cc636c5

File tree

1 file changed

+27
-20
lines changed

1 file changed

+27
-20
lines changed

tests/estimators/object_detection/test_pytorch_yolo_loss_wrapper.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,18 @@ def test_yolov8_loss_wrapper():
4646
boxes = torch.tensor([[0.1, 0.1, 0.3, 0.3], [0.5, 0.5, 0.8, 0.8]]) # [x1, y1, x2, y2]
4747
labels = torch.zeros(2, dtype=torch.long) # Use class 0 for testing
4848
targets.append({"boxes": boxes, "labels": labels})"""
49-
targets = torch.tensor([[ 0.0000, 20.0000, 0.7738, 0.3919, 0.4525, 0.7582],
50-
[ 0.0000, 20.0000, 0.2487, 0.4062, 0.4966, 0.5787],
51-
[ 0.0000, 20.0000, 0.5667, 0.2772, 0.0791, 0.2313],
52-
[ 0.0000, 20.0000, 0.1009, 0.1955, 0.2002, 0.0835],
53-
[ 1.0000, 20.0000, 0.7738, 0.3919, 0.4525, 0.7582],
54-
[ 1.0000, 20.0000, 0.2487, 0.4062, 0.4966, 0.5787],
55-
[ 1.0000, 20.0000, 0.5667, 0.2772, 0.0791, 0.2313],
56-
[ 1.0000, 20.0000, 0.1009, 0.1955, 0.2002, 0.0835]])
57-
49+
targets = torch.tensor(
50+
[
51+
[0.0000, 20.0000, 0.7738, 0.3919, 0.4525, 0.7582],
52+
[0.0000, 20.0000, 0.2487, 0.4062, 0.4966, 0.5787],
53+
[0.0000, 20.0000, 0.5667, 0.2772, 0.0791, 0.2313],
54+
[0.0000, 20.0000, 0.1009, 0.1955, 0.2002, 0.0835],
55+
[1.0000, 20.0000, 0.7738, 0.3919, 0.4525, 0.7582],
56+
[1.0000, 20.0000, 0.2487, 0.4062, 0.4966, 0.5787],
57+
[1.0000, 20.0000, 0.5667, 0.2772, 0.0791, 0.2313],
58+
[1.0000, 20.0000, 0.1009, 0.1955, 0.2002, 0.0835],
59+
]
60+
)
5861

5962
# Test training mode
6063
losses = wrapper(x, targets)
@@ -108,14 +111,18 @@ def test_yolov10_loss_wrapper():
108111
boxes = torch.tensor([[0.1, 0.1, 0.3, 0.3], [0.5, 0.5, 0.8, 0.8]]) # [x1, y1, x2, y2]
109112
labels = torch.zeros(2, dtype=torch.long) # Use class 0 for testing
110113
targets.append({"boxes": boxes, "labels": labels})"""
111-
targets = torch.tensor([[ 0.0000, 20.0000, 0.7738, 0.3919, 0.4525, 0.7582],
112-
[ 0.0000, 20.0000, 0.2487, 0.4062, 0.4966, 0.5787],
113-
[ 0.0000, 20.0000, 0.5667, 0.2772, 0.0791, 0.2313],
114-
[ 0.0000, 20.0000, 0.1009, 0.1955, 0.2002, 0.0835],
115-
[ 1.0000, 20.0000, 0.7738, 0.3919, 0.4525, 0.7582],
116-
[ 1.0000, 20.0000, 0.2487, 0.4062, 0.4966, 0.5787],
117-
[ 1.0000, 20.0000, 0.5667, 0.2772, 0.0791, 0.2313],
118-
[ 1.0000, 20.0000, 0.1009, 0.1955, 0.2002, 0.0835]])
114+
targets = torch.tensor(
115+
[
116+
[0.0000, 20.0000, 0.7738, 0.3919, 0.4525, 0.7582],
117+
[0.0000, 20.0000, 0.2487, 0.4062, 0.4966, 0.5787],
118+
[0.0000, 20.0000, 0.5667, 0.2772, 0.0791, 0.2313],
119+
[0.0000, 20.0000, 0.1009, 0.1955, 0.2002, 0.0835],
120+
[1.0000, 20.0000, 0.7738, 0.3919, 0.4525, 0.7582],
121+
[1.0000, 20.0000, 0.2487, 0.4062, 0.4966, 0.5787],
122+
[1.0000, 20.0000, 0.5667, 0.2772, 0.0791, 0.2313],
123+
[1.0000, 20.0000, 0.1009, 0.1955, 0.2002, 0.0835],
124+
]
125+
)
119126

120127
# Test training mode
121128
losses = wrapper(x, targets)
@@ -236,7 +243,7 @@ def loss(self, items):
236243
wrapper.train()
237244
# Dummy input and targets
238245
x = torch.zeros((1, 3, 416, 416))
239-
targets = torch.tensor([[ 0.0000, 20.0000, 0.7738, 0.3919, 0.4525, 0.7582]])
246+
targets = torch.tensor([[0.0000, 20.0000, 0.7738, 0.3919, 0.4525, 0.7582]])
240247
losses = wrapper(x, targets)
241248
assert set(losses.keys()) == {"loss_total", "loss_box", "loss_cls", "loss_dfl"}
242249
assert losses["loss_total"].item() == 6.0 # sum([1.0, 2.0, 3.0])
@@ -281,7 +288,7 @@ def loss(self, items):
281288
wrapper = PyTorchYoloLossWrapper(test_model, name="yolov8n")
282289
wrapper.train()
283290
x = torch.zeros((1, 3, 416, 416))
284-
targets = torch.tensor([[ 0.0000, 20.0000, 0.7738, 0.3919, 0.4525, 0.7582]])
291+
targets = torch.tensor([[0.0000, 20.0000, 0.7738, 0.3919, 0.4525, 0.7582]])
285292
losses = wrapper(x, targets)
286293
assert set(losses.keys()) == {"loss_total", "loss_box", "loss_cls", "loss_dfl"}
287294
assert losses["loss_total"].item() == 6.0
@@ -456,7 +463,7 @@ def loss(self, items):
456463
for batch_size in batch_sizes:
457464
for box_count in box_counts:
458465
x = torch.zeros((batch_size, 3, 416, 416))
459-
targets = torch.tensor([[ 0.0000, 20.0000, 0.7738, 0.3919, 0.4525, 0.7582]]*batch_size)
466+
targets = torch.tensor([[0.0000, 20.0000, 0.7738, 0.3919, 0.4525, 0.7582]] * batch_size)
460467
losses = wrapper(x, targets)
461468

462469
# Verify loss structure

0 commit comments

Comments
 (0)