@@ -1188,27 +1188,4 @@ def conf(): return preds[...,4] if has_objectness else preds[...,4:].max(-1)[0]
11881188 nms = torchvision .ops .batched_nms (preds [:,:4 ], conf (), batch , iou_threshold = nms_thresh )
11891189 preds = preds [nms ]
11901190 batch = batch [nms ]
1191- return batch , preds
1192-
1193- class BarlowTwinsHead (nn .Module ):
1194- def __init__ (self , backbone , input_dim , hidden_dim = 2048 , output_dim = 128 ):
1195- super ().__init__ ()
1196- self .net = backbone
1197- self .proj = nn .Sequential (nn .Linear (input_dim , hidden_dim , bias = True ),
1198- nn .LayerNorm (hidden_dim ),
1199- nn .ReLU (),
1200- nn .Linear (hidden_dim , output_dim , bias = False ))
1201-
1202- def forward (self , x ):
1203- x = self .net (x )[- 1 ]
1204- x = x .flatten (2 ).mean (2 )
1205- x = self .proj (x )
1206- return x
1207-
1208- def barlow_loss (z1 , z2 , lambda_coeff ):
1209- z1 , z2 = map (lambda z : (z - z .mean (0 )) / z .std (0 ), (z1 ,z2 ))
1210- cross = (z1 .T @ z2 ) / z1 .shape [0 ]
1211- mask = torch .eye (cross .shape [0 ], dtype = torch .bool , device = cross .device )
1212- on_diag = (cross [mask ]- 1 ).pow (2 ).sum ()
1213- off_diag = cross [~ mask ].pow (2 ).sum ()
1214- return (on_diag + lambda_coeff * off_diag , cross )
1191+ return batch , preds
0 commit comments