Skip to content

Commit 0fc9c0d

Browse files
author
pfeatherstone
committed
- initialize batchnorm just like everyone else does. so no need for init_batchnorms anymore
- C3k2() can take the attn=True argument which uses repeated BottleNeck + Attention blocks - SPPF can now shortcut for YOlo26 and can configure both activations - ABlock and PSABlock were pretty much the same so combined - Yolo26 WIP
1 parent 446c1c3 commit 0fc9c0d

File tree

1 file changed

+49
-49
lines changed

1 file changed

+49
-49
lines changed

src/models.py

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,6 @@ def copy_params(n1: nn.Module, n2: nn.Module):
9797
m1.running_var.data.copy_(m2.running_var.data)
9898
m1.eps = m2.eps
9999
m1.momentum = m2.momentum
100-
101-
def init_batchnorms(net: nn.Module):
102-
for m in batchnorms(net):
103-
m.eps = 1e-3
104-
m.momentum = 0.03
105100

106101
def count_parameters(net: torch.nn.Module, include_stats=True):
107102
return sum(p.numel() for p in net.parameters()) + (sum(m.running_mean.numel() + m.running_var.numel() for m in batchnorms(net)) if include_stats else 0)
@@ -125,7 +120,7 @@ def forward(self, x):
125120
class Conv(nn.Sequential):
126121
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=nn.SiLU(True)):
127122
super().__init__(nn.Conv2d(c1, c2, k, s, default(p,k//2), groups=g, bias=False),
128-
nn.BatchNorm2d(c2),
123+
nn.BatchNorm2d(c2, eps=1e-3, momentum=0.03),
129124
act)
130125

131126
def Con5(c1, c2=None, spp=False, act=actV3):
@@ -220,9 +215,11 @@ def forward(self, x):
220215
b = self.m(self.cv2(x))
221216
return self.cv3(torch.cat((b, a), 1))
222217

223-
def C3k2(c1, c2, n=1, shortcut=True, e=0.5, c3k=False):
218+
def C3k2(c1, c2, n=1, shortcut=True, e=0.5, c3k=False, attn=False):
224219
net = C2f(c1, c2, n, shortcut, e)
225-
blk = C3(net.c_, net.c_, k=(3,3), shortcut=shortcut, n=2) if c3k else Bottleneck(net.c_, k=(3,3), shortcut=shortcut)
220+
if attn: blk = nn.Sequential(Bottleneck(net.c_, k=(3,3), shortcut=shortcut), PSABlock(net.c_, num_heads=net.c_//64, attn_ratio=0.5))
221+
elif c3k: blk = C3(net.c_, net.c_, k=(3,3), shortcut=shortcut, n=2)
222+
else: blk = Bottleneck(net.c_, k=(3,3), shortcut=shortcut)
226223
net.m = nn.ModuleList(deepcopy(blk) for _ in range(n))
227224
return net
228225

@@ -282,20 +279,20 @@ def forward(self, x):
282279
return x
283280

284281
class SPPF(nn.Module):
285-
def __init__(self, c1, c2, act=nn.SiLU(True), shortcut=False): # equivalent to SPP(k=(5, 9, 13))
282+
def __init__(self, c1, c2, acts=[nn.Identity(), nn.SiLU(True)], shortcut=False): # equivalent to SPP(k=(5, 9, 13))
286283
super().__init__()
287284
c_ = c1 // 2 # hidden channels
288-
self.add = shortcut
289-
self.cv1 = Conv(c1, c_, 1, 1, act=act)
290-
self.cv2 = Conv(c_*4, c2, 1, 1, act=act)
285+
self.add = shortcut and c1==c2
286+
self.cv1 = Conv(c1, c_, 1, 1, act=acts[0])
287+
self.cv2 = Conv(c_*4, c2, 1, 1, act=acts[1])
291288
self.m = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
292-
def forward(self, x):
293-
x = self.cv1(x)
289+
def forward(self, input):
290+
x = self.cv1(input)
294291
y1 = self.m(x)
295292
y2 = self.m(y1)
296293
y3 = self.m(y2)
297294
y = self.cv2(torch.cat((x, y1, y2, y3), 1))
298-
return x+y if self.add else y
295+
return input+y if self.add else y
299296

300297
class SPPCSPC(nn.Module):
301298
def __init__(self, c1, c2, e=0.5, act=nn.SiLU(True)):
@@ -341,35 +338,31 @@ def forward(self, x):
341338
x = self.proj(x + self.pe(v))
342339
return x
343340

344-
def PSABlock(c, num_heads=4, attn_ratio=4):
345-
return nn.Sequential(Residual(Attention(c, num_heads=num_heads, attn_ratio=attn_ratio)),
346-
Residual(nn.Sequential(Conv(c, c*2, 1), Conv(c*2, c, 1, act=nn.Identity()))))
341+
def PSABlock(c, num_heads=4, attn_ratio=0.5, area=None, e=2):
342+
c_ = int(c*e)
343+
return nn.Sequential(Residual(Attention(c, num_heads=num_heads, attn_ratio=attn_ratio, area=area)),
344+
Residual(nn.Sequential(Conv(c, c_, 1), Conv(c_, c, 1, act=nn.Identity()))))
347345

348346
class PSA(nn.Module):
349347
def __init__(self, c, e=0.5, n=1):
350348
super().__init__()
351349
c_ = int(c * e) # hidden channels
352350
self.cv1 = Conv(c, 2*c_, 1)
353351
self.cv2 = Conv(2*c_, c, 1)
354-
self.net = Repeat(PSABlock(c_, num_heads=c_ // 64, attn_ratio=0.5), n)
352+
self.net = Repeat(PSABlock(c_, num_heads=c_//64, attn_ratio=0.5), n)
355353

356354
def forward(self, x):
357355
a, b = self.cv1(x).chunk(2, 1)
358356
return self.cv2(torch.cat((a, self.net(b)), 1))
359357

360-
def ABlock(c, num_heads, area=1, mlp_ratio=1.2):
361-
c_ = int(c*mlp_ratio)
362-
return nn.Sequential(Residual(Attention(c, num_heads=num_heads, attn_ratio=1, area=area)),
363-
Residual(nn.Sequential(Conv(c, c_, 1), Conv(c_, c, 1, act=nn.Identity()))))
364-
365358
class A2C2f(nn.Module):
366359
def __init__(self, c1, c2, n=1, shortcut=True, e=0.5, a2=True, residual=False, area=1, mlp_ratio=2.0):
367360
super().__init__()
368361
self.c_ = int(c2 * e) # hidden channels
369362
self.cv1 = Conv(c1, self.c_, 1)
370363
self.cv2 = Conv((1 + n) * self.c_, c2, 1)
371364
self.g = nn.Parameter(0.01 * torch.ones(c2)) if (a2 and residual) else None
372-
self.m = nn.ModuleList(Repeat(ABlock(self.c_, self.c_//32, area, mlp_ratio), 2) if a2 else
365+
self.m = nn.ModuleList(Repeat(PSABlock(self.c_, num_heads=self.c_//32, attn_ratio=1, area=area, e=mlp_ratio), 2) if a2 else
373366
C3(self.c_, self.c_, k=(3,3), shortcut=shortcut, n=2) for _ in range(n))
374367

375368
def forward(self, x):
@@ -615,14 +608,10 @@ def forward(self, x):
615608
x8 = self.b8(self.b7(x6)) # 8 P5/32
616609
return x4, x6, x8
617610

618-
class BackboneV26(BackboneV11):
619-
def __init__(self, w, r, d, variant):
620-
super().__init__(w, r, d, variant, sppf_shortcut=True)
621-
622611
class EfficientRep(nn.Module):
623612
def __init__(self, w, d, cspsppf=False):
624613
super().__init__()
625-
sppf = partial(SPPCSPC, e=0.25) if cspsppf else SPPF
614+
sppf = partial(SPPCSPC, e=0.25, act=nn.ReLU(True)) if cspsppf else partial(SPPF, acts=[nn.ReLU(True), nn.ReLU(True)])
626615
self.b0 = RepConv( c1=3, c2=int(64*w), s=2, act=F.relu)
627616
self.b1 = RepConv( c1=int(64*w), c2=int(128*w), s=2, act=F.relu)
628617
self.b2 = RepBlock(c1=int(128*w), c2=int(128*w), n=round(6*d))
@@ -632,7 +621,7 @@ def __init__(self, w, d, cspsppf=False):
632621
self.b6 = RepBlock(c1=int(512*w), c2=int(512*w), n=round(18*d))
633622
self.b7 = RepConv( c1=int(512*w), c2=int(1024*w),s=2, act=F.relu)
634623
self.b8 = RepBlock(c1=int(1024*w),c2=int(1024*w),n=round(6*d))
635-
self.b9 = sppf( c1=int(1024*w),c2=int(1024*w), act=nn.ReLU(True))
624+
self.b9 = sppf(c1=int(1024*w),c2=int(1024*w))
636625

637626
def forward(self, x):
638627
x4 = self.b2(self.b1(self.b0(x))) # p2/4
@@ -644,7 +633,7 @@ def forward(self, x):
644633
class CSPBepBackbone(nn.Module):
645634
def __init__(self, w, d, csp_e=1/2, cspsppf=False):
646635
super().__init__()
647-
sppf = partial(SPPCSPC, e=0.25) if cspsppf else SPPF
636+
sppf = partial(SPPCSPC, e=0.25, act=nn.ReLU(True)) if cspsppf else partial(SPPF, acts=[nn.SiLU(), nn.SiLU()])
648637
self.b0 = RepConv(c1=3, c2=int(64*w), s=2, act=F.relu)
649638
self.b1 = RepConv(c1=int(64*w), c2=int(128*w), s=2, act=F.relu)
650639
self.b2 = BepC3( c1=int(128*w), c2=int(128*w), e=csp_e, n=round(6*d))
@@ -654,7 +643,7 @@ def __init__(self, w, d, csp_e=1/2, cspsppf=False):
654643
self.b6 = BepC3( c1=int(512*w), c2=int(512*w), e=csp_e, n=round(18*d))
655644
self.b7 = RepConv(c1=int(512*w), c2=int(1024*w),s=2, act=F.relu)
656645
self.b8 = BepC3( c1=int(1024*w),c2=int(1024*w),e=csp_e, n=round(6*d))
657-
self.b9 = sppf( c1=int(1024*w),c2=int(1024*w), act=nn.ReLU(True))
646+
self.b9 = sppf( c1=int(1024*w),c2=int(1024*w))
658647

659648
def forward(self, x):
660649
x4 = self.b2(self.b1(self.b0(x))) # p2/4
@@ -790,16 +779,17 @@ def forward(self, x4, x6, x10):
790779
return [x16, x19, x22]
791780

792781
class HeadV11(nn.Module):
793-
def __init__(self, w, r, d, variant):
782+
def __init__(self, w, r, d, variant, is26=False):
794783
super().__init__()
795-
c3k = variant in "mlx"
784+
c3k = True if is26 else variant in "mlx"
785+
n = 1 if is26 else 2
796786
self.up = nn.Upsample(scale_factor=2)
797-
self.n1 = C3k2(c1=int(512*w*(1+r)), c2=int(512*w), n=round(2*d), c3k=c3k)
798-
self.n2 = C3k2(c1=int(512*w*2), c2=int(256*w), n=round(2*d), c3k=c3k)
799-
self.n3 = Conv(c1=int(256*w), c2=int(256*w), k=3, s=2)
800-
self.n4 = C3k2(c1=int(768*w), c2=int(512*w), n=round(2*d), c3k=c3k)
801-
self.n5 = Conv(c1=int(512*w), c2=int(512*w), k=3, s=2)
802-
self.n6 = C3k2(c1=int(512*w*(1+r)), c2=int(512*w*r), n=round(2*d), c3k=True)
787+
self.n1 = C3k2(c1=int(512*w*(1+r)), c2=int(512*w), n=round(2*d), c3k=c3k)
788+
self.n2 = C3k2(c1=int(512*w*2), c2=int(256*w), n=round(2*d), c3k=c3k)
789+
self.n3 = Conv(c1=int(256*w), c2=int(256*w), k=3, s=2)
790+
self.n4 = C3k2(c1=int(768*w), c2=int(512*w), n=round(2*d), c3k=c3k)
791+
self.n5 = Conv(c1=int(512*w), c2=int(512*w), k=3, s=2)
792+
self.n6 = C3k2(c1=int(512*w*(1+r)), c2=int(512*w*r), n=max(1,round(n*d)), c3k=True, attn=is26)
803793

804794
def forward(self, x4, x6, x10):
805795
x13 = self.n1(torch.cat([self.up(x10),x6], 1)) # 13
@@ -993,20 +983,22 @@ def forward(self, xs, targets=None):
993983
return pred if not exists(targets) else (pred, {'iou': loss_iou, 'cls': loss_cls, 'obj': loss_obj})
994984

995985
class Detect(nn.Module):
996-
def __init__(self, nc=80, ch=(), v11=False):
986+
def __init__(self, nc=80, ch=(), dfl=True, separable=False, end2end=False):
997987
super().__init__()
998988
def spconv(c1, c2, k): return nn.Sequential(Conv(c1,c1,k,g=c1),Conv(c1,c2,1))
999-
conv = spconv if v11 else Conv
989+
conv = spconv if separable else Conv
1000990
self.nc = nc # number of classes
1001-
self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
991+
self.reg_max = 16 if dfl else 1 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
1002992
self.no = nc + self.reg_max * 4 # number of outputs per anchor
1003993
self.strides = [8, 16, 32] # strides computed during build
1004994
self.c2 = max((16, ch[0] // 4, self.reg_max * 4))
1005995
self.c3 = max(ch[0], min(self.nc, 100)) # channels
1006996
self.cv2 = nn.ModuleList(nn.Sequential(Conv(x, self.c2, 3), Conv(self.c2, self.c2, 3), nn.Conv2d(self.c2, 4 * self.reg_max, 1)) for x in ch)
1007997
self.cv3 = nn.ModuleList(nn.Sequential(conv(x, self.c3, 3), conv(self.c3, self.c3, 3), nn.Conv2d(self.c3, self.nc, 1)) for x in ch)
1008-
self.r = nn.Parameter(torch.arange(self.reg_max).float(), requires_grad=False)
1009-
998+
self.r = nn.Parameter(torch.arange(self.reg_max).float(), requires_grad=False) if dfl else None
999+
if end2end: self.one2one_cv2 = deepcopy(self.cv2)
1000+
if end2end: self.one2one_cv3 = deepcopy(self.cv3)
1001+
10101002
def forward_private(self, xs, cv2, cv3, targets=None):
10111003
sxy, ps, strides = make_anchors(xs, self.strides)
10121004
feats = [rearrange(torch.cat((c1(x), c2(x)), 1), 'b f h w -> b (h w) f') for x,c1,c2 in zip(xs, cv2, cv3)]
@@ -1179,15 +1171,23 @@ def __init__(self, variant, num_classes):
11791171
d, w, r = get_variant_multiplesV11(variant)
11801172
super().__init__(BackboneV11(w, r, d, variant),
11811173
HeadV11(w, r, d, variant),
1182-
Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*r)), v11=True),
1174+
Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*r)), separable=True),
11831175
variant)
11841176

1177+
class Yolov26(YoloBase):
1178+
def __init__(self, variant, num_classes):
1179+
d, w, r = get_variant_multiplesV26(variant)
1180+
super().__init__(BackboneV11(w, r, d, variant, sppf_shortcut=True),
1181+
HeadV11(w, r, d, variant, is26=True),
1182+
Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*r)), separable=True, dfl=False, end2end=True),
1183+
variant)
1184+
11851185
class Yolov12(YoloBase):
11861186
def __init__(self, variant, num_classes):
11871187
d, w, r = get_variant_multiplesV12(variant)
11881188
super().__init__(BackboneV12(w, r, d, variant),
11891189
HeadV12(w, r, d),
1190-
Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*r)), v11=True),
1190+
Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*r)), separable=True),
11911191
variant)
11921192

11931193
class Yolov6(YoloBase):
@@ -1230,4 +1230,4 @@ def barlow_loss(z1, z2, lambda_coeff):
12301230
mask = torch.eye(cross.shape[0], dtype=torch.bool, device=cross.device)
12311231
on_diag = (cross[mask]-1).pow(2).sum()
12321232
off_diag = cross[~mask].pow(2).sum()
1233-
return (on_diag + lambda_coeff * off_diag, cross)
1233+
return (on_diag + lambda_coeff * off_diag, cross)

0 commit comments

Comments
 (0)