@@ -279,7 +279,7 @@ def forward(self, x):
279279 return x
280280
281281class SPPF (nn .Module ):
282- def __init__ (self , c1 , c2 , acts = [nn .Identity ( ), nn .SiLU (True )], shortcut = False ): # equivalent to SPP(k=(5, 9, 13))
282+ def __init__ (self , c1 , c2 , acts = [nn .SiLU ( True ), nn .SiLU (True )], shortcut = False ): # equivalent to SPP(k=(5, 9, 13))
283283 super ().__init__ ()
284284 c_ = c1 // 2 # hidden channels
285285 self .add = shortcut and c1 == c2
@@ -567,7 +567,7 @@ def forward(self, x):
567567 return x4 , x6 , x10
568568
569569class BackboneV11 (nn .Module ):
570- def __init__ (self , w , r , d , variant , sppf_shortcut = False ):
570+ def __init__ (self , w , r , d , variant , sppf_shortcut = False , sppf_acts = [ nn . SiLU ( True ), nn . SiLU ( True )] ):
571571 super ().__init__ ()
572572 c3k = variant in "mlx"
573573 self .b0 = Conv (c1 = 3 , c2 = int (64 * w ), k = 3 , s = 2 )
@@ -579,7 +579,7 @@ def __init__(self, w, r, d, variant, sppf_shortcut=False):
579579 self .b6 = C3k2 (c1 = int (512 * w ), c2 = int (512 * w ), n = round (2 * d ), e = 0.50 , c3k = True )
580580 self .b7 = Conv (c1 = int (512 * w ), c2 = int (512 * w * r ), k = 3 , s = 2 )
581581 self .b8 = C3k2 (c1 = int (512 * w * r ), c2 = int (512 * w * r ), n = round (2 * d ), e = 0.50 , c3k = True )
582- self .b9 = SPPF (c1 = int (512 * w * r ), c2 = int (512 * w * r ), shortcut = sppf_shortcut )
582+ self .b9 = SPPF (c1 = int (512 * w * r ), c2 = int (512 * w * r ), shortcut = sppf_shortcut , acts = sppf_acts )
583583 self .b10 = PSA (int (512 * w * r ), n = round (2 * d ))
584584
585585 def forward (self , x ):
@@ -988,6 +988,7 @@ def __init__(self, nc=80, ch=(), dfl=True, separable=False, end2end=False):
988988 def spconv (c1 , c2 , k ): return nn .Sequential (Conv (c1 ,c1 ,k ,g = c1 ),Conv (c1 ,c2 ,1 ))
989989 conv = spconv if separable else Conv
990990 self .nc = nc # number of classes
991+ self .dfl = dfl
991992 self .reg_max = 16 if dfl else 1 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
992993 self .no = nc + self .reg_max * 4 # number of outputs per anchor
993994 self .strides = [8 , 16 , 32 ] # strides computed during build
@@ -1004,7 +1005,7 @@ def forward_private(self, xs, cv2, cv3, targets=None):
10041005 feats = [rearrange (torch .cat ((c1 (x ), c2 (x )), 1 ), 'b f h w -> b (h w) f' ) for x ,c1 ,c2 in zip (xs , cv2 , cv3 )]
10051006 dist , cls = torch .cat (feats , 1 ).split ((4 * self .reg_max , self .nc ), - 1 )
10061007 dist = rearrange (dist , 'b n (k r) -> b n k r' , k = 4 )
1007- ltrb = torch .einsum ('bnkr, r -> bnk' , dist .softmax (- 1 ), self .r )
1008+ ltrb = torch .einsum ('bnkr, r -> bnk' , dist .softmax (- 1 ), self .r ) if self . dfl else dist . squeeze ( - 1 )
10081009 box = dist2box (ltrb , sxy , strides )
10091010 pred = torch .cat ((box , cls .sigmoid ()), - 1 )
10101011
@@ -1177,7 +1178,7 @@ def __init__(self, variant, num_classes):
11771178class Yolov26 (YoloBase ):
11781179 def __init__ (self , variant , num_classes ):
11791180 d , w , r = get_variant_multiplesV26 (variant )
1180- super ().__init__ (BackboneV11 (w , r , d , variant , sppf_shortcut = True ),
1181+ super ().__init__ (BackboneV11 (w , r , d , variant , sppf_shortcut = True , sppf_acts = [ nn . Identity (), nn . SiLU ( True )] ),
11811182 HeadV11 (w , r , d , variant , is26 = True ),
11821183 Detect (num_classes , ch = (int (256 * w ), int (512 * w ), int (512 * w * r )), separable = True , dfl = False , end2end = True ),
11831184 variant )
0 commit comments