@@ -79,6 +79,9 @@ def get_variant_multiplesV11(variant: str):
7979def get_variant_multiplesV12 (variant : str ):
8080 return get_variant_multiplesV11 (variant )
8181
82+ def get_variant_multiplesV26 (variant : str ):
83+ return get_variant_multiplesV11 (variant )
84+
8285def batchnorms (n : nn .Module ):
8386 for m in n .modules ():
8487 if isinstance (m , (nn .BatchNorm1d , nn .BatchNorm2d )):
@@ -279,19 +282,20 @@ def forward(self, x):
279282 return x
280283
281284class SPPF (nn .Module ):
282- def __init__ (self , c1 , c2 , act = nn .SiLU (True )): # equivalent to SPP(k=(5, 9, 13))
285+ def __init__ (self , c1 , c2 , act = nn .SiLU (True ), shortcut = False ): # equivalent to SPP(k=(5, 9, 13))
283286 super ().__init__ ()
284287 c_ = c1 // 2 # hidden channels
285- conv = partial ( Conv , act = act )
286- self .cv1 = conv (c1 , c_ , 1 , 1 )
287- self .cv2 = conv (c_ * 4 , c2 , 1 , 1 )
288+ self . add = shortcut
289+ self .cv1 = Conv (c1 , c_ , 1 , 1 , act = act )
290+ self .cv2 = Conv (c_ * 4 , c2 , 1 , 1 , act = act )
288291 self .m = nn .MaxPool2d (kernel_size = 5 , stride = 1 , padding = 2 )
289292 def forward (self , x ):
290293 x = self .cv1 (x )
291294 y1 = self .m (x )
292295 y2 = self .m (y1 )
293296 y3 = self .m (y2 )
294- return self .cv2 (torch .cat ((x , y1 , y2 , y3 ), 1 ))
297+ y = self .cv2 (torch .cat ((x , y1 , y2 , y3 ), 1 ))
298+ return x + y if self .add else y
295299
296300class SPPCSPC (nn .Module ):
297301 def __init__ (self , c1 , c2 , e = 0.5 , act = nn .SiLU (True )):
@@ -570,7 +574,7 @@ def forward(self, x):
570574 return x4 , x6 , x10
571575
572576class BackboneV11 (nn .Module ):
573- def __init__ (self , w , r , d , variant ):
577+ def __init__ (self , w , r , d , variant , sppf_shortcut = False ):
574578 super ().__init__ ()
575579 c3k = variant in "mlx"
576580 self .b0 = Conv (c1 = 3 , c2 = int (64 * w ), k = 3 , s = 2 )
@@ -582,7 +586,7 @@ def __init__(self, w, r, d, variant):
582586 self .b6 = C3k2 (c1 = int (512 * w ), c2 = int (512 * w ), n = round (2 * d ), e = 0.50 , c3k = True )
583587 self .b7 = Conv (c1 = int (512 * w ), c2 = int (512 * w * r ), k = 3 , s = 2 )
584588 self .b8 = C3k2 (c1 = int (512 * w * r ), c2 = int (512 * w * r ), n = round (2 * d ), e = 0.50 , c3k = True )
585- self .b9 = SPPF (c1 = int (512 * w * r ), c2 = int (512 * w * r ))
589+ self .b9 = SPPF (c1 = int (512 * w * r ), c2 = int (512 * w * r ), shortcut = sppf_shortcut )
586590 self .b10 = PSA (int (512 * w * r ), n = round (2 * d ))
587591
588592 def forward (self , x ):
@@ -610,7 +614,11 @@ def forward(self, x):
610614 x6 = self .b6 (self .b5 (x4 )) # 6 P4/16
611615 x8 = self .b8 (self .b7 (x6 )) # 8 P5/32
612616 return x4 , x6 , x8
613-
617+
618+ class BackboneV26 (BackboneV11 ):
619+ def __init__ (self , w , r , d , variant ):
620+ super ().__init__ (w , r , d , variant , sppf_shortcut = True )
621+
614622class EfficientRep (nn .Module ):
615623 def __init__ (self , w , d , cspsppf = False ):
616624 super ().__init__ ()
0 commit comments