Skip to content

Commit 231adc7

Browse files
author
pfeatherstone
committed
Yolo26 - inference done
1 parent 0fc9c0d commit 231adc7

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

src/models.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def forward(self, x):
279279
return x
280280

281281
class SPPF(nn.Module):
282-
def __init__(self, c1, c2, acts=[nn.Identity(), nn.SiLU(True)], shortcut=False): # equivalent to SPP(k=(5, 9, 13))
282+
def __init__(self, c1, c2, acts=[nn.SiLU(True), nn.SiLU(True)], shortcut=False): # equivalent to SPP(k=(5, 9, 13))
283283
super().__init__()
284284
c_ = c1 // 2 # hidden channels
285285
self.add = shortcut and c1==c2
@@ -567,7 +567,7 @@ def forward(self, x):
567567
return x4, x6, x10
568568

569569
class BackboneV11(nn.Module):
570-
def __init__(self, w, r, d, variant, sppf_shortcut=False):
570+
def __init__(self, w, r, d, variant, sppf_shortcut=False, sppf_acts=[nn.SiLU(True), nn.SiLU(True)]):
571571
super().__init__()
572572
c3k = variant in "mlx"
573573
self.b0 = Conv(c1=3, c2=int(64*w), k=3, s=2)
@@ -579,7 +579,7 @@ def __init__(self, w, r, d, variant, sppf_shortcut=False):
579579
self.b6 = C3k2(c1=int(512*w), c2=int(512*w), n=round(2*d), e=0.50, c3k=True)
580580
self.b7 = Conv(c1=int(512*w), c2=int(512*w*r), k=3, s=2)
581581
self.b8 = C3k2(c1=int(512*w*r), c2=int(512*w*r), n=round(2*d), e=0.50, c3k=True)
582-
self.b9 = SPPF(c1=int(512*w*r), c2=int(512*w*r), shortcut=sppf_shortcut)
582+
self.b9 = SPPF(c1=int(512*w*r), c2=int(512*w*r), shortcut=sppf_shortcut, acts=sppf_acts)
583583
self.b10 = PSA(int(512*w*r), n=round(2*d))
584584

585585
def forward(self, x):
@@ -988,6 +988,7 @@ def __init__(self, nc=80, ch=(), dfl=True, separable=False, end2end=False):
988988
def spconv(c1, c2, k): return nn.Sequential(Conv(c1,c1,k,g=c1),Conv(c1,c2,1))
989989
conv = spconv if separable else Conv
990990
self.nc = nc # number of classes
991+
self.dfl = dfl
991992
self.reg_max = 16 if dfl else 1 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
992993
self.no = nc + self.reg_max * 4 # number of outputs per anchor
993994
self.strides = [8, 16, 32] # strides computed during build
@@ -1004,7 +1005,7 @@ def forward_private(self, xs, cv2, cv3, targets=None):
10041005
feats = [rearrange(torch.cat((c1(x), c2(x)), 1), 'b f h w -> b (h w) f') for x,c1,c2 in zip(xs, cv2, cv3)]
10051006
dist, cls = torch.cat(feats, 1).split((4 * self.reg_max, self.nc), -1)
10061007
dist = rearrange(dist, 'b n (k r) -> b n k r', k=4)
1007-
ltrb = torch.einsum('bnkr, r -> bnk', dist.softmax(-1), self.r)
1008+
ltrb = torch.einsum('bnkr, r -> bnk', dist.softmax(-1), self.r) if self.dfl else dist.squeeze(-1)
10081009
box = dist2box(ltrb, sxy, strides)
10091010
pred = torch.cat((box, cls.sigmoid()), -1)
10101011

@@ -1177,7 +1178,7 @@ def __init__(self, variant, num_classes):
11771178
class Yolov26(YoloBase):
11781179
def __init__(self, variant, num_classes):
11791180
d, w, r = get_variant_multiplesV26(variant)
1180-
super().__init__(BackboneV11(w, r, d, variant, sppf_shortcut=True),
1181+
super().__init__(BackboneV11(w, r, d, variant, sppf_shortcut=True, sppf_acts=[nn.Identity(), nn.SiLU(True)]),
11811182
HeadV11(w, r, d, variant, is26=True),
11821183
Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*r)), separable=True, dfl=False, end2end=True),
11831184
variant)

src/test.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ def load_from_ultralytics(net: Union[Yolov5, Yolov8, Yolov10, Yolov11]):
141141
for module in net2.modules():
142142
if isinstance(module, AAttn):
143143
fuse_bias_v12(module.pe.conv, module.pe.bn)
144+
elif isinstance(net, Yolov26):
145+
net2 = YOLO('yolo26{}.pt'.format(net.v)).model.eval()
146+
l0,l1 = 11,23
144147

145148
assert (nP1 := count_parameters(net)) == (nP2 := count_parameters(net2)), f'wrong number of parameters net {nP1} vs ultralytics {nP2}'
146149
copy_params(net.net, net2.model[0:l0])
@@ -187,8 +190,6 @@ def params(n):
187190
assert p1.shape == p2.shape, f"bad shape: {k} {p2.shape} {p1.shape}"
188191
p1.data.copy_(p2.data)
189192

190-
init_batchnorms(net)
191-
192193

193194
def load_from_yolov7_official(net: Yolov7, weights_pt: str):
194195
def params1():
@@ -206,8 +207,6 @@ def params2():
206207
for p1, p2 in zip(params1(), params2(), strict=True):
207208
p1.data.copy_(p2.data)
208209

209-
init_batchnorms(net)
210-
211210
# Handle special case in SPPCSPC where Yolov6 and Yolov7 disagree on the order of the final torch.cat()
212211
for module in net.modules():
213212
if isinstance(module, SPPCSPC):
@@ -230,6 +229,7 @@ def get_model(model: str, variant: str = ''):
230229
case 'yolov10': net = Yolov10(variant, 80).eval()
231230
case 'yolov11': net = Yolov11(variant, 80).eval()
232231
case 'yolov12': net = Yolov12(variant, 80).eval()
232+
case 'yolov26': net = Yolov26(variant, 80).eval()
233233

234234
print(f"{model}{variant} has {count_parameters(net)} parameters")
235235

@@ -241,7 +241,7 @@ def get_model(model: str, variant: str = ''):
241241
download_if_not_exist(model, filepath)
242242
load_from_darknet(net, filepath)
243243

244-
if model in ['yolov5', 'yolov8', 'yolov10', 'yolov11', 'yolov12']:
244+
if model in ['yolov5', 'yolov8', 'yolov10', 'yolov11', 'yolov12', 'yolov26']:
245245
load_from_ultralytics(net)
246246
has_obj = False
247247

@@ -293,7 +293,6 @@ def export(model: str, variant: str = '', onnx_path:str = '/tmp/model.onnx'):
293293
torch.testing.assert_close(preds1[...,4:], torch.from_numpy(preds2[...,4:])) # scores
294294
print(bcolors.OKGREEN, "Checking with onnxruntime... Done", bcolors.ENDC)
295295

296-
297296
test('yolov3')
298297
test('yolov3-spp')
299298
test('yolov3-tiny')
@@ -329,6 +328,11 @@ def export(model: str, variant: str = '', onnx_path:str = '/tmp/model.onnx'):
329328
test('yolov12', 'm')
330329
test('yolov12', 'l')
331330
test('yolov12', 'x')
331+
test('yolov26', 'n')
332+
test('yolov26', 's')
333+
test('yolov26', 'm')
334+
test('yolov26', 'l')
335+
test('yolov26', 'x')
332336

333337
# export('yolov3')
334338
# export('yolov3-spp')
@@ -364,4 +368,4 @@ def export(model: str, variant: str = '', onnx_path:str = '/tmp/model.onnx'):
364368
# export('yolov12', 's')
365369
# export('yolov12', 'm')
366370
# export('yolov12', 'l')
367-
# export('yolov12', 'x')
371+
# export('yolov12', 'x')

0 commit comments

Comments
 (0)