Skip to content

Commit 467895b

Browse files
author
pfeatherstone
committed
remove Barlow Twins loss. You another repo if you want to do SSL pretraining. something like LightlySSL for example
1 parent cddc9ee commit 467895b

File tree

2 files changed

+1
-125
lines changed

2 files changed

+1
-125
lines changed

src/models.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,27 +1188,4 @@ def conf(): return preds[...,4] if has_objectness else preds[...,4:].max(-1)[0]
11881188
nms = torchvision.ops.batched_nms(preds[:,:4], conf(), batch, iou_threshold=nms_thresh)
11891189
preds = preds[nms]
11901190
batch = batch[nms]
1191-
return batch, preds
1192-
1193-
class BarlowTwinsHead(nn.Module):
1194-
def __init__(self, backbone, input_dim, hidden_dim=2048, output_dim=128):
1195-
super().__init__()
1196-
self.net = backbone
1197-
self.proj = nn.Sequential(nn.Linear(input_dim, hidden_dim, bias=True),
1198-
nn.LayerNorm(hidden_dim),
1199-
nn.ReLU(),
1200-
nn.Linear(hidden_dim, output_dim, bias=False))
1201-
1202-
def forward(self, x):
1203-
x = self.net(x)[-1]
1204-
x = x.flatten(2).mean(2)
1205-
x = self.proj(x)
1206-
return x
1207-
1208-
def barlow_loss(z1, z2, lambda_coeff):
1209-
z1, z2 = map(lambda z: (z - z.mean(0)) / z.std(0), (z1,z2))
1210-
cross = (z1.T @ z2) / z1.shape[0]
1211-
mask = torch.eye(cross.shape[0], dtype=torch.bool, device=cross.device)
1212-
on_diag = (cross[mask]-1).pow(2).sum()
1213-
off_diag = cross[~mask].pow(2).sum()
1214-
return (on_diag + lambda_coeff * off_diag, cross)
1191+
return batch, preds

src/pretrain.py

Lines changed: 0 additions & 101 deletions
This file was deleted.

0 commit comments

Comments
 (0)