Skip to content

Commit c437e6d

Browse files
committed
Signed Commit: Signed-off-by: Arjun Sachar <[email protected]>
Signed-off-by: [email protected] <[email protected]>
1 parent f89f350 commit c437e6d

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

tests/estimators/object_detection/test_pytorch_yolo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def __init__(self):
335335
def loss(self, items):
336336
return (torch.tensor([1.0]), [torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0)])
337337

338-
dummy_model = DummyModel()
338+
test_model = DummyModel()
339339
# Patch ultralytics import in the wrapper
340340
import sys
341341
import types
@@ -356,5 +356,5 @@ def loss(self, items):
356356
sys.modules["ultralytics.models.yolo.detect"] = ultralytics_mock.models.yolo.detect
357357
sys.modules["ultralytics.utils"] = ultralytics_mock.utils
358358
sys.modules["ultralytics.utils.loss"] = ultralytics_mock.utils.loss
359-
wrapper = PyTorchYoloLossWrapper(dummy_model, name="yolov8n")
359+
wrapper = PyTorchYoloLossWrapper(test_model, name="yolov8n")
360360
assert isinstance(wrapper, PyTorchYoloLossWrapper)

tests/estimators/object_detection/test_pytorch_yolo_loss_wrapper.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ class DummyModel(torch.nn.Module):
1919
def forward(self, x):
2020
return x
2121

22-
dummy_model = DummyModel()
22+
test_model = DummyModel()
2323
yolo = PyTorchYolo(
24-
model=dummy_model,
24+
model=test_model,
2525
input_shape=(3, 416, 416),
2626
optimizer=None,
2727
clip_values=(0, 1),
@@ -68,7 +68,7 @@ def loss(self, items):
6868
# Return (loss, [loss_box, loss_cls, loss_dfl])
6969
return (torch.tensor([1.0, 2.0, 3.0]), [torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0)])
7070

71-
dummy_model = DummyModel()
71+
test_model = DummyModel()
7272
# Patch ultralytics import in the wrapper
7373
import sys
7474
import types
@@ -90,7 +90,7 @@ def loss(self, items):
9090
sys.modules["ultralytics.utils"] = ultralytics_mock.utils
9191
sys.modules["ultralytics.utils.loss"] = ultralytics_mock.utils.loss
9292

93-
wrapper = PyTorchYoloLossWrapper(dummy_model, name="yolov8n")
93+
wrapper = PyTorchYoloLossWrapper(test_model, name="yolov8n")
9494
wrapper.train()
9595
# Dummy input and targets
9696
x = torch.zeros((1, 3, 416, 416))
@@ -114,7 +114,7 @@ def loss(self, items):
114114
# Return (loss, [loss_box, loss_cls, loss_dfl])
115115
return (torch.tensor([1.0, 2.0, 3.0]), [torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0)])
116116

117-
dummy_model = DummyModel()
117+
test_model = DummyModel()
118118
# Patch ultralytics import in the wrapper
119119
import sys
120120
import types
@@ -136,7 +136,7 @@ def loss(self, items):
136136
sys.modules["ultralytics.utils"] = ultralytics_mock.utils
137137
sys.modules["ultralytics.utils.loss"] = ultralytics_mock.utils.loss
138138

139-
wrapper = PyTorchYoloLossWrapper(dummy_model, name="yolov8n")
139+
wrapper = PyTorchYoloLossWrapper(test_model, name="yolov8n")
140140
wrapper.train()
141141
x = torch.zeros((1, 3, 416, 416))
142142
targets = [{"boxes": torch.zeros((1, 4)), "labels": torch.zeros((1,))}]

0 commit comments

Comments
 (0)