Skip to content

Commit 5f882cc

Browse files
Add DirectCLR (#781) (#1874)
* Move DirectCLR to loss-based implementation (#1874) * Pass both views in single pass DirectCLR (#1874) * Add example notebook DirectCLR (#1874)
1 parent e201e00 commit 5f882cc

File tree

5 files changed

+422
-0
lines changed

5 files changed

+422
-0
lines changed
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "0",
6+
"metadata": {},
7+
"source": [
8+
"This example requires the following dependencies to be installed:\n",
9+
"pip install lightly"
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": null,
15+
"id": "1",
16+
"metadata": {},
17+
"outputs": [],
18+
"source": [
19+
"!pip install lightly"
20+
]
21+
},
22+
{
23+
"cell_type": "markdown",
24+
"id": "2",
25+
"metadata": {},
26+
"source": [
27+
"Note: The model and training settings do not follow the reference settings\n",
28+
"from the paper. The settings are chosen such that the example can easily be\n",
29+
"run on a small dataset with a single GPU."
30+
]
31+
},
32+
{
33+
"cell_type": "code",
34+
"execution_count": null,
35+
"id": "3",
36+
"metadata": {},
37+
"outputs": [],
38+
"source": [
39+
"import torch\n",
40+
"from torch.nn import Sequential\n",
41+
"from torch.optim import SGD\n",
42+
"from torch.utils.data import DataLoader\n",
43+
"from torchvision import models\n",
44+
"from torchvision.datasets import CIFAR10"
45+
]
46+
},
47+
{
48+
"cell_type": "code",
49+
"execution_count": null,
50+
"id": "4",
51+
"metadata": {},
52+
"outputs": [],
53+
"source": [
54+
"from lightly.loss import DirectCLRLoss\n",
55+
"from lightly.transforms.simclr_transform import SimCLRTransform"
56+
]
57+
},
58+
{
59+
"cell_type": "code",
60+
"execution_count": null,
61+
"id": "5",
62+
"metadata": {},
63+
"outputs": [],
64+
"source": [
65+
"resnet = models.resnet18()\n",
66+
"model = Sequential(*list(resnet.children())[:-1])"
67+
]
68+
},
69+
{
70+
"cell_type": "code",
71+
"execution_count": null,
72+
"id": "6",
73+
"metadata": {},
74+
"outputs": [],
75+
"source": [
76+
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
77+
"model.to(device)"
78+
]
79+
},
80+
{
81+
"cell_type": "code",
82+
"execution_count": null,
83+
"id": "7",
84+
"metadata": {},
85+
"outputs": [],
86+
"source": [
87+
"transform = SimCLRTransform(input_size=32, gaussian_blur=0.0)\n",
88+
"dataset = CIFAR10(\"datasets/cifar10\", download=True, transform=transform)\n",
89+
"# or create a dataset from a folder containing images or videos:\n",
90+
"# dataset = LightlyDataset(\"path/to/folder\", transform=transform)"
91+
]
92+
},
93+
{
94+
"cell_type": "code",
95+
"execution_count": null,
96+
"id": "8",
97+
"metadata": {},
98+
"outputs": [],
99+
"source": [
100+
"dataloader = DataLoader(\n",
101+
" dataset,\n",
102+
" batch_size=256,\n",
103+
" shuffle=True,\n",
104+
" drop_last=True,\n",
105+
" num_workers=8,\n",
106+
")"
107+
]
108+
},
109+
{
110+
"cell_type": "code",
111+
"execution_count": null,
112+
"id": "9",
113+
"metadata": {},
114+
"outputs": [],
115+
"source": [
116+
"criterion = DirectCLRLoss(loss_dim=32)\n",
117+
"optimizer = SGD(model.parameters(), lr=0.06)"
118+
]
119+
},
120+
{
121+
"cell_type": "code",
122+
"execution_count": null,
123+
"id": "10",
124+
"metadata": {},
125+
"outputs": [],
126+
"source": [
127+
"print(\"Starting Training\")\n",
128+
"for epoch in range(10):\n",
129+
" total_loss = 0\n",
130+
" for batch in dataloader:\n",
131+
" x0, x1 = batch[0]\n",
132+
" x = torch.cat([x0, x1]).to(device)\n",
133+
" z0, z1 = model(x).chunk(2, dim=0)\n",
134+
" loss = criterion(z0, z1)\n",
135+
" total_loss += loss.detach()\n",
136+
" loss.backward()\n",
137+
" optimizer.step()\n",
138+
" optimizer.zero_grad()\n",
139+
" avg_loss = total_loss / len(dataloader)\n",
140+
" print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")"
141+
]
142+
}
143+
],
144+
"metadata": {
145+
"jupytext": {
146+
"cell_metadata_filter": "-all",
147+
"main_language": "python",
148+
"notebook_metadata_filter": "-all"
149+
}
150+
},
151+
"nbformat": 4,
152+
"nbformat_minor": 5
153+
}

examples/pytorch/directclr.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# This example requires the following dependencies to be installed:
2+
# pip install lightly
3+
4+
# Note: The model and training settings do not follow the reference settings
5+
# from the paper. The settings are chosen such that the example can easily be
6+
# run on a small dataset with a single GPU.
7+
8+
import torch
9+
from torch.nn import Sequential
10+
from torch.optim import SGD
11+
from torch.utils.data import DataLoader
12+
from torchvision import models
13+
from torchvision.datasets import CIFAR10
14+
15+
from lightly.loss import DirectCLRLoss
16+
from lightly.transforms.simclr_transform import SimCLRTransform
17+
18+
resnet = models.resnet18()
19+
model = Sequential(*list(resnet.children())[:-1])
20+
21+
device = "cuda" if torch.cuda.is_available() else "cpu"
22+
model.to(device)
23+
24+
transform = SimCLRTransform(input_size=32, gaussian_blur=0.0)
25+
dataset = CIFAR10("datasets/cifar10", download=True, transform=transform)
26+
# or create a dataset from a folder containing images or videos:
27+
# dataset = LightlyDataset("path/to/folder", transform=transform)
28+
29+
dataloader = DataLoader(
30+
dataset,
31+
batch_size=256,
32+
shuffle=True,
33+
drop_last=True,
34+
num_workers=8,
35+
)
36+
37+
criterion = DirectCLRLoss(loss_dim=32)
38+
optimizer = SGD(model.parameters(), lr=0.06)
39+
40+
print("Starting Training")
41+
for epoch in range(10):
42+
total_loss = 0
43+
for batch in dataloader:
44+
x0, x1 = batch[0]
45+
x = torch.cat([x0, x1]).to(device)
46+
z0, z1 = model(x).chunk(2, dim=0)
47+
loss = criterion(z0, z1)
48+
total_loss += loss.detach()
49+
loss.backward()
50+
optimizer.step()
51+
optimizer.zero_grad()
52+
avg_loss = total_loss / len(dataloader)
53+
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

lightly/loss/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from lightly.loss.dcl_loss import DCLLoss, DCLWLoss
77
from lightly.loss.detcon_loss import DetConBLoss, DetConSLoss
88
from lightly.loss.dino_loss import DINOLoss
9+
from lightly.loss.directclr_loss import DirectCLRLoss
910
from lightly.loss.emp_ssl_loss import EMPSSLLoss
1011
from lightly.loss.ibot_loss import IBOTPatchLoss
1112
from lightly.loss.koleo_loss import KoLeoLoss

lightly/loss/directclr_loss.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
""" Contrastive Loss Functions """
2+
3+
# Copyright (c) 2020. Lightly AG and its affiliates.
4+
# All Rights Reserved
5+
6+
7+
from typing import Sequence, Union
8+
9+
from torch import Tensor
10+
11+
from lightly.loss.ntx_ent_loss import NTXentLoss
12+
13+
14+
class DirectCLRLoss(NTXentLoss):
15+
"""Implementation of the NT-Xent based DirectCLR Loss.
16+
17+
Following the DirectCLR[0] paper, this loss should be used without projection
18+
head. Set `loss_dim` to the desired truncated representation length.
19+
DirectCLRLoss inherits from NTXentLoss, its parameters can be set after
20+
setting `loss_dim`.
21+
22+
- [0] DirectCLR, 2021, https://arxiv.org/abs/2110.09348
23+
24+
Attributes:
25+
loss_dim:
26+
Computes the loss only on the first loss_dim values of the encoding.
27+
temperature:
28+
From NTXentLoss: scale logits by the inverse of the temperature.
29+
memory_bank_size:
30+
From NTXentLoss: size of the memory bank as (num_features, dim) tuple.
31+
num_features are the number of negative samples stored in the memory bank.
32+
If num_features is 0, the memory bank is disabled. Use 0 for SimCLR. For
33+
MoCo we typically use numbers like 4096 or 65536.
34+
Deprecated: If only a single integer is passed, it is interpreted as the
35+
number of features and the feature dimension is inferred from the first
36+
batch stored in the memory bank. Leaving out the feature dimension might
37+
lead to errors in distributed training.
38+
gather_distributed:
39+
From NTXentLoss: if True then negatives from all GPUs are gathered before
40+
the loss calculation. If a memory bank is used and gather_distributed is
41+
True, then tensors from all gpus are gathered before the memory bank is
42+
updated.
43+
44+
Examples:
45+
>>> # initialize loss function
46+
>>> loss_fn = DirectCLRLoss()
47+
>>>
48+
>>> # generate two random transforms of images
49+
>>> t0 = transforms(images)
50+
>>> t1 = transforms(images)
51+
>>>
52+
>>> # feed through backbone without projection head
53+
>>> out0, out1 = model(t0), model(t1)
54+
>>>
55+
>>> # calculate loss
56+
>>> loss = loss_fn(out0, out1)
57+
58+
"""
59+
60+
def __init__(
61+
self,
62+
loss_dim: int = 64,
63+
temperature: float = 0.5,
64+
memory_bank_size: Union[int, Sequence[int]] = 0,
65+
gather_distributed: bool = False,
66+
):
67+
"""Initializes the DirectCLRLoss module with the specified parameters.
68+
69+
Args:
70+
loss_dim:
71+
Computes the loss only on the first `loss_dim` values of the encoding.
72+
temperature:
73+
Scale logits by the inverse of the temperature.
74+
memory_bank_size:
75+
Size of the memory bank.
76+
gather_distributed:
77+
If True, negatives from all GPUs are gathered before the loss calculation.
78+
"""
79+
super().__init__(
80+
temperature=temperature,
81+
memory_bank_size=memory_bank_size,
82+
gather_distributed=gather_distributed,
83+
)
84+
self.loss_dim = loss_dim
85+
86+
def forward(self, out0: Tensor, out1: Tensor) -> Tensor:
87+
"""Forward pass through DirectCLR Loss.
88+
89+
To be used directly on the encoding without projection head. Flattens
90+
each output encoding and truncates it to `loss_dim` length, then computes
91+
the NTXentLoss.
92+
93+
Args:
94+
out0:
95+
Output projections of the first set of transformed images.
96+
Shape: (batch_size, embedding_size)
97+
out1:
98+
Output projections of the second set of transformed images.
99+
Shape: (batch_size, embedding_size)
100+
101+
Returns:
102+
DirectCLR Loss value.
103+
"""
104+
105+
out0 = out0.flatten(start_dim=1)[:, : self.loss_dim]
106+
out1 = out1.flatten(start_dim=1)[:, : self.loss_dim]
107+
108+
loss: Tensor = super().forward(out0, out1)
109+
110+
return loss

0 commit comments

Comments
 (0)