@@ -366,11 +366,17 @@ def __init__(self, old_detect, use_rvc2: bool):
366366
367367 self .use_rvc2 = use_rvc2
368368
369- self .proj_conv = nn .Conv2d (old_detect .dfl .c1 , 1 , 1 , bias = False ).requires_grad_ (
370- False
371- )
372- x = torch .arange (old_detect .dfl .c1 , dtype = torch .float )
373- self .proj_conv .weight .data [:] = nn .Parameter (x .view (1 , old_detect .dfl .c1 , 1 , 1 ))
369+ # yolo26: dfl will be nn.Identity(), we set proj_conv = None and skip the DFL block in forward
370+ if hasattr (old_detect .dfl , "c1" ):
371+ self .proj_conv = nn .Conv2d (
372+ old_detect .dfl .c1 , 1 , 1 , bias = False
373+ ).requires_grad_ (False )
374+ x = torch .arange (old_detect .dfl .c1 , dtype = torch .float )
375+ self .proj_conv .weight .data [:] = nn .Parameter (
376+ x .view (1 , old_detect .dfl .c1 , 1 , 1 )
377+ )
378+ else :
379+ self .proj_conv = None
374380
375381 def forward (self , x ):
376382 bs = x [0 ].shape [0 ] # batch size
@@ -382,9 +388,10 @@ def forward(self, x):
382388
383389 # ------------------------------
384390 # DFL PART
385- box = box .view (bs , 4 , self .reg_max , h * w ).permute (0 , 2 , 1 , 3 )
386- box = self .proj_conv (F .softmax (box , dim = 1 ))[:, 0 ]
387- box = box .reshape ([bs , 4 , h , w ])
391+ if self .proj_conv is not None :
392+ box = box .view (bs , 4 , self .reg_max , h * w ).permute (0 , 2 , 1 , 3 )
393+ box = self .proj_conv (F .softmax (box , dim = 1 ))[:, 0 ]
394+ box = box .reshape ([bs , 4 , h , w ])
388395 # ------------------------------
389396
390397 cls = self .cv3 [i ](x [i ])
0 commit comments