@@ -97,11 +97,6 @@ def copy_params(n1: nn.Module, n2: nn.Module):
9797 m1 .running_var .data .copy_ (m2 .running_var .data )
9898 m1 .eps = m2 .eps
9999 m1 .momentum = m2 .momentum
100-
101- def init_batchnorms (net : nn .Module ):
102- for m in batchnorms (net ):
103- m .eps = 1e-3
104- m .momentum = 0.03
105100
106101def count_parameters (net : torch .nn .Module , include_stats = True ):
107102 return sum (p .numel () for p in net .parameters ()) + (sum (m .running_mean .numel () + m .running_var .numel () for m in batchnorms (net )) if include_stats else 0 )
@@ -125,7 +120,7 @@ def forward(self, x):
125120class Conv (nn .Sequential ):
126121 def __init__ (self , c1 , c2 , k = 1 , s = 1 , p = None , g = 1 , act = nn .SiLU (True )):
127122 super ().__init__ (nn .Conv2d (c1 , c2 , k , s , default (p ,k // 2 ), groups = g , bias = False ),
128- nn .BatchNorm2d (c2 ),
123+ nn .BatchNorm2d (c2 , eps = 1e-3 , momentum = 0.03 ),
129124 act )
130125
131126def Con5 (c1 , c2 = None , spp = False , act = actV3 ):
@@ -220,9 +215,11 @@ def forward(self, x):
220215 b = self .m (self .cv2 (x ))
221216 return self .cv3 (torch .cat ((b , a ), 1 ))
222217
223- def C3k2 (c1 , c2 , n = 1 , shortcut = True , e = 0.5 , c3k = False ):
218+ def C3k2 (c1 , c2 , n = 1 , shortcut = True , e = 0.5 , c3k = False , attn = False ):
224219 net = C2f (c1 , c2 , n , shortcut , e )
225- blk = C3 (net .c_ , net .c_ , k = (3 ,3 ), shortcut = shortcut , n = 2 ) if c3k else Bottleneck (net .c_ , k = (3 ,3 ), shortcut = shortcut )
220+ if attn : blk = nn .Sequential (Bottleneck (net .c_ , k = (3 ,3 ), shortcut = shortcut ), PSABlock (net .c_ , num_heads = net .c_ // 64 , attn_ratio = 0.5 ))
221+ elif c3k : blk = C3 (net .c_ , net .c_ , k = (3 ,3 ), shortcut = shortcut , n = 2 )
222+ else : blk = Bottleneck (net .c_ , k = (3 ,3 ), shortcut = shortcut )
226223 net .m = nn .ModuleList (deepcopy (blk ) for _ in range (n ))
227224 return net
228225
@@ -282,20 +279,20 @@ def forward(self, x):
282279 return x
283280
284281class SPPF (nn .Module ):
285- def __init__ (self , c1 , c2 , act = nn .SiLU (True ), shortcut = False ): # equivalent to SPP(k=(5, 9, 13))
282+ def __init__ (self , c1 , c2 , acts = [ nn .Identity (), nn . SiLU (True )] , shortcut = False ): # equivalent to SPP(k=(5, 9, 13))
286283 super ().__init__ ()
287284 c_ = c1 // 2 # hidden channels
288- self .add = shortcut
289- self .cv1 = Conv (c1 , c_ , 1 , 1 , act = act )
290- self .cv2 = Conv (c_ * 4 , c2 , 1 , 1 , act = act )
285+ self .add = shortcut and c1 == c2
286+ self .cv1 = Conv (c1 , c_ , 1 , 1 , act = acts [ 0 ] )
287+ self .cv2 = Conv (c_ * 4 , c2 , 1 , 1 , act = acts [ 1 ] )
291288 self .m = nn .MaxPool2d (kernel_size = 5 , stride = 1 , padding = 2 )
292- def forward (self , x ):
293- x = self .cv1 (x )
289+ def forward (self , input ):
290+ x = self .cv1 (input )
294291 y1 = self .m (x )
295292 y2 = self .m (y1 )
296293 y3 = self .m (y2 )
297294 y = self .cv2 (torch .cat ((x , y1 , y2 , y3 ), 1 ))
298- return x + y if self .add else y
295+ return input + y if self .add else y
299296
300297class SPPCSPC (nn .Module ):
301298 def __init__ (self , c1 , c2 , e = 0.5 , act = nn .SiLU (True )):
@@ -341,35 +338,31 @@ def forward(self, x):
341338 x = self .proj (x + self .pe (v ))
342339 return x
343340
344- def PSABlock (c , num_heads = 4 , attn_ratio = 4 ):
345- return nn .Sequential (Residual (Attention (c , num_heads = num_heads , attn_ratio = attn_ratio )),
346- Residual (nn .Sequential (Conv (c , c * 2 , 1 ), Conv (c * 2 , c , 1 , act = nn .Identity ()))))
341+ def PSABlock (c , num_heads = 4 , attn_ratio = 0.5 , area = None , e = 2 ):
342+ c_ = int (c * e )
343+ return nn .Sequential (Residual (Attention (c , num_heads = num_heads , attn_ratio = attn_ratio , area = area )),
344+ Residual (nn .Sequential (Conv (c , c_ , 1 ), Conv (c_ , c , 1 , act = nn .Identity ()))))
347345
348346class PSA (nn .Module ):
349347 def __init__ (self , c , e = 0.5 , n = 1 ):
350348 super ().__init__ ()
351349 c_ = int (c * e ) # hidden channels
352350 self .cv1 = Conv (c , 2 * c_ , 1 )
353351 self .cv2 = Conv (2 * c_ , c , 1 )
354- self .net = Repeat (PSABlock (c_ , num_heads = c_ // 64 , attn_ratio = 0.5 ), n )
352+ self .net = Repeat (PSABlock (c_ , num_heads = c_ // 64 , attn_ratio = 0.5 ), n )
355353
356354 def forward (self , x ):
357355 a , b = self .cv1 (x ).chunk (2 , 1 )
358356 return self .cv2 (torch .cat ((a , self .net (b )), 1 ))
359357
360- def ABlock (c , num_heads , area = 1 , mlp_ratio = 1.2 ):
361- c_ = int (c * mlp_ratio )
362- return nn .Sequential (Residual (Attention (c , num_heads = num_heads , attn_ratio = 1 , area = area )),
363- Residual (nn .Sequential (Conv (c , c_ , 1 ), Conv (c_ , c , 1 , act = nn .Identity ()))))
364-
365358class A2C2f (nn .Module ):
366359 def __init__ (self , c1 , c2 , n = 1 , shortcut = True , e = 0.5 , a2 = True , residual = False , area = 1 , mlp_ratio = 2.0 ):
367360 super ().__init__ ()
368361 self .c_ = int (c2 * e ) # hidden channels
369362 self .cv1 = Conv (c1 , self .c_ , 1 )
370363 self .cv2 = Conv ((1 + n ) * self .c_ , c2 , 1 )
371364 self .g = nn .Parameter (0.01 * torch .ones (c2 )) if (a2 and residual ) else None
372- self .m = nn .ModuleList (Repeat (ABlock (self .c_ , self .c_ // 32 , area , mlp_ratio ), 2 ) if a2 else
365+ self .m = nn .ModuleList (Repeat (PSABlock (self .c_ , num_heads = self .c_ // 32 , attn_ratio = 1 , area = area , e = mlp_ratio ), 2 ) if a2 else
373366 C3 (self .c_ , self .c_ , k = (3 ,3 ), shortcut = shortcut , n = 2 ) for _ in range (n ))
374367
375368 def forward (self , x ):
@@ -615,14 +608,10 @@ def forward(self, x):
615608 x8 = self .b8 (self .b7 (x6 )) # 8 P5/32
616609 return x4 , x6 , x8
617610
618- class BackboneV26 (BackboneV11 ):
619- def __init__ (self , w , r , d , variant ):
620- super ().__init__ (w , r , d , variant , sppf_shortcut = True )
621-
622611class EfficientRep (nn .Module ):
623612 def __init__ (self , w , d , cspsppf = False ):
624613 super ().__init__ ()
625- sppf = partial (SPPCSPC , e = 0.25 ) if cspsppf else SPPF
614+ sppf = partial (SPPCSPC , e = 0.25 , act = nn . ReLU ( True )) if cspsppf else partial ( SPPF , acts = [ nn . ReLU ( True ), nn . ReLU ( True )])
626615 self .b0 = RepConv ( c1 = 3 , c2 = int (64 * w ), s = 2 , act = F .relu )
627616 self .b1 = RepConv ( c1 = int (64 * w ), c2 = int (128 * w ), s = 2 , act = F .relu )
628617 self .b2 = RepBlock (c1 = int (128 * w ), c2 = int (128 * w ), n = round (6 * d ))
@@ -632,7 +621,7 @@ def __init__(self, w, d, cspsppf=False):
632621 self .b6 = RepBlock (c1 = int (512 * w ), c2 = int (512 * w ), n = round (18 * d ))
633622 self .b7 = RepConv ( c1 = int (512 * w ), c2 = int (1024 * w ),s = 2 , act = F .relu )
634623 self .b8 = RepBlock (c1 = int (1024 * w ),c2 = int (1024 * w ),n = round (6 * d ))
635- self .b9 = sppf ( c1 = int (1024 * w ),c2 = int (1024 * w ), act = nn . ReLU ( True ))
624+ self .b9 = sppf (c1 = int (1024 * w ),c2 = int (1024 * w ))
636625
637626 def forward (self , x ):
638627 x4 = self .b2 (self .b1 (self .b0 (x ))) # p2/4
@@ -644,7 +633,7 @@ def forward(self, x):
644633class CSPBepBackbone (nn .Module ):
645634 def __init__ (self , w , d , csp_e = 1 / 2 , cspsppf = False ):
646635 super ().__init__ ()
647- sppf = partial (SPPCSPC , e = 0.25 ) if cspsppf else SPPF
636+ sppf = partial (SPPCSPC , e = 0.25 , act = nn . ReLU ( True )) if cspsppf else partial ( SPPF , acts = [ nn . SiLU (), nn . SiLU ()])
648637 self .b0 = RepConv (c1 = 3 , c2 = int (64 * w ), s = 2 , act = F .relu )
649638 self .b1 = RepConv (c1 = int (64 * w ), c2 = int (128 * w ), s = 2 , act = F .relu )
650639 self .b2 = BepC3 ( c1 = int (128 * w ), c2 = int (128 * w ), e = csp_e , n = round (6 * d ))
@@ -654,7 +643,7 @@ def __init__(self, w, d, csp_e=1/2, cspsppf=False):
654643 self .b6 = BepC3 ( c1 = int (512 * w ), c2 = int (512 * w ), e = csp_e , n = round (18 * d ))
655644 self .b7 = RepConv (c1 = int (512 * w ), c2 = int (1024 * w ),s = 2 , act = F .relu )
656645 self .b8 = BepC3 ( c1 = int (1024 * w ),c2 = int (1024 * w ),e = csp_e , n = round (6 * d ))
657- self .b9 = sppf ( c1 = int (1024 * w ),c2 = int (1024 * w ), act = nn . ReLU ( True ) )
646+ self .b9 = sppf ( c1 = int (1024 * w ),c2 = int (1024 * w ))
658647
659648 def forward (self , x ):
660649 x4 = self .b2 (self .b1 (self .b0 (x ))) # p2/4
@@ -790,16 +779,17 @@ def forward(self, x4, x6, x10):
790779 return [x16 , x19 , x22 ]
791780
792781class HeadV11 (nn .Module ):
793- def __init__ (self , w , r , d , variant ):
782+ def __init__ (self , w , r , d , variant , is26 = False ):
794783 super ().__init__ ()
795- c3k = variant in "mlx"
784+ c3k = True if is26 else variant in "mlx"
785+ n = 1 if is26 else 2
796786 self .up = nn .Upsample (scale_factor = 2 )
797- self .n1 = C3k2 (c1 = int (512 * w * (1 + r )), c2 = int (512 * w ), n = round (2 * d ), c3k = c3k )
798- self .n2 = C3k2 (c1 = int (512 * w * 2 ), c2 = int (256 * w ), n = round (2 * d ), c3k = c3k )
799- self .n3 = Conv (c1 = int (256 * w ), c2 = int (256 * w ), k = 3 , s = 2 )
800- self .n4 = C3k2 (c1 = int (768 * w ), c2 = int (512 * w ), n = round (2 * d ), c3k = c3k )
801- self .n5 = Conv (c1 = int (512 * w ), c2 = int (512 * w ), k = 3 , s = 2 )
802- self .n6 = C3k2 (c1 = int (512 * w * (1 + r )), c2 = int (512 * w * r ), n = round (2 * d ), c3k = True )
787+ self .n1 = C3k2 (c1 = int (512 * w * (1 + r )), c2 = int (512 * w ), n = round (2 * d ), c3k = c3k )
788+ self .n2 = C3k2 (c1 = int (512 * w * 2 ), c2 = int (256 * w ), n = round (2 * d ), c3k = c3k )
789+ self .n3 = Conv (c1 = int (256 * w ), c2 = int (256 * w ), k = 3 , s = 2 )
790+ self .n4 = C3k2 (c1 = int (768 * w ), c2 = int (512 * w ), n = round (2 * d ), c3k = c3k )
791+ self .n5 = Conv (c1 = int (512 * w ), c2 = int (512 * w ), k = 3 , s = 2 )
792+ self .n6 = C3k2 (c1 = int (512 * w * (1 + r )), c2 = int (512 * w * r ), n = max ( 1 , round (n * d )) , c3k = True , attn = is26 )
803793
804794 def forward (self , x4 , x6 , x10 ):
805795 x13 = self .n1 (torch .cat ([self .up (x10 ),x6 ], 1 )) # 13
@@ -993,20 +983,22 @@ def forward(self, xs, targets=None):
993983 return pred if not exists (targets ) else (pred , {'iou' : loss_iou , 'cls' : loss_cls , 'obj' : loss_obj })
994984
995985class Detect (nn .Module ):
996- def __init__ (self , nc = 80 , ch = (), v11 = False ):
986+ def __init__ (self , nc = 80 , ch = (), dfl = True , separable = False , end2end = False ):
997987 super ().__init__ ()
998988 def spconv (c1 , c2 , k ): return nn .Sequential (Conv (c1 ,c1 ,k ,g = c1 ),Conv (c1 ,c2 ,1 ))
999- conv = spconv if v11 else Conv
989+ conv = spconv if separable else Conv
1000990 self .nc = nc # number of classes
1001- self .reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
991+ self .reg_max = 16 if dfl else 1 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
1002992 self .no = nc + self .reg_max * 4 # number of outputs per anchor
1003993 self .strides = [8 , 16 , 32 ] # strides computed during build
1004994 self .c2 = max ((16 , ch [0 ] // 4 , self .reg_max * 4 ))
1005995 self .c3 = max (ch [0 ], min (self .nc , 100 )) # channels
1006996 self .cv2 = nn .ModuleList (nn .Sequential (Conv (x , self .c2 , 3 ), Conv (self .c2 , self .c2 , 3 ), nn .Conv2d (self .c2 , 4 * self .reg_max , 1 )) for x in ch )
1007997 self .cv3 = nn .ModuleList (nn .Sequential (conv (x , self .c3 , 3 ), conv (self .c3 , self .c3 , 3 ), nn .Conv2d (self .c3 , self .nc , 1 )) for x in ch )
1008- self .r = nn .Parameter (torch .arange (self .reg_max ).float (), requires_grad = False )
1009-
998+ self .r = nn .Parameter (torch .arange (self .reg_max ).float (), requires_grad = False ) if dfl else None
999+ if end2end : self .one2one_cv2 = deepcopy (self .cv2 )
1000+ if end2end : self .one2one_cv3 = deepcopy (self .cv3 )
1001+
10101002 def forward_private (self , xs , cv2 , cv3 , targets = None ):
10111003 sxy , ps , strides = make_anchors (xs , self .strides )
10121004 feats = [rearrange (torch .cat ((c1 (x ), c2 (x )), 1 ), 'b f h w -> b (h w) f' ) for x ,c1 ,c2 in zip (xs , cv2 , cv3 )]
@@ -1179,15 +1171,23 @@ def __init__(self, variant, num_classes):
11791171 d , w , r = get_variant_multiplesV11 (variant )
11801172 super ().__init__ (BackboneV11 (w , r , d , variant ),
11811173 HeadV11 (w , r , d , variant ),
1182- Detect (num_classes , ch = (int (256 * w ), int (512 * w ), int (512 * w * r )), v11 = True ),
1174+ Detect (num_classes , ch = (int (256 * w ), int (512 * w ), int (512 * w * r )), separable = True ),
11831175 variant )
11841176
1177+ class Yolov26 (YoloBase ):
1178+ def __init__ (self , variant , num_classes ):
1179+ d , w , r = get_variant_multiplesV26 (variant )
1180+ super ().__init__ (BackboneV11 (w , r , d , variant , sppf_shortcut = True ),
1181+ HeadV11 (w , r , d , variant , is26 = True ),
1182+ Detect (num_classes , ch = (int (256 * w ), int (512 * w ), int (512 * w * r )), separable = True , dfl = False , end2end = True ),
1183+ variant )
1184+
11851185class Yolov12 (YoloBase ):
11861186 def __init__ (self , variant , num_classes ):
11871187 d , w , r = get_variant_multiplesV12 (variant )
11881188 super ().__init__ (BackboneV12 (w , r , d , variant ),
11891189 HeadV12 (w , r , d ),
1190- Detect (num_classes , ch = (int (256 * w ), int (512 * w ), int (512 * w * r )), v11 = True ),
1190+ Detect (num_classes , ch = (int (256 * w ), int (512 * w ), int (512 * w * r )), separable = True ),
11911191 variant )
11921192
11931193class Yolov6 (YoloBase ):
@@ -1230,4 +1230,4 @@ def barlow_loss(z1, z2, lambda_coeff):
12301230 mask = torch .eye (cross .shape [0 ], dtype = torch .bool , device = cross .device )
12311231 on_diag = (cross [mask ]- 1 ).pow (2 ).sum ()
12321232 off_diag = cross [~ mask ].pow (2 ).sum ()
1233- return (on_diag + lambda_coeff * off_diag , cross )
1233+ return (on_diag + lambda_coeff * off_diag , cross )
0 commit comments