diff --git a/docs/source/examples/dinov2.rst b/docs/source/examples/dinov2.rst new file mode 100644 index 000000000..718f6bc54 --- /dev/null +++ b/docs/source/examples/dinov2.rst @@ -0,0 +1,76 @@ +.. _dinov2: + +DINOv2 +====== + +DINOv2 (DIstillation with NO labels v2) [0]_ is an advanced self-supervised learning framework developed by Meta AI for robust visual representation learning without labeled data. Extending the original DINO [1]_ approach, DINOv2 trains a student network to match outputs from a momentum-averaged teacher network. By leveraging self-distillation objectives at both image and patch levels, enhances both global and local feature learning. Combined with other various innovations in both the training recipe and efficient training implementation, DINOv2 exhibits state-of-the-art performance across various computer vision tasks, including classification, segmentation, and depth estimation, without the necessity for task-specific fine-tuning. + +Key Components +-------------- + +- **Multi-level Objectives**: DINOv2 employs DINO loss for the image-level objective and iBOT [2]_ loss for patch-level objective. This multi-level approach enhances both global and local feature representations, significantly improving performance on dense prediction tasks like segmentation and depth estimation. +- **KoLeo Regularizer**: DINOv2 introduces the KoLeo regularizer [3]_, which promotes uniform spreading of features within a batch, significantly enhancing the quality of nearest-neighbor retrieval tasks without negatively affecting performance on dense downstream tasks. + +Good to Know +------------ + +- **SOTA out-of-the-box**: DINOv2 currently represents the state-of-the-art (SOTA) among self-supervised learning (SSL) methods in computer vision, outperforming existing frameworks in various benchmarks. +- **Relation to other SSL methods**: DINOv2 can be seen as a combination of DINO and iBOT losses with the centering of SwAV [4]_. + +Reference: + + .. [0] `DINOv2: Learning Robust Visual Features without Supervision, 2023 `_ + .. [1] `Emerging Properties in Self-Supervised Vision Transformers, 2021 `_ + .. [2] `iBOT: Image BERT Pre-Training with Online Tokenizer, 2021 `_ + .. [3] `Spreading vectors for similarity search, 2018 `_ + .. [4] `Unsupervised Learning of Visual Features by Contrasting Cluster Assignments, 2020 `_ + + +.. tabs:: + .. tab:: PyTorch + + .. image:: https://img.shields.io/badge/Open%20in%20Colab-blue?logo=googlecolab&label=%20&labelColor=5c5c5c + :target: https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch/dinov2.ipynb + + This example can be run from the command line with:: + + python lightly/examples/pytorch/dinov2.py + + .. literalinclude:: ../../../examples/pytorch/dinov2.py + + .. tab:: Lightning + + .. image:: https://img.shields.io/badge/Open%20in%20Colab-blue?logo=googlecolab&label=%20&labelColor=5c5c5c + :target: https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning/dinov2.ipynb + + This example can be run from the command line with:: + + python lightly/examples/pytorch_lightning/dinov2.py + + .. literalinclude:: ../../../examples/pytorch_lightning/dinov2.py + + .. tab:: Lightning Distributed + + .. image:: https://img.shields.io/badge/Open%20in%20Colab-blue?logo=googlecolab&label=%20&labelColor=5c5c5c + :target: https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning_distributed/dinov2.ipynb + + This example runs on multiple gpus using Distributed Data Parallel (DDP) + training with Pytorch Lightning. At least one GPU must be available on + the system. The example can be run from the command line with:: + + python lightly/examples/pytorch_lightning_distributed/dinov2.py + + The model differs in the following ways from the non-distributed + implementation: + + - Distributed Data Parallel is enabled + - Synchronized Batch Norm is used in place of standard Batch Norm + - Distributed Sampling is used in the dataloader + + Note that Synchronized Batch Norm is optional and the model can also be + trained without it. Without Synchronized Batch Norm the batch norm for + each GPU is only calculated based on the features on that specific GPU. + Distributed Sampling makes sure that each distributed process sees only + a subset of the data. + + .. literalinclude:: ../../../examples/pytorch_lightning_distributed/dinov2.py diff --git a/docs/source/examples/models.rst b/docs/source/examples/models.rst index b7bc92288..1f2322d35 100644 --- a/docs/source/examples/models.rst +++ b/docs/source/examples/models.rst @@ -16,6 +16,7 @@ for PyTorch and PyTorch Lightning to give you a headstart when implementing your dcl.rst densecl.rst dino.rst + dinov2.rst fastsiam.rst mae.rst mmcr.rst diff --git a/examples/notebooks/pytorch/dinov2.ipynb b/examples/notebooks/pytorch/dinov2.ipynb new file mode 100644 index 000000000..8af90bda0 --- /dev/null +++ b/examples/notebooks/pytorch/dinov2.ipynb @@ -0,0 +1,446 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import copy\n", + "from functools import partial" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from timm.models.vision_transformer import vit_small_patch16_224\n", + "from torch import Tensor\n", + "from torch.nn import Module\n", + "from torch.optim import AdamW" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import DINOLoss, IBOTPatchLoss, KoLeoLoss\n", + "from lightly.models.modules import DINOv2ProjectionHead, MaskedVisionTransformerTIMM\n", + "from lightly.models.utils import (\n", + " random_block_mask,\n", + " update_drop_path_rate,\n", + " update_momentum,\n", + ")\n", + "from lightly.transforms.dino_transform import DINOTransform\n", + "from lightly.utils.scheduler import cosine_schedule, linear_warmup_schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "def freeze_eval_module(module: Module) -> None:\n", + " \"\"\"Freeze the parameters of a module.\"\"\"\n", + " for param in module.parameters():\n", + " param.requires_grad = False\n", + " module.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "class DINOv2Head(Module):\n", + " def __init__(\n", + " self, dino_head: DINOv2ProjectionHead, ibot_head: DINOv2ProjectionHead\n", + " ) -> None:\n", + " super().__init__()\n", + " self.dino_head = dino_head\n", + " self.ibot_head = ibot_head" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "class DINOv2(Module):\n", + " def __init__(\n", + " self,\n", + " ibot_separate_head: bool = False,\n", + " ) -> None:\n", + " super().__init__()\n", + "\n", + " # Backbones\n", + " vit_teacher = vit_small_patch16_224(\n", + " pos_embed=\"learn\",\n", + " dynamic_img_size=True,\n", + " init_values=1e-5,\n", + " )\n", + " self.teacher_backbone = MaskedVisionTransformerTIMM(\n", + " vit=vit_teacher,\n", + " antialias=False,\n", + " pos_embed_initialization=\"skip\",\n", + " )\n", + " self.student_backbone = copy.deepcopy(self.teacher_backbone)\n", + " update_drop_path_rate(\n", + " self.student_backbone.vit,\n", + " drop_path_rate=0.1, # we recommend using smaller rates like 0.1 for vit-s-14\n", + " mode=\"uniform\",\n", + " )\n", + "\n", + " freeze_eval_module(self.teacher_backbone)\n", + "\n", + " # Heads\n", + " dino_head = partial(\n", + " DINOv2ProjectionHead,\n", + " input_dim=384,\n", + " )\n", + "\n", + " teacher_dino_head = dino_head()\n", + " student_dino_head = dino_head()\n", + "\n", + " ibot_head = partial(\n", + " DINOv2ProjectionHead,\n", + " input_dim=384,\n", + " )\n", + "\n", + " if ibot_separate_head:\n", + " teacher_ibot_head = ibot_head()\n", + " student_ibot_head = ibot_head()\n", + " else:\n", + " teacher_ibot_head = teacher_dino_head\n", + " student_ibot_head = student_dino_head\n", + "\n", + " self.teacher_head = DINOv2Head(\n", + " dino_head=teacher_dino_head,\n", + " ibot_head=teacher_ibot_head,\n", + " )\n", + " self.student_head = DINOv2Head(\n", + " dino_head=student_dino_head,\n", + " ibot_head=student_ibot_head,\n", + " )\n", + "\n", + " freeze_eval_module(self.teacher_head)\n", + "\n", + " def forward(self, x: Tensor) -> Tensor:\n", + " return self.teacher_backbone(x)\n", + "\n", + " def forward_teacher(self, x: Tensor) -> tuple[Tensor, Tensor]:\n", + " features = self.teacher_backbone.encode(x)\n", + " cls_tokens = features[:, 0]\n", + " return cls_tokens, features\n", + "\n", + " def forward_student(\n", + " self, x: Tensor, mask: Tensor | None\n", + " ) -> tuple[Tensor, Tensor | None]:\n", + " features = self.student_backbone.encode(x, mask=mask)\n", + " cls_tokens = features[:, 0]\n", + " masked_features = None if mask is None else features[mask]\n", + " return cls_tokens, masked_features" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "model = DINOv2()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "transform = DINOTransform(\n", + " global_crop_scale=(0.32, 1),\n", + " local_crop_scale=(0.05, 0.32),\n", + " n_local_views=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "# We ignore object detection annotations by setting target_transform to return 0.\n", + "def target_transform(t):\n", + " return 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"mps\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=target_transform,\n", + ")\n", + "# Or create a dataset from a folder containing images or videos.\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=64,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "# Create the loss functions.\n", + "dino_criterion = DINOLoss()\n", + "ibot_criterion = IBOTPatchLoss()\n", + "koleo_criterion = KoLeoLoss()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "# Move loss to correct device because it also contains parameters.\n", + "dino_criterion = dino_criterion.to(device)\n", + "ibot_criterion = ibot_criterion.to(device)\n", + "koleo_criterion = koleo_criterion.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = AdamW(model.parameters(), lr=0.001)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [ + "epochs = 50\n", + "num_batches = len(dataloader)\n", + "total_steps = epochs * num_batches" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(epochs):\n", + " total_loss = 0\n", + " for batch_idx, batch in enumerate(dataloader):\n", + " views = batch[0]\n", + " views = [view.to(device) for view in views]\n", + " global_views = torch.cat(views[:2])\n", + " local_views = torch.cat(views[2:])\n", + "\n", + " # Masking\n", + " B = len(global_views)\n", + " sequence_length = model.teacher_backbone.sequence_length\n", + " mask = global_views.new_zeros((B, sequence_length), dtype=torch.bool)\n", + "\n", + " # Mask patches except class token.\n", + " H, W = model.teacher_backbone.vit.patch_embed.grid_size\n", + " assert (\n", + " H * W == sequence_length - 1\n", + " ), f\"Unexpected grid size: {H}x{W}, sequence_length {sequence_length}\"\n", + " block_mask = random_block_mask(size=(B, H, W), device=mask.device)\n", + " mask[:, 1:] = block_mask.flatten(start_dim=1)\n", + "\n", + " # Teacher forward\n", + " with torch.no_grad():\n", + " teacher_cls_token, teacher_features = model.forward_teacher(global_views)\n", + " teacher_cls_out = model.teacher_head.dino_head.forward(teacher_cls_token)\n", + " teacher_masked_out = model.teacher_head.ibot_head.forward(\n", + " teacher_features[mask]\n", + " )\n", + "\n", + " # Student forward\n", + " (\n", + " student_global_cls_token,\n", + " student_global_masked_features,\n", + " ) = model.forward_student(global_views, mask=mask)\n", + " student_global_cls_out = model.student_head.dino_head.forward(\n", + " student_global_cls_token\n", + " )\n", + " student_global_masked_out = model.student_head.ibot_head.forward(\n", + " student_global_masked_features\n", + " )\n", + " student_local_cls_token, _ = model.forward_student(local_views, mask=None)\n", + " student_local_cls_out = model.student_head.dino_head.forward(\n", + " student_local_cls_token\n", + " )\n", + " student_cls_out = torch.cat([student_global_cls_out, student_local_cls_out])\n", + "\n", + " # Calculate current global step based on epoch and batch index.\n", + " global_step = epoch * num_batches + batch_idx\n", + "\n", + " # Calculate the loss.\n", + " teacher_temp = linear_warmup_schedule(\n", + " step=global_step,\n", + " warmup_steps=int(30 / epochs * total_steps),\n", + " start_value=0.04,\n", + " end_value=0.07,\n", + " )\n", + " dino_loss = dino_criterion(\n", + " teacher_out=teacher_cls_out.chunk(2),\n", + " student_out=student_cls_out.chunk(len(views)),\n", + " teacher_temp=teacher_temp,\n", + " )\n", + " ibot_loss = ibot_criterion(\n", + " teacher_out=teacher_masked_out,\n", + " student_out=student_global_masked_out,\n", + " mask=block_mask,\n", + " teacher_temp=teacher_temp,\n", + " )\n", + " koleo_loss = 0.1 * sum(\n", + " koleo_criterion(t) for t in student_global_cls_token.chunk(2)\n", + " )\n", + " loss = dino_loss + ibot_loss + koleo_loss\n", + "\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + "\n", + " # Optionally zero out the learning rate of the last layer.\n", + " if epoch < 1:\n", + " for param_group in optimizer.param_groups:\n", + " if \"last_layer\" in param_group:\n", + " param_group[\"lr\"] = 0.0\n", + "\n", + " # Apply weight decay schedule.\n", + " weight_decay = cosine_schedule(\n", + " step=global_step,\n", + " max_steps=total_steps,\n", + " start_value=0.04,\n", + " end_value=0.4,\n", + " )\n", + "\n", + " # Update weight decay directly for all parameter groups.\n", + " for group in optimizer.param_groups:\n", + " if group[\"weight_decay\"] != 0.0:\n", + " group[\"weight_decay\"] = weight_decay\n", + "\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + "\n", + " # Momentum update teacher.\n", + " momentum = cosine_schedule(\n", + " step=global_step,\n", + " max_steps=total_steps,\n", + " start_value=0.992,\n", + " end_value=1.0,\n", + " )\n", + " update_momentum(model.student_backbone, model.teacher_backbone, m=momentum)\n", + " update_momentum(model.student_head, model.teacher_head, m=momentum)\n", + "\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/dinov2.ipynb b/examples/notebooks/pytorch_lightning/dinov2.ipynb new file mode 100644 index 000000000..c80bc5e16 --- /dev/null +++ b/examples/notebooks/pytorch_lightning/dinov2.ipynb @@ -0,0 +1,414 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import copy\n", + "from functools import partial" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from timm.models.vision_transformer import vit_small_patch16_224\n", + "from torch import Tensor\n", + "from torch.nn import Module\n", + "from torch.optim import AdamW" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import DINOLoss, IBOTPatchLoss, KoLeoLoss\n", + "from lightly.models.modules import DINOv2ProjectionHead, MaskedVisionTransformerTIMM\n", + "from lightly.models.utils import (\n", + " random_block_mask,\n", + " update_drop_path_rate,\n", + " update_momentum,\n", + ")\n", + "from lightly.transforms.dino_transform import DINOTransform\n", + "from lightly.utils.scheduler import cosine_schedule, linear_warmup_schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "def freeze_eval_module(module: Module) -> None:\n", + " \"\"\"Freeze the parameters of a module.\"\"\"\n", + " for param in module.parameters():\n", + " param.requires_grad = False\n", + " module.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "class DINOv2Head(Module):\n", + " def __init__(\n", + " self, dino_head: DINOv2ProjectionHead, ibot_head: DINOv2ProjectionHead\n", + " ) -> None:\n", + " super().__init__()\n", + " self.dino_head = dino_head\n", + " self.ibot_head = ibot_head" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "class DINOv2(pl.LightningModule):\n", + " def __init__(\n", + " self,\n", + " ibot_separate_head: bool = False,\n", + " ) -> None:\n", + " super().__init__()\n", + " self.save_hyperparameters()\n", + "\n", + " # Backbones\n", + " vit_teacher = vit_small_patch16_224(\n", + " pos_embed=\"learn\",\n", + " dynamic_img_size=True,\n", + " init_values=1e-5,\n", + " )\n", + " self.teacher_backbone = MaskedVisionTransformerTIMM(\n", + " vit=vit_teacher,\n", + " antialias=False,\n", + " pos_embed_initialization=\"skip\",\n", + " )\n", + " self.student_backbone = copy.deepcopy(self.teacher_backbone)\n", + " update_drop_path_rate(\n", + " self.student_backbone.vit,\n", + " drop_path_rate=0.1, # we recommend using smaller rates like 0.1 for vit-s-14\n", + " mode=\"uniform\",\n", + " )\n", + "\n", + " freeze_eval_module(self.teacher_backbone)\n", + "\n", + " # Heads\n", + " dino_head = partial(\n", + " DINOv2ProjectionHead,\n", + " input_dim=384,\n", + " )\n", + "\n", + " teacher_dino_head = dino_head()\n", + " student_dino_head = dino_head()\n", + "\n", + " ibot_head = partial(\n", + " DINOv2ProjectionHead,\n", + " input_dim=384,\n", + " )\n", + "\n", + " if ibot_separate_head:\n", + " teacher_ibot_head = ibot_head()\n", + " student_ibot_head = ibot_head()\n", + " else:\n", + " teacher_ibot_head = teacher_dino_head\n", + " student_ibot_head = student_dino_head\n", + "\n", + " self.teacher_head = DINOv2Head(\n", + " dino_head=teacher_dino_head,\n", + " ibot_head=teacher_ibot_head,\n", + " )\n", + " self.student_head = DINOv2Head(\n", + " dino_head=student_dino_head,\n", + " ibot_head=student_ibot_head,\n", + " )\n", + "\n", + " freeze_eval_module(self.teacher_head)\n", + "\n", + " # Losses\n", + " self.dino_criterion = DINOLoss()\n", + " self.ibot_criterion = IBOTPatchLoss()\n", + " self.koleo_criterion = KoLeoLoss()\n", + "\n", + " def forward(self, x: Tensor) -> Tensor:\n", + " pass\n", + "\n", + " def forward_teacher(self, x: Tensor) -> tuple[Tensor, Tensor]:\n", + " features = self.teacher_backbone.encode(x)\n", + " cls_tokens = features[:, 0]\n", + " return cls_tokens, features\n", + "\n", + " def forward_student(\n", + " self, x: Tensor, mask: Tensor | None\n", + " ) -> tuple[Tensor, Tensor | None]:\n", + " features = self.student_backbone.encode(x, mask=mask)\n", + " cls_tokens = features[:, 0]\n", + " masked_features = None if mask is None else features[mask]\n", + " return cls_tokens, masked_features\n", + "\n", + " def training_step(\n", + " self, batch: tuple[list[Tensor], Tensor, list[str]], batch_idx: int\n", + " ) -> Tensor:\n", + " views, targets = batch[0], batch[1]\n", + " global_views = torch.cat(views[:2])\n", + " local_views = torch.cat(views[2:])\n", + "\n", + " # Masking\n", + " B = len(global_views)\n", + " sequence_length = self.teacher_backbone.sequence_length\n", + " mask = global_views.new_zeros((B, sequence_length), dtype=torch.bool)\n", + " # Mask patches except class token.\n", + " H, W = self.teacher_backbone.vit.patch_embed.grid_size\n", + " assert (\n", + " H * W == sequence_length - 1\n", + " ), f\"Unexpected grid size: {H}x{W}, sequence_length {sequence_length}\"\n", + " block_mask = random_block_mask(size=(B, H, W), device=mask.device)\n", + " mask[:, 1:] = block_mask.flatten(start_dim=1)\n", + "\n", + " # Teacher forward\n", + " with torch.no_grad():\n", + " teacher_cls_token, teacher_features = self.forward_teacher(global_views)\n", + " teacher_cls_out = self.teacher_head.dino_head.forward(teacher_cls_token)\n", + " teacher_masked_out = self.teacher_head.ibot_head.forward(\n", + " teacher_features[mask]\n", + " )\n", + "\n", + " # Student forward\n", + " student_global_cls_token, student_global_masked_features = self.forward_student(\n", + " global_views, mask=mask\n", + " )\n", + " student_global_cls_out = self.student_head.dino_head.forward(\n", + " student_global_cls_token\n", + " )\n", + " student_global_masked_out = self.student_head.ibot_head.forward(\n", + " student_global_masked_features\n", + " )\n", + "\n", + " student_local_cls_token, _ = self.forward_student(local_views, mask=None)\n", + " student_local_cls_out = self.student_head.dino_head.forward(\n", + " student_local_cls_token\n", + " )\n", + " student_cls_out = torch.cat([student_global_cls_out, student_local_cls_out])\n", + "\n", + " teacher_temp = linear_warmup_schedule(\n", + " step=self.trainer.global_step,\n", + " warmup_steps=int(\n", + " 30 / self.trainer.max_epochs * self.trainer.estimated_stepping_batches\n", + " ),\n", + " start_value=0.04,\n", + " end_value=0.07,\n", + " )\n", + " dino_loss = self.dino_criterion(\n", + " teacher_out=teacher_cls_out.chunk(2),\n", + " student_out=student_cls_out.chunk(len(views)),\n", + " teacher_temp=teacher_temp,\n", + " )\n", + " ibot_loss = self.ibot_criterion(\n", + " teacher_out=teacher_masked_out,\n", + " student_out=student_global_masked_out,\n", + " mask=block_mask,\n", + " teacher_temp=teacher_temp,\n", + " )\n", + " koleo_loss = 0.1 * sum(\n", + " self.koleo_criterion(t) for t in student_global_cls_token.chunk(2)\n", + " )\n", + " loss = dino_loss + ibot_loss + koleo_loss\n", + "\n", + " self.log_dict(\n", + " {\n", + " \"train_loss\": loss,\n", + " \"train_dino_loss\": dino_loss,\n", + " \"train_ibot_loss\": ibot_loss,\n", + " \"train_koleo_loss\": koleo_loss,\n", + " \"teacher_temp\": teacher_temp,\n", + " },\n", + " prog_bar=True,\n", + " sync_dist=True,\n", + " batch_size=len(targets),\n", + " )\n", + "\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = AdamW(self.parameters(), lr=0.001)\n", + " return optim\n", + "\n", + " def on_before_optimizer_step(self, optimizer: AdamW, *args) -> None:\n", + " # Optionally zero out the learning rate of the last layer.\n", + " if self.current_epoch < 1:\n", + " for param_group in optimizer.param_groups:\n", + " if \"last_layer\" in param_group:\n", + " param_group[\"lr\"] = 0.0\n", + "\n", + " # Apply weight decay schedule\n", + " weight_decay = cosine_schedule(\n", + " step=self.trainer.global_step,\n", + " max_steps=self.trainer.estimated_stepping_batches,\n", + " start_value=0.04,\n", + " end_value=0.4,\n", + " )\n", + " for group in optimizer.param_groups:\n", + " if group[\"weight_decay\"] != 0.0:\n", + " group[\"weight_decay\"] = weight_decay\n", + "\n", + " def on_train_batch_end(self, outputs, batch, batch_idx):\n", + " # Momentum update teacher.\n", + " momentum = cosine_schedule(\n", + " step=self.trainer.global_step,\n", + " max_steps=self.trainer.estimated_stepping_batches,\n", + " start_value=0.992,\n", + " end_value=1.0,\n", + " )\n", + " update_momentum(self.student_backbone, self.teacher_backbone, m=momentum)\n", + " update_momentum(self.student_head, self.teacher_head, m=momentum)\n", + "\n", + " return super().on_train_batch_end(outputs, batch, batch_idx)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "model = DINOv2()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "transform = DINOTransform(\n", + " global_crop_scale=(0.32, 1),\n", + " local_crop_scale=(0.05, 0.32),\n", + " n_local_views=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "# we ignore object detection annotations by setting target_transform to return 0\n", + "def target_transform(t):\n", + " return 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=target_transform,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=64,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=50, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/dinov2.ipynb b/examples/notebooks/pytorch_lightning_distributed/dinov2.ipynb new file mode 100644 index 000000000..ac4e9cbae --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/dinov2.ipynb @@ -0,0 +1,414 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import copy\n", + "from functools import partial" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from timm.models.vision_transformer import vit_small_patch16_224\n", + "from torch import Tensor\n", + "from torch.nn import Module\n", + "from torch.optim import AdamW" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import DINOLoss, IBOTPatchLoss, KoLeoLoss\n", + "from lightly.models.modules import DINOv2ProjectionHead, MaskedVisionTransformerTIMM\n", + "from lightly.models.utils import (\n", + " random_block_mask,\n", + " update_drop_path_rate,\n", + " update_momentum,\n", + ")\n", + "from lightly.transforms.dino_transform import DINOTransform\n", + "from lightly.utils.optim import update_param_groups\n", + "from lightly.utils.scheduler import cosine_schedule, linear_warmup_schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "def freeze_eval_module(module: Module) -> None:\n", + " \"\"\"Freeze the parameters of a module.\"\"\"\n", + " for param in module.parameters():\n", + " param.requires_grad = False\n", + " module.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "class DINOv2Head(Module):\n", + " def __init__(\n", + " self, dino_head: DINOv2ProjectionHead, ibot_head: DINOv2ProjectionHead\n", + " ) -> None:\n", + " super().__init__()\n", + " self.dino_head = dino_head\n", + " self.ibot_head = ibot_head" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "class DINOv2(pl.LightningModule):\n", + " def __init__(\n", + " self,\n", + " ibot_separate_head: bool = False,\n", + " ) -> None:\n", + " super().__init__()\n", + " self.save_hyperparameters()\n", + "\n", + " # Backbones\n", + " vit_teacher = vit_small_patch16_224(\n", + " pos_embed=\"learn\",\n", + " dynamic_img_size=True,\n", + " init_values=1e-5,\n", + " )\n", + " self.teacher_backbone = MaskedVisionTransformerTIMM(\n", + " vit=vit_teacher,\n", + " antialias=False,\n", + " pos_embed_initialization=\"skip\",\n", + " )\n", + " self.student_backbone = copy.deepcopy(self.teacher_backbone)\n", + " update_drop_path_rate(\n", + " self.student_backbone.vit,\n", + " drop_path_rate=0.1, # we recommend using smaller rates like 0.1 for vit-s-14\n", + " mode=\"uniform\",\n", + " )\n", + "\n", + " freeze_eval_module(self.teacher_backbone)\n", + "\n", + " # Heads\n", + " dino_head = partial(\n", + " DINOv2ProjectionHead,\n", + " input_dim=384,\n", + " )\n", + "\n", + " teacher_dino_head = dino_head()\n", + " student_dino_head = dino_head()\n", + "\n", + " ibot_head = partial(\n", + " DINOv2ProjectionHead,\n", + " input_dim=384,\n", + " )\n", + "\n", + " if ibot_separate_head:\n", + " teacher_ibot_head = ibot_head()\n", + " student_ibot_head = ibot_head()\n", + " else:\n", + " teacher_ibot_head = teacher_dino_head\n", + " student_ibot_head = student_dino_head\n", + "\n", + " self.teacher_head = DINOv2Head(\n", + " dino_head=teacher_dino_head,\n", + " ibot_head=teacher_ibot_head,\n", + " )\n", + " self.student_head = DINOv2Head(\n", + " dino_head=student_dino_head,\n", + " ibot_head=student_ibot_head,\n", + " )\n", + "\n", + " freeze_eval_module(self.teacher_head)\n", + "\n", + " # Losses\n", + " self.dino_criterion = DINOLoss()\n", + " self.ibot_criterion = IBOTPatchLoss()\n", + " self.koleo_criterion = KoLeoLoss()\n", + "\n", + " def forward(self, x: Tensor) -> Tensor:\n", + " pass\n", + "\n", + " def forward_teacher(self, x: Tensor) -> tuple[Tensor, Tensor]:\n", + " features = self.teacher_backbone.encode(x)\n", + " cls_tokens = features[:, 0]\n", + " return cls_tokens, features\n", + "\n", + " def forward_student(\n", + " self, x: Tensor, mask: Tensor | None\n", + " ) -> tuple[Tensor, Tensor | None]:\n", + " features = self.student_backbone.encode(x, mask=mask)\n", + " cls_tokens = features[:, 0]\n", + " masked_features = None if mask is None else features[mask]\n", + " return cls_tokens, masked_features\n", + "\n", + " def training_step(\n", + " self, batch: tuple[list[Tensor], Tensor, list[str]], batch_idx: int\n", + " ) -> Tensor:\n", + " views, targets = batch[0], batch[1]\n", + " global_views = torch.cat(views[:2])\n", + " local_views = torch.cat(views[2:])\n", + "\n", + " # Masking\n", + " B = len(global_views)\n", + " sequence_length = self.teacher_backbone.sequence_length\n", + " mask = global_views.new_zeros((B, sequence_length), dtype=torch.bool)\n", + " # Mask patches except class token.\n", + " H, W = self.teacher_backbone.vit.patch_embed.grid_size\n", + " assert (\n", + " H * W == sequence_length - 1\n", + " ), f\"Unexpected grid size: {H}x{W}, sequence_length {sequence_length}\"\n", + " block_mask = random_block_mask(size=(B, H, W), device=mask.device)\n", + " mask[:, 1:] = block_mask.flatten(start_dim=1)\n", + "\n", + " # Teacher forward\n", + " with torch.no_grad():\n", + " teacher_cls_token, teacher_features = self.forward_teacher(global_views)\n", + " teacher_cls_out = self.teacher_head.dino_head.forward(teacher_cls_token)\n", + " teacher_masked_out = self.teacher_head.ibot_head.forward(\n", + " teacher_features[mask]\n", + " )\n", + "\n", + " # Student forward\n", + " student_global_cls_token, student_global_masked_features = self.forward_student(\n", + " global_views, mask=mask\n", + " )\n", + " student_global_cls_out = self.student_head.dino_head.forward(\n", + " student_global_cls_token\n", + " )\n", + " student_global_masked_out = self.student_head.ibot_head.forward(\n", + " student_global_masked_features\n", + " )\n", + "\n", + " student_local_cls_token, _ = self.forward_student(local_views, mask=None)\n", + " student_local_cls_out = self.student_head.dino_head.forward(\n", + " student_local_cls_token\n", + " )\n", + " student_cls_out = torch.cat([student_global_cls_out, student_local_cls_out])\n", + "\n", + " teacher_temp = linear_warmup_schedule(\n", + " step=self.trainer.global_step,\n", + " warmup_steps=int(\n", + " 30 / self.trainer.max_epochs * self.trainer.estimated_stepping_batches\n", + " ),\n", + " start_value=0.04,\n", + " end_value=0.07,\n", + " )\n", + " dino_loss = self.dino_criterion(\n", + " teacher_out=teacher_cls_out.chunk(2),\n", + " student_out=student_cls_out.chunk(len(views)),\n", + " teacher_temp=teacher_temp,\n", + " )\n", + " ibot_loss = self.ibot_criterion(\n", + " teacher_out=teacher_masked_out,\n", + " student_out=student_global_masked_out,\n", + " mask=block_mask,\n", + " teacher_temp=teacher_temp,\n", + " )\n", + " koleo_loss = 0.1 * sum(\n", + " self.koleo_criterion(t) for t in student_global_cls_token.chunk(2)\n", + " )\n", + " loss = dino_loss + ibot_loss + koleo_loss\n", + "\n", + " self.log_dict(\n", + " {\n", + " \"train_loss\": loss,\n", + " \"train_dino_loss\": dino_loss,\n", + " \"train_ibot_loss\": ibot_loss,\n", + " \"train_koleo_loss\": koleo_loss,\n", + " \"teacher_temp\": teacher_temp,\n", + " },\n", + " prog_bar=True,\n", + " sync_dist=True,\n", + " batch_size=len(targets),\n", + " )\n", + "\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = AdamW(self.parameters(), lr=0.001)\n", + " return optim\n", + "\n", + " def on_before_optimizer_step(self, optimizer: AdamW, *args) -> None:\n", + " # Optionally zero out the learning rate of the last layer.\n", + " if self.current_epoch < 1:\n", + " for param_group in optimizer.param_groups:\n", + " if \"last_layer\" in param_group:\n", + " param_group[\"lr\"] = 0.0\n", + "\n", + " # Apply weight decay schedule\n", + " weight_decay = cosine_schedule(\n", + " step=self.trainer.global_step,\n", + " max_steps=self.trainer.estimated_stepping_batches,\n", + " start_value=0.04,\n", + " end_value=0.4,\n", + " )\n", + " for group in optimizer.param_groups:\n", + " if group[\"weight_decay\"] != 0.0:\n", + " group[\"weight_decay\"] = weight_decay\n", + "\n", + " def on_train_batch_end(self, outputs, batch, batch_idx):\n", + " # Momentum update teacher.\n", + " momentum = cosine_schedule(\n", + " step=self.trainer.global_step,\n", + " max_steps=self.trainer.estimated_stepping_batches,\n", + " start_value=0.992,\n", + " end_value=1.0,\n", + " )\n", + " update_momentum(self.student_backbone, self.teacher_backbone, m=momentum)\n", + " update_momentum(self.student_head, self.teacher_head, m=momentum)\n", + "\n", + " return super().on_train_batch_end(outputs, batch, batch_idx)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "model = DINOv2()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "transform = DINOTransform(\n", + " global_crop_scale=(0.32, 1),\n", + " local_crop_scale=(0.05, 0.32),\n", + " n_local_views=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "# we ignore object detection annotations by setting target_transform to return 0\n", + "def target_transform(t):\n", + " return 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=target_transform,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=64,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm\n", + "# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.\n", + "trainer = pl.Trainer(\n", + " max_epochs=50,\n", + " devices=\"auto\",\n", + " accelerator=\"gpu\",\n", + " strategy=\"ddp_find_unused_parameters_true\",\n", + " sync_batchnorm=True,\n", + " use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n", + ")\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/pytorch/dinov2.py b/examples/pytorch/dinov2.py new file mode 100644 index 000000000..56ac0953e --- /dev/null +++ b/examples/pytorch/dinov2.py @@ -0,0 +1,280 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + +# Note: The model and training settings do not follow the reference settings +# from the paper. The settings are chosen such that the example can easily be +# run on a small dataset with a single GPU. + +import copy +from functools import partial + +import torch +import torchvision +from timm.models.vision_transformer import vit_small_patch16_224 +from torch import Tensor +from torch.nn import Module +from torch.optim import AdamW + +from lightly.loss import DINOLoss, IBOTPatchLoss, KoLeoLoss +from lightly.models.modules import DINOv2ProjectionHead, MaskedVisionTransformerTIMM +from lightly.models.utils import ( + random_block_mask, + update_drop_path_rate, + update_momentum, +) +from lightly.transforms.dino_transform import DINOTransform +from lightly.utils.scheduler import cosine_schedule, linear_warmup_schedule + + +def freeze_eval_module(module: Module) -> None: + """Freeze the parameters of a module.""" + for param in module.parameters(): + param.requires_grad = False + module.eval() + + +class DINOv2Head(Module): + def __init__( + self, dino_head: DINOv2ProjectionHead, ibot_head: DINOv2ProjectionHead + ) -> None: + super().__init__() + self.dino_head = dino_head + self.ibot_head = ibot_head + + +class DINOv2(Module): + def __init__( + self, + ibot_separate_head: bool = False, + ) -> None: + super().__init__() + + # Backbones + vit_teacher = vit_small_patch16_224( + pos_embed="learn", + dynamic_img_size=True, + init_values=1e-5, + ) + self.teacher_backbone = MaskedVisionTransformerTIMM( + vit=vit_teacher, + antialias=False, + pos_embed_initialization="skip", + ) + self.student_backbone = copy.deepcopy(self.teacher_backbone) + update_drop_path_rate( + self.student_backbone.vit, + drop_path_rate=0.1, # we recommend using smaller rates like 0.1 for vit-s-14 + mode="uniform", + ) + + freeze_eval_module(self.teacher_backbone) + + # Heads + dino_head = partial( + DINOv2ProjectionHead, + input_dim=384, + ) + + teacher_dino_head = dino_head() + student_dino_head = dino_head() + + ibot_head = partial( + DINOv2ProjectionHead, + input_dim=384, + ) + + if ibot_separate_head: + teacher_ibot_head = ibot_head() + student_ibot_head = ibot_head() + else: + teacher_ibot_head = teacher_dino_head + student_ibot_head = student_dino_head + + self.teacher_head = DINOv2Head( + dino_head=teacher_dino_head, + ibot_head=teacher_ibot_head, + ) + self.student_head = DINOv2Head( + dino_head=student_dino_head, + ibot_head=student_ibot_head, + ) + + freeze_eval_module(self.teacher_head) + + def forward(self, x: Tensor) -> Tensor: + return self.teacher_backbone(x) + + def forward_teacher(self, x: Tensor) -> tuple[Tensor, Tensor]: + features = self.teacher_backbone.encode(x) + cls_tokens = features[:, 0] + return cls_tokens, features + + def forward_student( + self, x: Tensor, mask: Tensor | None + ) -> tuple[Tensor, Tensor | None]: + features = self.student_backbone.encode(x, mask=mask) + cls_tokens = features[:, 0] + masked_features = None if mask is None else features[mask] + return cls_tokens, masked_features + + +model = DINOv2() + +transform = DINOTransform( + global_crop_scale=(0.32, 1), + local_crop_scale=(0.05, 0.32), + n_local_views=8, +) + + +# We ignore object detection annotations by setting target_transform to return 0. +def target_transform(t): + return 0 + + +device = "cuda" if torch.cuda.is_available() else "mps" +model.to(device) + +dataset = torchvision.datasets.VOCDetection( + "datasets/pascal_voc", + download=True, + transform=transform, + target_transform=target_transform, +) +# Or create a dataset from a folder containing images or videos. +# dataset = LightlyDataset("path/to/folder") + +dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=64, + shuffle=True, + drop_last=True, + num_workers=8, +) + +# Create the loss functions. +dino_criterion = DINOLoss() +ibot_criterion = IBOTPatchLoss() +koleo_criterion = KoLeoLoss() + +# Move loss to correct device because it also contains parameters. +dino_criterion = dino_criterion.to(device) +ibot_criterion = ibot_criterion.to(device) +koleo_criterion = koleo_criterion.to(device) + +optimizer = AdamW(model.parameters(), lr=0.001) + +epochs = 50 +num_batches = len(dataloader) +total_steps = epochs * num_batches + +print("Starting Training") +for epoch in range(epochs): + total_loss = 0 + for batch_idx, batch in enumerate(dataloader): + views = batch[0] + views = [view.to(device) for view in views] + global_views = torch.cat(views[:2]) + local_views = torch.cat(views[2:]) + + # Masking + B = len(global_views) + sequence_length = model.teacher_backbone.sequence_length + mask = global_views.new_zeros((B, sequence_length), dtype=torch.bool) + + # Mask patches except class token. + H, W = model.teacher_backbone.vit.patch_embed.grid_size + assert ( + H * W == sequence_length - 1 + ), f"Unexpected grid size: {H}x{W}, sequence_length {sequence_length}" + block_mask = random_block_mask(size=(B, H, W), device=mask.device) + mask[:, 1:] = block_mask.flatten(start_dim=1) + + # Teacher forward + with torch.no_grad(): + teacher_cls_token, teacher_features = model.forward_teacher(global_views) + teacher_cls_out = model.teacher_head.dino_head.forward(teacher_cls_token) + teacher_masked_out = model.teacher_head.ibot_head.forward( + teacher_features[mask] + ) + + # Student forward + ( + student_global_cls_token, + student_global_masked_features, + ) = model.forward_student(global_views, mask=mask) + student_global_cls_out = model.student_head.dino_head.forward( + student_global_cls_token + ) + student_global_masked_out = model.student_head.ibot_head.forward( + student_global_masked_features + ) + student_local_cls_token, _ = model.forward_student(local_views, mask=None) + student_local_cls_out = model.student_head.dino_head.forward( + student_local_cls_token + ) + student_cls_out = torch.cat([student_global_cls_out, student_local_cls_out]) + + # Calculate current global step based on epoch and batch index. + global_step = epoch * num_batches + batch_idx + + # Calculate the loss. + teacher_temp = linear_warmup_schedule( + step=global_step, + warmup_steps=int(30 / epochs * total_steps), + start_value=0.04, + end_value=0.07, + ) + dino_loss = dino_criterion( + teacher_out=teacher_cls_out.chunk(2), + student_out=student_cls_out.chunk(len(views)), + teacher_temp=teacher_temp, + ) + ibot_loss = ibot_criterion( + teacher_out=teacher_masked_out, + student_out=student_global_masked_out, + mask=block_mask, + teacher_temp=teacher_temp, + ) + koleo_loss = 0.1 * sum( + koleo_criterion(t) for t in student_global_cls_token.chunk(2) + ) + loss = dino_loss + ibot_loss + koleo_loss + + total_loss += loss.detach() + loss.backward() + + # Optionally zero out the learning rate of the last layer. + if epoch < 1: + for param_group in optimizer.param_groups: + if "last_layer" in param_group: + param_group["lr"] = 0.0 + + # Apply weight decay schedule. + weight_decay = cosine_schedule( + step=global_step, + max_steps=total_steps, + start_value=0.04, + end_value=0.4, + ) + + # Update weight decay directly for all parameter groups. + for group in optimizer.param_groups: + if group["weight_decay"] != 0.0: + group["weight_decay"] = weight_decay + + optimizer.step() + optimizer.zero_grad() + + # Momentum update teacher. + momentum = cosine_schedule( + step=global_step, + max_steps=total_steps, + start_value=0.992, + end_value=1.0, + ) + update_momentum(model.student_backbone, model.teacher_backbone, m=momentum) + update_momentum(model.student_head, model.teacher_head, m=momentum) + + avg_loss = total_loss / len(dataloader) + print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}") diff --git a/examples/pytorch_lightning/dinov2.py b/examples/pytorch_lightning/dinov2.py new file mode 100644 index 000000000..ec5cfe09d --- /dev/null +++ b/examples/pytorch_lightning/dinov2.py @@ -0,0 +1,280 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + +# Note: The model and training settings do not follow the reference settings +# from the paper. The settings are chosen such that the example can easily be +# run on a small dataset with a single GPU. + +import copy +from functools import partial + +import pytorch_lightning as pl +import torch +import torchvision +from timm.models.vision_transformer import vit_small_patch16_224 +from torch import Tensor +from torch.nn import Module +from torch.optim import AdamW + +from lightly.loss import DINOLoss, IBOTPatchLoss, KoLeoLoss +from lightly.models.modules import DINOv2ProjectionHead, MaskedVisionTransformerTIMM +from lightly.models.utils import ( + random_block_mask, + update_drop_path_rate, + update_momentum, +) +from lightly.transforms.dino_transform import DINOTransform +from lightly.utils.scheduler import cosine_schedule, linear_warmup_schedule + + +def freeze_eval_module(module: Module) -> None: + """Freeze the parameters of a module.""" + for param in module.parameters(): + param.requires_grad = False + module.eval() + + +class DINOv2Head(Module): + def __init__( + self, dino_head: DINOv2ProjectionHead, ibot_head: DINOv2ProjectionHead + ) -> None: + super().__init__() + self.dino_head = dino_head + self.ibot_head = ibot_head + + +class DINOv2(pl.LightningModule): + def __init__( + self, + ibot_separate_head: bool = False, + ) -> None: + super().__init__() + self.save_hyperparameters() + + # Backbones + vit_teacher = vit_small_patch16_224( + pos_embed="learn", + dynamic_img_size=True, + init_values=1e-5, + ) + self.teacher_backbone = MaskedVisionTransformerTIMM( + vit=vit_teacher, + antialias=False, + pos_embed_initialization="skip", + ) + self.student_backbone = copy.deepcopy(self.teacher_backbone) + update_drop_path_rate( + self.student_backbone.vit, + drop_path_rate=0.1, # we recommend using smaller rates like 0.1 for vit-s-14 + mode="uniform", + ) + + freeze_eval_module(self.teacher_backbone) + + # Heads + dino_head = partial( + DINOv2ProjectionHead, + input_dim=384, + ) + + teacher_dino_head = dino_head() + student_dino_head = dino_head() + + ibot_head = partial( + DINOv2ProjectionHead, + input_dim=384, + ) + + if ibot_separate_head: + teacher_ibot_head = ibot_head() + student_ibot_head = ibot_head() + else: + teacher_ibot_head = teacher_dino_head + student_ibot_head = student_dino_head + + self.teacher_head = DINOv2Head( + dino_head=teacher_dino_head, + ibot_head=teacher_ibot_head, + ) + self.student_head = DINOv2Head( + dino_head=student_dino_head, + ibot_head=student_ibot_head, + ) + + freeze_eval_module(self.teacher_head) + + # Losses + self.dino_criterion = DINOLoss() + self.ibot_criterion = IBOTPatchLoss() + self.koleo_criterion = KoLeoLoss() + + def forward(self, x: Tensor) -> Tensor: + pass + + def forward_teacher(self, x: Tensor) -> tuple[Tensor, Tensor]: + features = self.teacher_backbone.encode(x) + cls_tokens = features[:, 0] + return cls_tokens, features + + def forward_student( + self, x: Tensor, mask: Tensor | None + ) -> tuple[Tensor, Tensor | None]: + features = self.student_backbone.encode(x, mask=mask) + cls_tokens = features[:, 0] + masked_features = None if mask is None else features[mask] + return cls_tokens, masked_features + + def training_step( + self, batch: tuple[list[Tensor], Tensor, list[str]], batch_idx: int + ) -> Tensor: + views, targets = batch[0], batch[1] + global_views = torch.cat(views[:2]) + local_views = torch.cat(views[2:]) + + # Masking + B = len(global_views) + sequence_length = self.teacher_backbone.sequence_length + mask = global_views.new_zeros((B, sequence_length), dtype=torch.bool) + # Mask patches except class token. + H, W = self.teacher_backbone.vit.patch_embed.grid_size + assert ( + H * W == sequence_length - 1 + ), f"Unexpected grid size: {H}x{W}, sequence_length {sequence_length}" + block_mask = random_block_mask(size=(B, H, W), device=mask.device) + mask[:, 1:] = block_mask.flatten(start_dim=1) + + # Teacher forward + with torch.no_grad(): + teacher_cls_token, teacher_features = self.forward_teacher(global_views) + teacher_cls_out = self.teacher_head.dino_head.forward(teacher_cls_token) + teacher_masked_out = self.teacher_head.ibot_head.forward( + teacher_features[mask] + ) + + # Student forward + student_global_cls_token, student_global_masked_features = self.forward_student( + global_views, mask=mask + ) + student_global_cls_out = self.student_head.dino_head.forward( + student_global_cls_token + ) + student_global_masked_out = self.student_head.ibot_head.forward( + student_global_masked_features + ) + + student_local_cls_token, _ = self.forward_student(local_views, mask=None) + student_local_cls_out = self.student_head.dino_head.forward( + student_local_cls_token + ) + student_cls_out = torch.cat([student_global_cls_out, student_local_cls_out]) + + teacher_temp = linear_warmup_schedule( + step=self.trainer.global_step, + warmup_steps=int( + 30 / self.trainer.max_epochs * self.trainer.estimated_stepping_batches + ), + start_value=0.04, + end_value=0.07, + ) + dino_loss = self.dino_criterion( + teacher_out=teacher_cls_out.chunk(2), + student_out=student_cls_out.chunk(len(views)), + teacher_temp=teacher_temp, + ) + ibot_loss = self.ibot_criterion( + teacher_out=teacher_masked_out, + student_out=student_global_masked_out, + mask=block_mask, + teacher_temp=teacher_temp, + ) + koleo_loss = 0.1 * sum( + self.koleo_criterion(t) for t in student_global_cls_token.chunk(2) + ) + loss = dino_loss + ibot_loss + koleo_loss + + self.log_dict( + { + "train_loss": loss, + "train_dino_loss": dino_loss, + "train_ibot_loss": ibot_loss, + "train_koleo_loss": koleo_loss, + "teacher_temp": teacher_temp, + }, + prog_bar=True, + sync_dist=True, + batch_size=len(targets), + ) + + return loss + + def configure_optimizers(self): + optim = AdamW(self.parameters(), lr=0.001) + return optim + + def on_before_optimizer_step(self, optimizer: AdamW, *args) -> None: + # Optionally zero out the learning rate of the last layer. + if self.current_epoch < 1: + for param_group in optimizer.param_groups: + if "last_layer" in param_group: + param_group["lr"] = 0.0 + + # Apply weight decay schedule + weight_decay = cosine_schedule( + step=self.trainer.global_step, + max_steps=self.trainer.estimated_stepping_batches, + start_value=0.04, + end_value=0.4, + ) + for group in optimizer.param_groups: + if group["weight_decay"] != 0.0: + group["weight_decay"] = weight_decay + + def on_train_batch_end(self, outputs, batch, batch_idx): + # Momentum update teacher. + momentum = cosine_schedule( + step=self.trainer.global_step, + max_steps=self.trainer.estimated_stepping_batches, + start_value=0.992, + end_value=1.0, + ) + update_momentum(self.student_backbone, self.teacher_backbone, m=momentum) + update_momentum(self.student_head, self.teacher_head, m=momentum) + + return super().on_train_batch_end(outputs, batch, batch_idx) + + +model = DINOv2() + +transform = DINOTransform( + global_crop_scale=(0.32, 1), + local_crop_scale=(0.05, 0.32), + n_local_views=8, +) + + +# we ignore object detection annotations by setting target_transform to return 0 +def target_transform(t): + return 0 + + +dataset = torchvision.datasets.VOCDetection( + "datasets/pascal_voc", + download=True, + transform=transform, + target_transform=target_transform, +) +# or create a dataset from a folder containing images or videos: +# dataset = LightlyDataset("path/to/folder") + +dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=64, + shuffle=True, + drop_last=True, + num_workers=8, +) + +accelerator = "gpu" if torch.cuda.is_available() else "cpu" + +trainer = pl.Trainer(max_epochs=50, devices=1, accelerator=accelerator) +trainer.fit(model=model, train_dataloaders=dataloader) diff --git a/examples/pytorch_lightning_distributed/dinov2.py b/examples/pytorch_lightning_distributed/dinov2.py new file mode 100644 index 000000000..fe3f08e82 --- /dev/null +++ b/examples/pytorch_lightning_distributed/dinov2.py @@ -0,0 +1,288 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + +# Note: The model and training settings do not follow the reference settings +# from the paper. The settings are chosen such that the example can easily be +# run on a small dataset with a single GPU. + +import copy +from functools import partial + +import pytorch_lightning as pl +import torch +import torchvision +from timm.models.vision_transformer import vit_small_patch16_224 +from torch import Tensor +from torch.nn import Module +from torch.optim import AdamW + +from lightly.loss import DINOLoss, IBOTPatchLoss, KoLeoLoss +from lightly.models.modules import DINOv2ProjectionHead, MaskedVisionTransformerTIMM +from lightly.models.utils import ( + random_block_mask, + update_drop_path_rate, + update_momentum, +) +from lightly.transforms.dino_transform import DINOTransform +from lightly.utils.optim import update_param_groups +from lightly.utils.scheduler import cosine_schedule, linear_warmup_schedule + + +def freeze_eval_module(module: Module) -> None: + """Freeze the parameters of a module.""" + for param in module.parameters(): + param.requires_grad = False + module.eval() + + +class DINOv2Head(Module): + def __init__( + self, dino_head: DINOv2ProjectionHead, ibot_head: DINOv2ProjectionHead + ) -> None: + super().__init__() + self.dino_head = dino_head + self.ibot_head = ibot_head + + +class DINOv2(pl.LightningModule): + def __init__( + self, + ibot_separate_head: bool = False, + ) -> None: + super().__init__() + self.save_hyperparameters() + + # Backbones + vit_teacher = vit_small_patch16_224( + pos_embed="learn", + dynamic_img_size=True, + init_values=1e-5, + ) + self.teacher_backbone = MaskedVisionTransformerTIMM( + vit=vit_teacher, + antialias=False, + pos_embed_initialization="skip", + ) + self.student_backbone = copy.deepcopy(self.teacher_backbone) + update_drop_path_rate( + self.student_backbone.vit, + drop_path_rate=0.1, # we recommend using smaller rates like 0.1 for vit-s-14 + mode="uniform", + ) + + freeze_eval_module(self.teacher_backbone) + + # Heads + dino_head = partial( + DINOv2ProjectionHead, + input_dim=384, + ) + + teacher_dino_head = dino_head() + student_dino_head = dino_head() + + ibot_head = partial( + DINOv2ProjectionHead, + input_dim=384, + ) + + if ibot_separate_head: + teacher_ibot_head = ibot_head() + student_ibot_head = ibot_head() + else: + teacher_ibot_head = teacher_dino_head + student_ibot_head = student_dino_head + + self.teacher_head = DINOv2Head( + dino_head=teacher_dino_head, + ibot_head=teacher_ibot_head, + ) + self.student_head = DINOv2Head( + dino_head=student_dino_head, + ibot_head=student_ibot_head, + ) + + freeze_eval_module(self.teacher_head) + + # Losses + self.dino_criterion = DINOLoss() + self.ibot_criterion = IBOTPatchLoss() + self.koleo_criterion = KoLeoLoss() + + def forward(self, x: Tensor) -> Tensor: + pass + + def forward_teacher(self, x: Tensor) -> tuple[Tensor, Tensor]: + features = self.teacher_backbone.encode(x) + cls_tokens = features[:, 0] + return cls_tokens, features + + def forward_student( + self, x: Tensor, mask: Tensor | None + ) -> tuple[Tensor, Tensor | None]: + features = self.student_backbone.encode(x, mask=mask) + cls_tokens = features[:, 0] + masked_features = None if mask is None else features[mask] + return cls_tokens, masked_features + + def training_step( + self, batch: tuple[list[Tensor], Tensor, list[str]], batch_idx: int + ) -> Tensor: + views, targets = batch[0], batch[1] + global_views = torch.cat(views[:2]) + local_views = torch.cat(views[2:]) + + # Masking + B = len(global_views) + sequence_length = self.teacher_backbone.sequence_length + mask = global_views.new_zeros((B, sequence_length), dtype=torch.bool) + # Mask patches except class token. + H, W = self.teacher_backbone.vit.patch_embed.grid_size + assert ( + H * W == sequence_length - 1 + ), f"Unexpected grid size: {H}x{W}, sequence_length {sequence_length}" + block_mask = random_block_mask(size=(B, H, W), device=mask.device) + mask[:, 1:] = block_mask.flatten(start_dim=1) + + # Teacher forward + with torch.no_grad(): + teacher_cls_token, teacher_features = self.forward_teacher(global_views) + teacher_cls_out = self.teacher_head.dino_head.forward(teacher_cls_token) + teacher_masked_out = self.teacher_head.ibot_head.forward( + teacher_features[mask] + ) + + # Student forward + student_global_cls_token, student_global_masked_features = self.forward_student( + global_views, mask=mask + ) + student_global_cls_out = self.student_head.dino_head.forward( + student_global_cls_token + ) + student_global_masked_out = self.student_head.ibot_head.forward( + student_global_masked_features + ) + + student_local_cls_token, _ = self.forward_student(local_views, mask=None) + student_local_cls_out = self.student_head.dino_head.forward( + student_local_cls_token + ) + student_cls_out = torch.cat([student_global_cls_out, student_local_cls_out]) + + teacher_temp = linear_warmup_schedule( + step=self.trainer.global_step, + warmup_steps=int( + 30 / self.trainer.max_epochs * self.trainer.estimated_stepping_batches + ), + start_value=0.04, + end_value=0.07, + ) + dino_loss = self.dino_criterion( + teacher_out=teacher_cls_out.chunk(2), + student_out=student_cls_out.chunk(len(views)), + teacher_temp=teacher_temp, + ) + ibot_loss = self.ibot_criterion( + teacher_out=teacher_masked_out, + student_out=student_global_masked_out, + mask=block_mask, + teacher_temp=teacher_temp, + ) + koleo_loss = 0.1 * sum( + self.koleo_criterion(t) for t in student_global_cls_token.chunk(2) + ) + loss = dino_loss + ibot_loss + koleo_loss + + self.log_dict( + { + "train_loss": loss, + "train_dino_loss": dino_loss, + "train_ibot_loss": ibot_loss, + "train_koleo_loss": koleo_loss, + "teacher_temp": teacher_temp, + }, + prog_bar=True, + sync_dist=True, + batch_size=len(targets), + ) + + return loss + + def configure_optimizers(self): + optim = AdamW(self.parameters(), lr=0.001) + return optim + + def on_before_optimizer_step(self, optimizer: AdamW, *args) -> None: + # Optionally zero out the learning rate of the last layer. + if self.current_epoch < 1: + for param_group in optimizer.param_groups: + if "last_layer" in param_group: + param_group["lr"] = 0.0 + + # Apply weight decay schedule + weight_decay = cosine_schedule( + step=self.trainer.global_step, + max_steps=self.trainer.estimated_stepping_batches, + start_value=0.04, + end_value=0.4, + ) + for group in optimizer.param_groups: + if group["weight_decay"] != 0.0: + group["weight_decay"] = weight_decay + + def on_train_batch_end(self, outputs, batch, batch_idx): + # Momentum update teacher. + momentum = cosine_schedule( + step=self.trainer.global_step, + max_steps=self.trainer.estimated_stepping_batches, + start_value=0.992, + end_value=1.0, + ) + update_momentum(self.student_backbone, self.teacher_backbone, m=momentum) + update_momentum(self.student_head, self.teacher_head, m=momentum) + + return super().on_train_batch_end(outputs, batch, batch_idx) + + +model = DINOv2() + +transform = DINOTransform( + global_crop_scale=(0.32, 1), + local_crop_scale=(0.05, 0.32), + n_local_views=8, +) + + +# we ignore object detection annotations by setting target_transform to return 0 +def target_transform(t): + return 0 + + +dataset = torchvision.datasets.VOCDetection( + "datasets/pascal_voc", + download=True, + transform=transform, + target_transform=target_transform, +) +# or create a dataset from a folder containing images or videos: +# dataset = LightlyDataset("path/to/folder") + +dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=64, + shuffle=True, + drop_last=True, + num_workers=8, +) + +# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm +# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True. +trainer = pl.Trainer( + max_epochs=50, + devices="auto", + accelerator="gpu", + strategy="ddp_find_unused_parameters_true", + sync_batchnorm=True, + use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0 +) +trainer.fit(model=model, train_dataloaders=dataloader)