Skip to content

Commit 7b69489

Browse files
author
pfeatherstone
committed
Don't need DetectV10. Need to implement end2end training in Detect
1 parent 8f35f95 commit 7b69489

File tree

1 file changed

+4
-21
lines changed

1 file changed

+4
-21
lines changed

src/models.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,23 +1033,6 @@ def forward_private(self, xs, cv2, cv3, targets=None):
10331033
def forward(self, xs, targets=None):
10341034
return self.forward_private(xs, self.cv2, self.cv3, targets)
10351035

1036-
class DetectV10(Detect):
1037-
def __init__(self, nc=80, ch=()):
1038-
super().__init__(nc, ch)
1039-
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, x, 3, g=x),
1040-
Conv(x, self.c3, 1),
1041-
Conv(self.c3, self.c3, 3, g=self.c3),
1042-
Conv(self.c3, self.c3, 1),
1043-
nn.Conv2d(self.c3, self.nc, 1)) for x in ch)
1044-
1045-
self.one2one_cv2 = deepcopy(self.cv2)
1046-
self.one2one_cv3 = deepcopy(self.cv3)
1047-
self.max_det = 100
1048-
1049-
def forward(self, x, targets=None):
1050-
# TODO: implement all the topk stuff. I think yolov10 doesn't need NMS. But you can you use it in inference mode for now.
1051-
return self.forward_private(x, self.one2one_cv2, self.one2one_cv3)
1052-
10531036
class DetectV6(nn.Module):
10541037
def __init__(self, nc=80, ch=(), use_dfl=False, distill=False):
10551038
super().__init__()
@@ -1148,31 +1131,31 @@ def __init__(self, variant, num_classes):
11481131
d, w, r = get_variant_multiplesV5(variant)
11491132
super().__init__(BackboneV5(w, r, d),
11501133
HeadV5(w, r, d),
1151-
Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*2))),
1134+
Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*2)), separable=False, dfl=True, end2end=False),
11521135
variant)
11531136

11541137
class Yolov8(YoloBase):
11551138
def __init__(self, variant, num_classes):
11561139
d, w, r = get_variant_multiplesV8(variant)
11571140
super().__init__(BackboneV8(w, r, d),
11581141
HeadV8(w, r, d),
1159-
Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*r))),
1142+
Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*r)), separable=False, dfl=True, end2end=False),
11601143
variant)
11611144

11621145
class Yolov10(YoloBase):
11631146
def __init__(self, variant, num_classes):
11641147
d, w, r = get_variant_multiplesV10(variant)
11651148
super().__init__(BackboneV10(w, r, d, variant),
11661149
HeadV10(w, r, d, variant),
1167-
DetectV10(num_classes, ch=(int(256*w), int(512*w), int(512*w*r))),
1150+
Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*r)), separable=True, dfl=True, end2end=True),
11681151
variant)
11691152

11701153
class Yolov11(YoloBase):
11711154
def __init__(self, variant, num_classes):
11721155
d, w, r = get_variant_multiplesV11(variant)
11731156
super().__init__(BackboneV11(w, r, d, variant),
11741157
HeadV11(w, r, d, variant),
1175-
Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*r)), separable=True),
1158+
Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*r)), separable=True, dfl=True, end2end=False),
11761159
variant)
11771160

11781161
class Yolov26(YoloBase):

0 commit comments

Comments
 (0)