Skip to content

Commit 4b563b1

Browse files
author
pfeatherstone
committed
- The ONLY difference between BackboneV11 and BackboneV26 is the shortcut in SPPF
1 parent 6571c6f commit 4b563b1

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

src/models.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def get_variant_multiplesV11(variant: str):
7979
def get_variant_multiplesV12(variant: str):
8080
return get_variant_multiplesV11(variant)
8181

82+
def get_variant_multiplesV26(variant: str):
83+
return get_variant_multiplesV11(variant)
84+
8285
def batchnorms(n: nn.Module):
8386
for m in n.modules():
8487
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
@@ -279,19 +282,20 @@ def forward(self, x):
279282
return x
280283

281284
class SPPF(nn.Module):
282-
def __init__(self, c1, c2, act=nn.SiLU(True)): # equivalent to SPP(k=(5, 9, 13))
285+
def __init__(self, c1, c2, act=nn.SiLU(True), shortcut=False): # equivalent to SPP(k=(5, 9, 13))
283286
super().__init__()
284287
c_ = c1 // 2 # hidden channels
285-
conv = partial(Conv, act=act)
286-
self.cv1 = conv(c1, c_, 1, 1)
287-
self.cv2 = conv(c_ * 4, c2, 1, 1)
288+
self.add = shortcut
289+
self.cv1 = Conv(c1, c_, 1, 1, act=act)
290+
self.cv2 = Conv(c_*4, c2, 1, 1, act=act)
288291
self.m = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
289292
def forward(self, x):
290293
x = self.cv1(x)
291294
y1 = self.m(x)
292295
y2 = self.m(y1)
293296
y3 = self.m(y2)
294-
return self.cv2(torch.cat((x, y1, y2, y3), 1))
297+
y = self.cv2(torch.cat((x, y1, y2, y3), 1))
298+
return x+y if self.add else y
295299

296300
class SPPCSPC(nn.Module):
297301
def __init__(self, c1, c2, e=0.5, act=nn.SiLU(True)):
@@ -570,7 +574,7 @@ def forward(self, x):
570574
return x4, x6, x10
571575

572576
class BackboneV11(nn.Module):
573-
def __init__(self, w, r, d, variant):
577+
def __init__(self, w, r, d, variant, sppf_shortcut=False):
574578
super().__init__()
575579
c3k = variant in "mlx"
576580
self.b0 = Conv(c1=3, c2=int(64*w), k=3, s=2)
@@ -582,7 +586,7 @@ def __init__(self, w, r, d, variant):
582586
self.b6 = C3k2(c1=int(512*w), c2=int(512*w), n=round(2*d), e=0.50, c3k=True)
583587
self.b7 = Conv(c1=int(512*w), c2=int(512*w*r), k=3, s=2)
584588
self.b8 = C3k2(c1=int(512*w*r), c2=int(512*w*r), n=round(2*d), e=0.50, c3k=True)
585-
self.b9 = SPPF(c1=int(512*w*r), c2=int(512*w*r))
589+
self.b9 = SPPF(c1=int(512*w*r), c2=int(512*w*r), shortcut=sppf_shortcut)
586590
self.b10 = PSA(int(512*w*r), n=round(2*d))
587591

588592
def forward(self, x):
@@ -610,7 +614,11 @@ def forward(self, x):
610614
x6 = self.b6(self.b5(x4)) # 6 P4/16
611615
x8 = self.b8(self.b7(x6)) # 8 P5/32
612616
return x4, x6, x8
613-
617+
618+
class BackboneV26(BackboneV11):
619+
def __init__(self, w, r, d, variant):
620+
super().__init__(w, r, d, variant, sppf_shortcut=True)
621+
614622
class EfficientRep(nn.Module):
615623
def __init__(self, w, d, cspsppf=False):
616624
super().__init__()

0 commit comments

Comments
 (0)