Skip to content

Commit f1c7843

Browse files
committed
Yolo 26 with nms required, argument for cli version passing in tests
1 parent 486b83d commit f1c7843

File tree

7 files changed

+33
-10
lines changed

7 files changed

+33
-10
lines changed

tests/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
{"name": "yolo26x", "version": "v26"},
7777
{"name": "yolo26n-seg", "version": "v26"},
7878
{"name": "yolo26n-pose", "version": "v26"},
79+
{"name": "yolo26n", "version": "v26_nms", "cli_version": "yolov26_nms"},
7980
{"name": "yolov8n-cls", "version": "v8"},
8081
{"name": "yolov8n-seg", "version": "v8"},
8182
{"name": "yolov8n-pose", "version": "v8"},

tests/test_end2end.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
@pytest.mark.parametrize(
1616
"model",
1717
TEST_MODELS,
18-
ids=[model["name"] for model in TEST_MODELS],
18+
ids=[model.get("cli_version", model["name"]) if model.get("cli_version") else model["name"] for model in TEST_MODELS],
1919
)
2020
def test_cli_conversion(model: dict, test_config: dict, subtests):
2121
"""Tests the whole CLI conversion flow with no extra params specified."""
@@ -50,6 +50,8 @@ def test_cli_conversion(model: dict, test_config: dict, subtests):
5050
pytest.skip("Weights not present and `download_weights` not set")
5151

5252
command = ["tools", model_path]
53+
if model.get("cli_version"):
54+
command += ["--version", model.get("cli_version")]
5355
if model.get("size"): # edge case when stride=64 is needed
5456
command += ["--imgsz", model.get("size")]
5557

tools/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
YOLOV11_CONVERSION,
2929
YOLOV12_CONVERSION,
3030
YOLOV26_CONVERSION,
31+
YOLOV26_NMS_CONVERSION,
3132
detect_version,
3233
)
3334

@@ -50,6 +51,7 @@
5051
YOLOV11_CONVERSION,
5152
YOLOV12_CONVERSION,
5253
YOLOV26_CONVERSION,
54+
YOLOV26_NMS_CONVERSION,
5355
]
5456

5557

@@ -176,6 +178,7 @@ def convert(
176178
YOLOV9_CONVERSION,
177179
YOLOV11_CONVERSION,
178180
YOLOV12_CONVERSION,
181+
YOLOV26_NMS_CONVERSION,
179182
]:
180183
from tools.yolo.yolov8_exporter import YoloV8Exporter
181184

tools/modules/heads.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -366,11 +366,17 @@ def __init__(self, old_detect, use_rvc2: bool):
366366

367367
self.use_rvc2 = use_rvc2
368368

369-
self.proj_conv = nn.Conv2d(old_detect.dfl.c1, 1, 1, bias=False).requires_grad_(
370-
False
371-
)
372-
x = torch.arange(old_detect.dfl.c1, dtype=torch.float)
373-
self.proj_conv.weight.data[:] = nn.Parameter(x.view(1, old_detect.dfl.c1, 1, 1))
369+
# yolo26: dfl will be nn.Identity(), we set proj_conv = None and skip the DFL block in forward
370+
if hasattr(old_detect.dfl, "c1"):
371+
self.proj_conv = nn.Conv2d(
372+
old_detect.dfl.c1, 1, 1, bias=False
373+
).requires_grad_(False)
374+
x = torch.arange(old_detect.dfl.c1, dtype=torch.float)
375+
self.proj_conv.weight.data[:] = nn.Parameter(
376+
x.view(1, old_detect.dfl.c1, 1, 1)
377+
)
378+
else:
379+
self.proj_conv = None
374380

375381
def forward(self, x):
376382
bs = x[0].shape[0] # batch size
@@ -382,9 +388,10 @@ def forward(self, x):
382388

383389
# ------------------------------
384390
# DFL PART
385-
box = box.view(bs, 4, self.reg_max, h * w).permute(0, 2, 1, 3)
386-
box = self.proj_conv(F.softmax(box, dim=1))[:, 0]
387-
box = box.reshape([bs, 4, h, w])
391+
if self.proj_conv is not None:
392+
box = box.view(bs, 4, self.reg_max, h * w).permute(0, 2, 1, 3)
393+
box = self.proj_conv(F.softmax(box, dim=1))[:, 0]
394+
box = box.reshape([bs, 4, h, w])
388395
# ------------------------------
389396

390397
cls = self.cv3[i](x[i])

tools/version_detection/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
YOLOV11_CONVERSION,
1414
YOLOV12_CONVERSION,
1515
YOLOV26_CONVERSION,
16+
YOLOV26_NMS_CONVERSION,
1617
detect_version,
1718
)
1819

@@ -30,6 +31,7 @@
3031
"YOLOV11_CONVERSION",
3132
"YOLOV12_CONVERSION",
3233
"YOLOV26_CONVERSION",
34+
"YOLOV26_NMS_CONVERSION",
3335
"GOLD_YOLO_CONVERSION",
3436
"UNRECOGNIZED",
3537
]

tools/version_detection/version_detection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
YOLOV11_CONVERSION = "yolov11"
1919
YOLOV12_CONVERSION = "yolov12"
2020
YOLOV26_CONVERSION = "yolov26"
21+
YOLOV26_NMS_CONVERSION = "yolov26_nms"
2122
GOLD_YOLO_CONVERSION = "goldyolo"
2223
UNRECOGNIZED = "none"
2324

tools/yolo/yolov8_exporter.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,16 @@ def __init__(
119119
def load_model(self):
120120
# load the model
121121
model, _ = load_checkpoint(
122-
self.model_path, device="cpu", inplace=True, fuse=True
122+
self.model_path, device="cpu", inplace=True, fuse=False
123123
)
124124

125+
# for yolo26 end2end has to be disabled before fusing
126+
# otherwise cv2/cv3 are removed in the fuse process
127+
head = model.model[-1]
128+
if getattr(head, "end2end", False):
129+
head.end2end = False
130+
model.fuse()
131+
125132
self.mode = -1
126133
if isinstance(model.model[-1], (Segment)) or isinstance(
127134
model.model[-1], (YOLOESegment)

0 commit comments

Comments
 (0)