diff --git a/.gitignore b/.gitignore index 24f67f2e..9317223a 100644 --- a/.gitignore +++ b/.gitignore @@ -87,9 +87,6 @@ target/ profile_default/ ipython_config.py -# pyenv -.python-version - # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies @@ -133,3 +130,13 @@ dmypy.json # Pyre type checker .pyre/ + +# Apex +apex/ + +# Logging +runs/ +wandb/ + +# host-specific values +scripts/env.fish diff --git a/.python-version b/.python-version new file mode 100644 index 00000000..1281604a --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.10.7 diff --git a/README.md b/README.md index 816bf731..e333743d 100644 --- a/README.md +++ b/README.md @@ -1,310 +1,114 @@ # Swin Transformer -[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-v2-scaling-up-capacity-and/object-detection-on-coco)](https://paperswithcode.com/sota/object-detection-on-coco?p=swin-transformer-v2-scaling-up-capacity-and) -[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-v2-scaling-up-capacity-and/instance-segmentation-on-coco)](https://paperswithcode.com/sota/instance-segmentation-on-coco?p=swin-transformer-v2-scaling-up-capacity-and) -[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-v2-scaling-up-capacity-and/semantic-segmentation-on-ade20k)](https://paperswithcode.com/sota/semantic-segmentation-on-ade20k?p=swin-transformer-v2-scaling-up-capacity-and) -[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-v2-scaling-up-capacity-and/action-classification-on-kinetics-400)](https://paperswithcode.com/sota/action-classification-on-kinetics-400?p=swin-transformer-v2-scaling-up-capacity-and) +[Link to original Swin Transformer project](https://github.com/microsoft/Swin-Transformer) -This repo is the official implementation of ["Swin Transformer: Hierarchical Vision Transformer using Shifted Windows"](https://arxiv.org/pdf/2103.14030.pdf) as well as the follow-ups. It currently includes code and models for the following tasks: +## Installation Instructions -> **Image Classification**: Included in this repo. See [get_started.md](get_started.md) for a quick start. +1. Set up python packages -> **Object Detection and Instance Segmentation**: See [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection). - -> **Semantic Segmentation**: See [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation). - -> **Video Action Recognition**: See [Video Swin Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer). - -> **Semi-Supervised Object Detection**: See [Soft Teacher](https://github.com/microsoft/SoftTeacher). - -> **SSL: Contrasitive Learning**: See [Transformer-SSL](https://github.com/SwinTransformer/Transformer-SSL). - -> **SSL: Masked Image Modeling**: See [get_started.md#simmim-support](https://github.com/microsoft/Swin-Transformer/blob/main/get_started.md#simmim-support). - -> **Mixture-of-Experts**: See [get_started](get_started.md#mixture-of-experts-support) for more instructions. - -> **Feature-Distillation**: Will appear in [Feature-Distillation](https://github.com/SwinTransformer/Feature-Distillation). - -## Activity notification - -* 09/18/2022: Organizing ECCV Workshop [*Computer Vision in the Wild (CVinW)*](https://computer-vision-in-the-wild.github.io/eccv-2022/), where two challenges are hosted to evaluate the zero-shot, few-shot and full-shot performance of pre-trained vision models in downstream tasks: - - [``*Image Classification in the Wild (ICinW)*''](https://eval.ai/web/challenges/challenge-page/1832/overview) Challenge evaluates on 20 image classification tasks. - - [``*Object Detection in the Wild (ODinW)*''](https://eval.ai/web/challenges/challenge-page/1839/overview) Challenge evaluates on 35 object detection tasks. - - -$\qquad$ [ [Workshop]](https://computer-vision-in-the-wild.github.io/eccv-2022/) $\qquad$ [ [IC Challenge] ](https://eval.ai/web/challenges/challenge-page/1832/overview) -$\qquad$ [ [OD Challenge] ](https://eval.ai/web/challenges/challenge-page/1839/overview) - -## Updates - -***09/24/2022*** - -1. Merged [SimMIM](https://github.com/microsoft/SimMIM), which is a **Masked Image Modeling** based pre-training approach applicable to Swin and SwinV2 (and also applicable for ViT and ResNet). Please refer to [get started with SimMIM](get_started.md#simmim-support) to play with SimMIM pre-training. - -2. Released a series of Swin and SwinV2 models pre-trained using the SimMIM approach (see [MODELHUB for SimMIM](MODELHUB.md#simmim-pretrained-swin-v2-models)), with model size ranging from SwinV2-Small-50M to SwinV2-giant-1B, data size ranging from ImageNet-1K-10% to ImageNet-22K, and iterations from 125k to 500k. You may leverage these models to study the properties of MIM methods. Please look into the [data scaling](https://arxiv.org/abs/2206.04664) paper for more details. - -***07/09/2022*** - -`News`: - -1. SwinV2-G achieves `61.4 mIoU` on ADE20K semantic segmentation (+1.5 mIoU over the previous SwinV2-G model), using an additional [feature distillation (FD)](https://github.com/SwinTransformer/Feature-Distillation) approach, **setting a new recrod** on this benchmark. FD is an approach that can generally improve the fine-tuning performance of various pre-trained models, including DeiT, DINO, and CLIP. Particularly, it improves CLIP pre-trained ViT-L by +1.6% to reach `89.0%` on ImageNet-1K image classification, which is **the most accurate ViT-L model**. -2. Merged a PR from **Nvidia** that links to faster Swin Transformer inference that have significant speed improvements on `T4 and A100 GPUs`. -3. Merged a PR from **Nvidia** that enables an option to use `pure FP16 (Apex O2)` in training, while almost maintaining the accuracy. - -***06/03/2022*** - -1. Added **Swin-MoE**, the Mixture-of-Experts variant of Swin Transformer implemented using [Tutel](https://github.com/microsoft/tutel) (an optimized Mixture-of-Experts implementation). **Swin-MoE** is introduced in the [TuTel](https://arxiv.org/abs/2206.03382) paper. - -***05/12/2022*** - -1. Pretrained models of [Swin Transformer V2](https://arxiv.org/abs/2111.09883) on ImageNet-1K and ImageNet-22K are released. -2. ImageNet-22K pretrained models for Swin-V1-Tiny and Swin-V2-Small are released. - -***03/02/2022*** - -1. Swin Transformer V2 and SimMIM got accepted by CVPR 2022. [SimMIM](https://github.com/microsoft/SimMIM) is a self-supervised pre-training approach based on masked image modeling, a key technique that works out the 3-billion-parameter Swin V2 model using `40x less labelled data` than that of previous billion-scale models based on JFT-3B. - -***02/09/2022*** - -1. Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/akhaliq/Swin-Transformer) - -***10/12/2021*** - -1. Swin Transformer received ICCV 2021 best paper award (Marr Prize). - -***08/09/2021*** -1. [Soft Teacher](https://arxiv.org/pdf/2106.09018v2.pdf) will appear at ICCV2021. The code will be released at [GitHub Repo](https://github.com/microsoft/SoftTeacher). `Soft Teacher` is an end-to-end semi-supervisd object detection method, achieving a new record on the COCO test-dev: `61.3 box AP` and `53.0 mask AP`. - -***07/03/2021*** -1. Add **Swin MLP**, which is an adaption of `Swin Transformer` by replacing all multi-head self-attention (MHSA) blocks by MLP layers (more precisely it is a group linear layer). The shifted window configuration can also significantly improve the performance of vanilla MLP architectures. - -***06/25/2021*** -1. [Video Swin Transformer](https://arxiv.org/abs/2106.13230) is released at [Video-Swin-Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer). -`Video Swin Transformer` achieves state-of-the-art accuracy on a broad range of video recognition benchmarks, including action recognition (`84.9` top-1 accuracy on Kinetics-400 and `86.1` top-1 accuracy on Kinetics-600 with `~20x` less pre-training data and `~3x` smaller model size) and temporal modeling (`69.6` top-1 accuracy on Something-Something v2). - -***05/12/2021*** -1. Used as a backbone for `Self-Supervised Learning`: [Transformer-SSL](https://github.com/SwinTransformer/Transformer-SSL) - -Using Swin-Transformer as the backbone for self-supervised learning enables us to evaluate the transferring performance of the learnt representations on down-stream tasks, which is missing in previous works due to the use of ViT/DeiT, which has not been well tamed for down-stream tasks. +```sh +python -m venv venv +# Activate your virtual environment somehow +source venv/bin/activate.fish +``` -***04/12/2021*** +CUDA 11.6 -Initial commits: +```sh +pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 +``` -1. Pretrained models on ImageNet-1K ([Swin-T-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth), [Swin-S-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth), [Swin-B-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)) and ImageNet-22K ([Swin-B-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth), [Swin-L-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)) are provided. -2. The supported code and models for ImageNet-1K image classification, COCO object detection and ADE20K semantic segmentation are provided. -3. The cuda kernel implementation for the [local relation layer](https://arxiv.org/pdf/1904.11491.pdf) is provided in branch [LR-Net](https://github.com/microsoft/Swin-Transformer/tree/LR-Net). +CUDA 11.3 -## Introduction +```sh +pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 +``` -**Swin Transformer** (the name `Swin` stands for **S**hifted **win**dow) is initially described in [arxiv](https://arxiv.org/abs/2103.14030), which capably serves as a -general-purpose backbone for computer vision. It is basically a hierarchical Transformer whose representation is -computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention -computation to non-overlapping local windows while also allowing for cross-window connection. +Python packages -Swin Transformer achieves strong performance on COCO object detection (`58.7 box AP` and `51.1 mask AP` on test-dev) and -ADE20K semantic segmentation (`53.5 mIoU` on val), surpassing previous models by a large margin. +```sh +pip install matplotlib yacs timm einops black isort flake8 flake8-bugbear termcolor wandb preface opencv-python +``` -![teaser](figures/teaser.png) +2. Install Apex -## Main Results on ImageNet with Pretrained Models +```sh +git clone https://github.com/NVIDIA/apex.git +cd apex +pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ +``` -**ImageNet-1K and ImageNet-22K Pretrained Swin-V1 Models** +```sh +cd kernels/window_process +python setup.py install +``` -| name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS| 22K model | 1K model | -| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---: |:---: | -| Swin-T | ImageNet-1K | 224x224 | 81.2 | 95.5 | 28M | 4.5G | 755 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/156nWJy4Q28rDlrX-rRbI3w)/[config](configs/swin/swin_tiny_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745562/log_swin_tiny_patch4_window7_224.txt) | -| Swin-S | ImageNet-1K | 224x224 | 83.2 | 96.2 | 50M | 8.7G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1KFjpj3Efey3LmtE1QqPeQg)/[config](configs/swin/swin_small_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745563/log_swin_small_patch4_window7_224.txt) | -| Swin-B | ImageNet-1K | 224x224 | 83.5 | 96.5 | 88M | 15.4G | 278 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/16bqCTEc70nC_isSsgBSaqQ)/[config](configs/swin/swin_base_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745564/log_swin_base_patch4_window7_224.txt) | -| Swin-B | ImageNet-1K | 384x384 | 84.5 | 97.0 | 88M | 47.1G | 85 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth)/[baidu](https://pan.baidu.com/s/1xT1cu740-ejW7htUdVLnmw)/[config](configs/swin/swin_base_patch4_window12_384_finetune.yaml) | -| Swin-T | ImageNet-22K | 224x224 | 80.9 | 96.0 | 28M | 4.5G | 755 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1vct0VYwwQQ8PYkBjwSSBZQ?pwd=swin)/[config](configs/swin/swin_tiny_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22kto1k_finetune.pth)/[baidu](https://pan.baidu.com/s/1K0OO-nGZDPkR8fm_r83e8Q?pwd=swin)/[config](configs/swin/swin_tiny_patch4_window7_224_22kto1k_finetune.yaml) | -| Swin-S | ImageNet-22K | 224x224 | 83.2 | 97.0 | 50M | 8.7G | 437 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/11NC1xdT5BAGBgazdTme5Sg?pwd=swin)/[config](configs/swin/swin_small_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22kto1k_finetune.pth)/[baidu](https://pan.baidu.com/s/10RFVfjQJhwPfeHrmxQUaLw?pwd=swin)/[config](configs/swin/swin_small_patch4_window7_224_22kto1k_finetune.yaml) | -| Swin-B | ImageNet-22K | 224x224 | 85.2 | 97.5 | 88M | 15.4G | 278 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1y1Ec3UlrKSI8IMtEs-oBXA)/[config](configs/swin/swin_base_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1n_wNkcbRxVXit8r_KrfAVg)/[config](configs/swin/swin_base_patch4_window7_224_22kto1k_finetune.yaml) | -| Swin-B | ImageNet-22K | 384x384 | 86.4 | 98.0 | 88M | 47.1G | 85 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1vwJxnJcVqcLZAw9HaqiR6g) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1caKTSdoLJYoi4WBcnmWuWg)/[config](configs/swin/swin_base_patch4_window12_384_22kto1k_finetune.yaml) | -| Swin-L | ImageNet-22K | 224x224 | 86.3 | 97.9 | 197M | 34.5G | 141 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1pws3rOTFuOebBYP3h6Kx8w)/[config](configs/swin/swin_large_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1NkQApMWUhxBGjk1ne6VqBQ)/[config](configs/swin/swin_large_patch4_window7_224_22kto1k_finetune.yaml) | -| Swin-L | ImageNet-22K | 384x384 | 87.3 | 98.2 | 197M | 103.9G | 42 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1sl7o_bJA143OD7UqSLAMoA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1X0FLHQyPOC6Kmv2CmgxJvA)/[config](configs/swin/swin_large_patch4_window12_384_22kto1k_finetune.yaml) | +3. Download Data -**ImageNet-1K and ImageNet-22K Pretrained Swin-V2 Models** +We use the iNat21 dataseta available on [GitHub](https://github.com/visipedia/inat_comp/tree/master/2021) -| name | pretrain | resolution | window |acc@1 | acc@5 | #params | FLOPs | FPS |22K model | 1K model | -|:---------------------:| :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---:|:---: |:---: | -| SwinV2-T | ImageNet-1K | 256x256 | 8x8 | 81.8 | 95.9 | 28M | 5.9G | 572 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1RzLkAH_5OtfRCJe6Vlg6rg?pwd=swin)/[config](configs/swinv2/swinv2_tiny_patch4_window8_256.yaml) | -| SwinV2-S | ImageNet-1K | 256x256 | 8x8 | 83.7 | 96.6 | 50M | 11.5G | 327 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/195PdA41szEduW3jEtRSa4Q?pwd=swin)/[config](configs/swinv2/swinv2_small_patch4_window8_256.yaml) | -| SwinV2-B | ImageNet-1K | 256x256 | 8x8 | 84.2 | 96.9 | 88M | 20.3G | 217 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/18AfMSz3dPyzIvP1dKuERvQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window8_256.yaml) | -| SwinV2-T | ImageNet-1K | 256x256 | 16x16 | 82.8 | 96.2 | 28M | 6.6G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1dyK3cK9Xipmv6RnTtrPocw?pwd=swin)/[config](configs/swinv2/swinv2_tiny_patch4_window16_256.yaml) | -| SwinV2-S | ImageNet-1K | 256x256 | 16x16 | 84.1 | 96.8 | 50M | 12.6G | 257 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1ZIPiSfWNKTPp821Ka-Mifw?pwd=swin)/[config](configs/swinv2/swinv2_small_patch4_window16_256.yaml) | -| SwinV2-B | ImageNet-1K | 256x256 | 16x16 | 84.6 | 97.0 | 88M | 21.8G | 174 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1dlDQGn8BXCmnh7wQSM5Nhw?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window16_256.yaml) | -| SwinV2-B\* | ImageNet-22K | 256x256 | 16x16 | 86.2 | 97.9 | 88M | 21.8G | 174 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/1Xc2rsSsRQz_sy5mjgfxrMQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/1sgstld4MgGsZxhUAW7MlmQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.yaml) | -| SwinV2-B\* | ImageNet-22K | 384x384 | 24x24 | 87.1 | 98.2 | 88M | 54.7G | 57 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/1Xc2rsSsRQz_sy5mjgfxrMQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/17u3sEQaUYlvfL195rrORzQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.yaml) | -| SwinV2-L\* | ImageNet-22K | 256x256 | 16x16 | 86.9 | 98.0 | 197M | 47.5G | 95 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/11PhCV7qAGXtZ8dXNgyiGOw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/1pqp31N80qIWjFPbudzB6Bw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.yaml) | -| SwinV2-L\* | ImageNet-22K | 384x384 | 24x24 | 87.6 | 98.3 | 197M | 115.4G | 33 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/11PhCV7qAGXtZ8dXNgyiGOw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/13URdNkygr3Xn0N3e6IwjgA?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.yaml) | +``` +cd /mnt/10tb +mkdir -p data/inat21 +cd data/inat21 +mkdir compressed raw +cd compressed +wget https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz +wget https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz + +# pv is just a progress bar +pv val.tar.gz | tar -xz +mv val ../raw/ # if I knew how tar worked I could have it extract to raw/ + +pv train.tar.gz | tar -xz +mv train ../raw/ +``` -Note: -- SwinV2-B\* (SwinV2-L\*) with input resolution of 256x256 and 384x384 both fine-tuned from the same pre-training model using a smaller input resolution of 192x192. -- SwinV2-B\* (384x384) achieves 78.08 acc@1 on ImageNet-1K-V2 while SwinV2-L\* (384x384) achieves 78.31. +4. Preprocess iNat 21 -**ImageNet-1K Pretrained Swin MLP Models** +Use your root data folder and your size of choice. -| name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS | 1K model | -| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | -| [Mixer-B/16](https://arxiv.org/pdf/2105.01601.pdf) | ImageNet-1K | 224x224 | 76.4 | - | 59M | 12.7G | - | [official repo](https://github.com/google-research/vision_transformer) | -| [ResMLP-S24](https://arxiv.org/abs/2105.03404) | ImageNet-1K | 224x224 | 79.4 | - | 30M | 6.0G | 715 | [timm](https://github.com/rwightman/pytorch-image-models) | -| [ResMLP-B24](https://arxiv.org/abs/2105.03404) | ImageNet-1K | 224x224 | 81.0 | - | 116M | 23.0G | 231 | [timm](https://github.com/rwightman/pytorch-image-models) | -| Swin-T/C24 | ImageNet-1K | 256x256 | 81.6 | 95.7 | 28M | 5.9G | 563 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_tiny_c24_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/17k-7l6Sxt7uZ7IV0f26GNQ)/[config](configs/swin/swin_tiny_c24_patch4_window8_256.yaml) | -| SwinMLP-T/C24 | ImageNet-1K | 256x256 | 79.4 | 94.6 | 20M | 4.0G | 807 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c24_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1Sa4vP5R0M2RjfIe9HIga-Q)/[config](configs/swin/swin_mlp_tiny_c24_patch4_window8_256.yaml) | -| SwinMLP-T/C12 | ImageNet-1K | 256x256 | 79.6 | 94.7 | 21M | 4.0G | 792 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c12_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1mM9J2_DEVZHUB5ASIpFl0w)/[config](configs/swin/swin_mlp_tiny_c12_patch4_window8_256.yaml) | -| SwinMLP-T/C6 | ImageNet-1K | 256x256 | 79.7 | 94.9 | 23M | 4.0G | 766 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c6_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1hUTYVT2W1CsjICw-3W-Vjg)/[config](configs/swin/swin_mlp_tiny_c6_patch4_window8_256.yaml) | -| SwinMLP-B | ImageNet-1K | 224x224 | 81.3 | 95.3 | 61M | 10.4G | 409 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1zww3dnbX3GxNiGfb-GwyUg)/[config](configs/swin/swin_mlp_base_patch4_window7_224.yaml) | +``` +export DATA_DIR=/mnt/10tb/data/inat21/ +python -m data.inat preprocess $DATA_DIR val resize 192 +python -m data.inat preprocess $DATA_DIR train resize 192 +python -m data.inat preprocess $DATA_DIR val resize 256 +python -m data.inat preprocess $DATA_DIR train resize 256 +``` -Note: access code for `baidu` is `swin`. C24 means each head has 24 channels. +5. Login to Wandb -**ImageNet-22K Pretrained Swin-MoE Models** +``` +wandb login +``` -- Please refer to [get_started](get_started.md#mixture-of-experts-support) for instructions on running Swin-MoE. -- Pretrained models for Swin-MoE can be found in [MODEL HUB](MODELHUB.md#imagenet-22k-pretrained-swin-moe-models) +6. Set up an `env.fish` file: -## Main Results on Downstream Tasks +You need to provide `$VENV` and a `$RUN_OUTPUT` environment variables. +I recommend using a file to save these variables. -**COCO Object Detection (2017 val)** +In fish: -| Backbone | Method | pretrain | Lr Schd | box mAP | mask mAP | #params | FLOPs | -| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | -| Swin-T | Mask R-CNN | ImageNet-1K | 3x | 46.0 | 41.6 | 48M | 267G | -| Swin-S | Mask R-CNN | ImageNet-1K | 3x | 48.5 | 43.3 | 69M | 359G | -| Swin-T | Cascade Mask R-CNN | ImageNet-1K | 3x | 50.4 | 43.7 | 86M | 745G | -| Swin-S | Cascade Mask R-CNN | ImageNet-1K | 3x | 51.9 | 45.0 | 107M | 838G | -| Swin-B | Cascade Mask R-CNN | ImageNet-1K | 3x | 51.9 | 45.0 | 145M | 982G | -| Swin-T | RepPoints V2 | ImageNet-1K | 3x | 50.0 | - | 45M | 283G | -| Swin-T | Mask RepPoints V2 | ImageNet-1K | 3x | 50.3 | 43.6 | 47M | 292G | -| Swin-B | HTC++ | ImageNet-22K | 6x | 56.4 | 49.1 | 160M | 1043G | -| Swin-L | HTC++ | ImageNet-22K | 3x | 57.1 | 49.5 | 284M | 1470G | -| Swin-L | HTC++* | ImageNet-22K | 3x | 58.0 | 50.4 | 284M | - | +```fish +# scripts/env.fish +set -gx VENV venv +set -gx RUN_OUTPUT /mnt/10tb/models/hierarchical-vision +``` -Note: * indicates multi-scale testing. +Then run `source scripts/env.fish` -**ADE20K Semantic Segmentation (val)** +## AWS Helpers -| Backbone | Method | pretrain | Crop Size | Lr Schd | mIoU | mIoU (ms+flip) | #params | FLOPs | -| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | -| Swin-T | UPerNet | ImageNet-1K | 512x512 | 160K | 44.51 | 45.81 | 60M | 945G | -| Swin-S | UperNet | ImageNet-1K | 512x512 | 160K | 47.64 | 49.47 | 81M | 1038G | -| Swin-B | UperNet | ImageNet-1K | 512x512 | 160K | 48.13 | 49.72 | 121M | 1188G | -| Swin-B | UPerNet | ImageNet-22K | 640x640 | 160K | 50.04 | 51.66 | 121M | 1841G | -| Swin-L | UperNet | ImageNet-22K | 640x640 | 160K | 52.05 | 53.53 | 234M | 3230G | - -## Citing Swin Transformer +Uninstall v1 of awscli: ``` -@inproceedings{liu2021Swin, - title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows}, - author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining}, - booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, - year={2021} -} -``` -## Citing Local Relation Networks (the first full-attention visual backbone) -``` -@inproceedings{hu2019local, - title={Local Relation Networks for Image Recognition}, - author={Hu, Han and Zhang, Zheng and Xie, Zhenda and Lin, Stephen}, - booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, - pages={3464--3473}, - year={2019} -} -``` -## Citing Swin Transformer V2 +sudo /usr/local/bin/pip uninstall awscli ``` -@inproceedings{liu2021swinv2, - title={Swin Transformer V2: Scaling Up Capacity and Resolution}, - author={Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo}, - booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)}, - year={2022} -} -``` -## Citing SimMIM (a self-supervised approach that enables SwinV2-G) -``` -@inproceedings{xie2021simmim, - title={SimMIM: A Simple Framework for Masked Image Modeling}, - author={Xie, Zhenda and Zhang, Zheng and Cao, Yue and Lin, Yutong and Bao, Jianmin and Yao, Zhuliang and Dai, Qi and Hu, Han}, - booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)}, - year={2022} -} -``` -## Citing SimMIM-data-scaling -``` -@article{xie2022data, - title={On Data Scaling in Masked Image Modeling}, - author={Xie, Zhenda and Zhang, Zheng and Cao, Yue and Lin, Yutong and Wei, Yixuan and Dai, Qi and Hu, Han}, - journal={arXiv preprint arXiv:2206.04664}, - year={2022} -} -``` -## Citing Swin-MoE + +Install v2: ``` -@misc{hwang2022tutel, - title={Tutel: Adaptive Mixture-of-Experts at Scale}, - author={Changho Hwang and Wei Cui and Yifan Xiong and Ziyue Yang and Ze Liu and Han Hu and Zilong Wang and Rafael Salas and Jithin Jose and Prabhat Ram and Joe Chau and Peng Cheng and Fan Yang and Mao Yang and Yongqiang Xiong}, - year={2022}, - eprint={2206.03382}, - archivePrefix={arXiv} -} +cd ~/pkg +curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" +unzip awscliv2.zip +./aws/install --bin-dir ~/.local/bin --install-dir ~/.local/aws-cli ``` - -## Getting Started - -- For **Image Classification**, please see [get_started.md](get_started.md) for detailed instructions. -- For **Object Detection and Instance Segmentation**, please see [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection). -- For **Semantic Segmentation**, please see [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation). -- For **Self-Supervised Learning**, please see [Transformer-SSL](https://github.com/SwinTransformer/Transformer-SSL). -- For **Video Recognition**, please see [Video Swin Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer). - -## Third-party Usage and Experiments - -***In this pargraph, we cross link third-party repositories which use Swin and report results. You can let us know by raising an issue*** - -(`Note please report accuracy numbers and provide trained models in your new repository to facilitate others to get sense of correctness and model behavior`) - -[06/30/2022] Swin Transformers (V1) inference implemented in FasterTransformer: [FasterTransformer](https://github.com/NVIDIA/FasterTransformer/blob/main/docs/swin_guide.md) - -[05/12/2022] Swin Transformers (V1) implemented in TensorFlow with the pre-trained parameters ported into them. Find the implementation, -TensorFlow weights, code example here in [this repository](https://github.com/sayakpaul/swin-transformers-tf/). - -[04/06/2022] Swin Transformer for Audio Classification: [Hierarchical Token Semantic Audio Transformer](https://github.com/RetroCirce/HTS-Audio-Transformer). - -[12/21/2021] Swin Transformer for StyleGAN: [StyleSwin](https://github.com/microsoft/StyleSwin) - -[12/13/2021] Swin Transformer for Face Recognition: [FaceX-Zoo](https://github.com/JDAI-CV/FaceX-Zoo) - -[08/29/2021] Swin Transformer for Image Restoration: [SwinIR](https://github.com/JingyunLiang/SwinIR) - -[08/12/2021] Swin Transformer for person reID: [https://github.com/layumi/Person_reID_baseline_pytorch](https://github.com/layumi/Person_reID_baseline_pytorch) - -[06/29/2021] Swin-Transformer in PaddleClas and inference based on whl package: [https://github.com/PaddlePaddle/PaddleClas](https://github.com/PaddlePaddle/PaddleClas) - -[04/14/2021] Swin for RetinaNet in Detectron: https://github.com/xiaohu2015/SwinT_detectron2. - -[04/16/2021] Included in a famous model zoo: https://github.com/rwightman/pytorch-image-models. - -[04/20/2021] Swin-Transformer classifier inference using TorchServe: https://github.com/kamalkraj/Swin-Transformer-Serve - -## Contributing - -This project welcomes contributions and suggestions. Most contributions require you to agree to a -Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us -the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. - -When you submit a pull request, a CLA bot will automatically determine whether you need to provide -a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions -provided by the bot. You will only need to do this once across all repos using our CLA. - -This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). -For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or -contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. - -## Trademarks - -This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft -trademarks or logos is subject to and must follow -[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). -Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. -Any use of third-party trademarks or logos are subject to those third-party's policies. diff --git a/config.py b/config.py index 1671ec34..27e3d295 100644 --- a/config.py +++ b/config.py @@ -6,33 +6,34 @@ # --------------------------------------------------------' import os + import yaml from yacs.config import CfgNode as CN + + _C = CN() # Base config files -_C.BASE = [''] +_C.BASE = [""] # ----------------------------------------------------------------------------- # Data settings # ----------------------------------------------------------------------------- _C.DATA = CN() -# Batch size for a single GPU, could be overwritten by command line argument -_C.DATA.BATCH_SIZE = 128 # Path to dataset, could be overwritten by command line argument -_C.DATA.DATA_PATH = '' +_C.DATA.DATA_PATH = "" # Dataset name -_C.DATA.DATASET = 'imagenet' +_C.DATA.DATASET = "imagenet" # Input image size _C.DATA.IMG_SIZE = 224 # Interpolation to resize image (random, bilinear, bicubic) -_C.DATA.INTERPOLATION = 'bicubic' +_C.DATA.INTERPOLATION = "bicubic" # Use zipped dataset instead of folder dataset # could be overwritten by command line argument _C.DATA.ZIP_MODE = False # Cache Data in Memory, could be overwritten by command line argument -_C.DATA.CACHE_MODE = 'part' +_C.DATA.CACHE_MODE = "part" # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. _C.DATA.PIN_MEMORY = True # Number of data loading threads @@ -48,14 +49,14 @@ # ----------------------------------------------------------------------------- _C.MODEL = CN() # Model type -_C.MODEL.TYPE = 'swin' +_C.MODEL.TYPE = "swin" # Model name -_C.MODEL.NAME = 'swin_tiny_patch4_window7_224' +_C.MODEL.NAME = "swin_tiny_patch4_window7_224" # Pretrained weight from checkpoint, could be imagenet22k pretrained weight # could be overwritten by command line argument -_C.MODEL.PRETRAINED = '' +_C.MODEL.PRETRAINED = "" # Checkpoint to resume, could be overwritten by command line argument -_C.MODEL.RESUME = '' +_C.MODEL.RESUME = "" # Number of classes, overwritten in data preparation _C.MODEL.NUM_CLASSES = 1000 # Dropout rate @@ -73,7 +74,7 @@ _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] _C.MODEL.SWIN.WINDOW_SIZE = 7 -_C.MODEL.SWIN.MLP_RATIO = 4. +_C.MODEL.SWIN.MLP_RATIO = 4.0 _C.MODEL.SWIN.QKV_BIAS = True _C.MODEL.SWIN.QK_SCALE = None _C.MODEL.SWIN.APE = False @@ -87,7 +88,7 @@ _C.MODEL.SWINV2.DEPTHS = [2, 2, 6, 2] _C.MODEL.SWINV2.NUM_HEADS = [3, 6, 12, 24] _C.MODEL.SWINV2.WINDOW_SIZE = 7 -_C.MODEL.SWINV2.MLP_RATIO = 4. +_C.MODEL.SWINV2.MLP_RATIO = 4.0 _C.MODEL.SWINV2.QKV_BIAS = True _C.MODEL.SWINV2.APE = False _C.MODEL.SWINV2.PATCH_NORM = True @@ -101,7 +102,7 @@ _C.MODEL.SWIN_MOE.DEPTHS = [2, 2, 6, 2] _C.MODEL.SWIN_MOE.NUM_HEADS = [3, 6, 12, 24] _C.MODEL.SWIN_MOE.WINDOW_SIZE = 7 -_C.MODEL.SWIN_MOE.MLP_RATIO = 4. +_C.MODEL.SWIN_MOE.MLP_RATIO = 4.0 _C.MODEL.SWIN_MOE.QKV_BIAS = True _C.MODEL.SWIN_MOE.QK_SCALE = None _C.MODEL.SWIN_MOE.APE = False @@ -131,7 +132,7 @@ _C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2] _C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24] _C.MODEL.SWIN_MLP.WINDOW_SIZE = 7 -_C.MODEL.SWIN_MLP.MLP_RATIO = 4. +_C.MODEL.SWIN_MLP.MLP_RATIO = 4.0 _C.MODEL.SWIN_MLP.APE = False _C.MODEL.SWIN_MLP.PATCH_NORM = True @@ -145,6 +146,10 @@ # Training settings # ----------------------------------------------------------------------------- _C.TRAIN = CN() +# Batch size for a single GPU, could be overwritten by command line argument +_C.TRAIN.DEVICE_BATCH_SIZE = 128 +# Global batch size = DEVICE_BATCH_SIZE * N_PROCS * ACCUMULATION_STEPS +_C.TRAIN.GLOBAL_BATCH_SIZE = 1024 _C.TRAIN.START_EPOCH = 0 _C.TRAIN.EPOCHS = 300 _C.TRAIN.WARMUP_EPOCHS = 20 @@ -165,7 +170,7 @@ # LR scheduler _C.TRAIN.LR_SCHEDULER = CN() -_C.TRAIN.LR_SCHEDULER.NAME = 'cosine' +_C.TRAIN.LR_SCHEDULER.NAME = "cosine" # Epoch interval to decay LR, used in StepLRScheduler _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 # LR decay rate, used in StepLRScheduler @@ -178,7 +183,7 @@ # Optimizer _C.TRAIN.OPTIMIZER = CN() -_C.TRAIN.OPTIMIZER.NAME = 'adamw' +_C.TRAIN.OPTIMIZER.NAME = "adamw" # Optimizer Epsilon _C.TRAIN.OPTIMIZER.EPS = 1e-8 # Optimizer Betas @@ -193,6 +198,16 @@ _C.TRAIN.MOE = CN() # Only save model on master device _C.TRAIN.MOE.SAVE_MASTER = False + +# Hierarchical coefficients for loss +_C.TRAIN.HIERARCHICAL_COEFFS = (1,) + +# [Debugging] How many batches of the training data to overfit. +_C.TRAIN.OVERFIT_BATCHES = 0 + +# Percentage of data for low data regieme +_C.TRAIN.DATA_PERCENTAGE = 1 + # ----------------------------------------------------------------------------- # Augmentation settings # ----------------------------------------------------------------------------- @@ -200,11 +215,11 @@ # Color jitter factor _C.AUG.COLOR_JITTER = 0.4 # Use AutoAugment policy. "v0" or "original" -_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' +_C.AUG.AUTO_AUGMENT = "rand-m9-mstd0.5-inc1" # Random erase prob _C.AUG.REPROB = 0.25 # Random erase mode -_C.AUG.REMODE = 'pixel' +_C.AUG.REMODE = "pixel" # Random erase count _C.AUG.RECOUNT = 1 # Mixup alpha, mixup enabled if > 0 @@ -218,7 +233,7 @@ # Probability of switching to cutmix when both mixup and cutmix enabled _C.AUG.MIXUP_SWITCH_PROB = 0.5 # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" -_C.AUG.MIXUP_MODE = 'batch' +_C.AUG.MIXUP_MODE = "batch" # ----------------------------------------------------------------------------- # Testing settings @@ -230,20 +245,32 @@ _C.TEST.SEQUENTIAL = False _C.TEST.SHUFFLE = False +# ----------------------------------------------------------------------------- +# Experiment Settings +# ----------------------------------------------------------------------------- +_C.EXPERIMENT = CN() +# The experiment name. This is a human-readable name that is easy to read. +_C.EXPERIMENT.NAME = "default-dragonfruit" +# The wandb id for logging. +# Generate this id with scripts/generate_wandb_id +_C.EXPERIMENT.WANDB_ID = "" + # ----------------------------------------------------------------------------- # Misc # ----------------------------------------------------------------------------- + +# Whether we are doing hierarchical classification +_C.HIERARCHICAL = False + # [SimMIM] Whether to enable pytorch amp, overwritten by command line argument _C.ENABLE_AMP = False # Enable Pytorch automatic mixed precision (amp). _C.AMP_ENABLE = True # [Deprecated] Mixed precision opt level of apex, if O0, no apex amp is used ('O0', 'O1', 'O2') -_C.AMP_OPT_LEVEL = '' +_C.AMP_OPT_LEVEL = "" # Path to output folder, overwritten by command line argument -_C.OUTPUT = '' -# Tag of experiment, overwritten by command line argument -_C.TAG = 'default' +_C.OUTPUT = "" # Frequency to save checkpoint _C.SAVE_FREQ = 1 # Frequency to logging info @@ -263,15 +290,15 @@ def _update_config_from_file(config, cfg_file): config.defrost() - with open(cfg_file, 'r') as f: + with open(cfg_file, "r") as f: yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) - for cfg in yaml_cfg.setdefault('BASE', ['']): + for cfg in yaml_cfg.setdefault("BASE", [""]): if cfg: _update_config_from_file( config, os.path.join(os.path.dirname(cfg_file), cfg) ) - print('=> merge config from {}'.format(cfg_file)) + print("=> merge config from {}".format(cfg_file)) config.merge_from_file(cfg_file) config.freeze() @@ -284,60 +311,91 @@ def update_config(config, args): config.merge_from_list(args.opts) def _check_args(name): - if hasattr(args, name) and eval(f'args.{name}'): + if hasattr(args, name) and eval(f"args.{name}"): return True return False # merge from specific arguments - if _check_args('batch_size'): - config.DATA.BATCH_SIZE = args.batch_size - if _check_args('data_path'): - config.DATA.DATA_PATH = args.data_path - if _check_args('zip'): + if _check_args("batch_size"): + config.TRAIN.DEVICE_BATCH_SIZE = args.batch_size + if _check_args("data_path"): + config.DATA.DATA_PATH = os.path.abspath(args.data_path) + if _check_args("zip"): config.DATA.ZIP_MODE = True - if _check_args('cache_mode'): + if _check_args("cache_mode"): config.DATA.CACHE_MODE = args.cache_mode - if _check_args('pretrained'): + if _check_args("pretrained"): config.MODEL.PRETRAINED = args.pretrained - if _check_args('resume'): + if _check_args("resume"): config.MODEL.RESUME = args.resume - if _check_args('accumulation_steps'): - config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps - if _check_args('use_checkpoint'): + if _check_args("use_checkpoint"): config.TRAIN.USE_CHECKPOINT = True - if _check_args('amp_opt_level'): + if _check_args("amp_opt_level"): print("[warning] Apex amp has been deprecated, please use pytorch amp instead!") - if args.amp_opt_level == 'O0': + if args.amp_opt_level == "O0": config.AMP_ENABLE = False - if _check_args('disable_amp'): + if _check_args("disable_amp"): config.AMP_ENABLE = False - if _check_args('output'): + if _check_args("output"): config.OUTPUT = args.output - if _check_args('tag'): - config.TAG = args.tag - if _check_args('eval'): + if _check_args("eval"): config.EVAL_MODE = True - if _check_args('throughput'): + if _check_args("throughput"): config.THROUGHPUT_MODE = True # [SimMIM] - if _check_args('enable_amp'): + if _check_args("enable_amp"): config.ENABLE_AMP = args.enable_amp # for acceleration - if _check_args('fused_window_process'): + if _check_args("fused_window_process"): config.FUSED_WINDOW_PROCESS = True - if _check_args('fused_layernorm'): + if _check_args("fused_layernorm"): config.FUSED_LAYERNORM = True - ## Overwrite optimizer if not None, currently we use it for [fused_adam, fused_lamb] - if _check_args('optim'): + # Overwrite optimizer if not None, currently we use it for [fused_adam, fused_lamb] + if _check_args("optim"): config.TRAIN.OPTIMIZER.NAME = args.optim - # set local rank for distributed training - config.LOCAL_RANK = args.local_rank + if _check_args("low_data"): + config.TRAIN.DATA_PERCENTAGE = args.low_data + + # Use os.environ["LOCAL_RANK"] rather than --local_rank + if "LOCAL_RANK" in os.environ: + # set local rank for distributed training + config.LOCAL_RANK = int(os.environ["LOCAL_RANK"]) + + # Use this to calculate accumulation steps + if "LOCAL_WORLD_SIZE" in os.environ: + + def divide_cleanly(a, b): + assert a % b == 0, f"{a} / {b} has remainder {a % b}" + return a // b + + n_procs = int(os.environ["LOCAL_WORLD_SIZE"]) + desired_device_batch_size = divide_cleanly( + config.TRAIN.GLOBAL_BATCH_SIZE, n_procs + ) + actual_device_batch_size = config.TRAIN.DEVICE_BATCH_SIZE + + if actual_device_batch_size > desired_device_batch_size: + print( + f"Decreasing device batch size from {actual_device_batch_size} to {desired_device_batch_size} so your global bath size is {config.TRAIN.GLOBAL_BATCH_SIZE}, not {desired_device_batch_size * n_procs}!" + ) + config.TRAIN.ACCUMULATION_STEPS = 1 + config.TRAIN.DEVICE_BATCH_SIZE = desired_device_batch_size + elif desired_device_batch_size == actual_device_batch_size: + config.TRAIN.ACCUMULATION_STEPS = 1 + else: + assert desired_device_batch_size > actual_device_batch_size + config.TRAIN.ACCUMULATION_STEPS = divide_cleanly( + desired_device_batch_size, actual_device_batch_size + ) + print( + f"Using {config.TRAIN.ACCUMULATION_STEPS} accumulation steps so your global batch size is {config.TRAIN.GLOBAL_BATCH_SIZE}, not {actual_device_batch_size * n_procs}!" + ) # output folder - config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) + config.OUTPUT = os.path.join(config.OUTPUT, config.EXPERIMENT.NAME) config.freeze() @@ -347,6 +405,231 @@ def get_config(args): # Return a clone so that the defaults will not be altered # This is for the "local variable" use pattern config = _C.clone() + update_config(config, args) + return config + + + +#################################### Added sript for yaml file generation ############################################# + +# Distribution = Literal["normal", "uniform", "loguniform"] +import tomli +import utils +import dataclasses +import copy +from typing_extensions import Literal + +from typing import ( + Any, + Dict, + Iterator, + List, + Optional, + Type, + TypeVar, + Union, +) +Distribution = Literal["normal", "uniform", "loguniform"] +T = TypeVar("T", bound="Config") + +class Config: + @classmethod + def from_dict(cls: Type[T], dct: Dict[str, Any]) -> T: + for field in dataclasses.fields(cls): + if ( + isinstance(field.type, type) + and issubclass(field.type, Config) + and field.name in dct + and not isinstance(dct[field.name], field.type) + ): + if not isinstance(dct[field.name], dict): + logger.warn( + "Subdict is not a dict! [cls: %s, field name: %s, field type: %s, actual type: %s]", + cls, + field.name, + field.type, + type(dct[field.name]), + ) + dct[field.name] = field.type.from_dict(dct[field.name]) + + return cls(**dct) + + @classmethod + def get_toml_name(cls) -> str: + # Because I'm a bad programmer and I do hacky things. + return cls.__name__[: cls.__name__.lower().find("config")].lower() + + @classmethod + def from_existing(cls: Type[T], other: Type[T], **overrides) -> T: + kwargs = {**dataclasses.asdict(other), **overrides} + + return cls(**kwargs) + + @property + def pretty(self) -> str: + return json.dumps(dataclasses.asdict(self), indent=4) + + def __str__(self) -> str: + return json.dumps(dataclasses.asdict(self)) + + def validate_field(self, fname: str, ftype) -> None: + choices = get_args(ftype) + if getattr(self, fname) not in choices: + raise ValueError(f"self.{fname} must be one of {', '.join(choices)}") + +@dataclasses.dataclass(frozen=True) +class RandomVecConfig(Config): + distribution: Optional[Distribution] = None + + # Distribution keyword args. + dist_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self) -> None: + if self.distribution is not None: + self.validate_field("distribution", Distribution) +Layer = Union[ + int, + Literal[ + "sigmoid", + "tanh", + "output", + "cos", + "sine", + "layernorm", + "groupnorm", + "1/x", + "nonlinear-wht", + "dropout", + ], +] + +@dataclasses.dataclass(frozen=True) +class ProjectionConfig(Config): + layers: List[Layer] = dataclasses.field(default_factory=lambda: ["output"]) + layer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + @classmethod + def from_dict(cls, dct) -> "ProjectionConfig": + """ + I reimplement this method because the toml dict will have a string for layers that needs to be evaluated to a real Python list. + """ + for key in dct: + if key == "layers" and isinstance(dct[key], str): + dct[key] = eval(dct[key]) + + return cls(**dct) + +PromptType = Literal["uuid", "token", "vocab", "chunk-n", "natural-n"] + +@dataclasses.dataclass(frozen=True) +class DataConfig(Config): + file: str + overwrite_cache: bool = False + """ + Can be one of 'uuid', 'token', or 'vocab'. + * uuid: encodes a uuid as the prompt (typically between 20-30 tokens for GPT2). + * token: adds a new token to the vocabulary for each chunk (<|start0|>, <|start1|>, etc.) + * vocab: finds an existing token in the vocabulary that's not in any of the examples and uses it as th e prompt. + * chunk-n: "Chunk 1: ", "Chunk 2: ", ... + """ + prompt_type: PromptType = "uuid" + + chunk_length: Union[Literal["longest"], int] = "longest" + + def __post_init__(self) -> None: + if not os.path.exists(self.file): + raise ValueError(f"{self.file} does not exist!") + + self.validate_field("prompt_type", PromptType) + + if self.chunk_length != "longest": + assert isinstance(self.chunk_length, int) + + def get_text(self) -> str: + assert self.file is not None + + with open(self.file, "r") as file: + return file.read() + +@dataclasses.dataclass(frozen=True) +class ModelConfig(Config): + language_model_name_or_path: str + intrinsic_dimension: Optional[int] = None + + # Structure-aware intrinsic dimension (SAID) + # Has no effect when intrinsic_dimension is None. + intrinsic_dimension_said: bool = False + + # temperature of 1.0 has no effect, lower tend toward greedy sampling + temperature: float = 1.0 + + # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_k: int = 0 + + # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + top_p: float = 0.9 + + # primarily useful for CTRL model; in that case, use 1.2 + repetition_penalty: float = 1.0 + + # optional stop token (ignore text generated after this token) + stop_token: Optional[str] = None + + # context window size + context_window: int = 1024 + + # dropout probability for fully connected layers in embeddings, encoder, and pooler, embeddings, and attention. + dropout: float = 0.0 + + # dropout probability for the intrinsic dimension layer(s) + int_dim_dropout: float = 0.0 + + # Whether to use pre-trained weights. + pretrained: bool = True + + random_vector: RandomVecConfig = dataclasses.field(default_factory=RandomVecConfig) + + projection: ProjectionConfig = dataclasses.field(default_factory=ProjectionConfig) + + normalized: bool = True + scaled: bool = False + scaling_factor: float = 1 + + def __post_init__(self) -> None: + assert isinstance(self.random_vector, RandomVecConfig), str( + type(self.random_vector) + ) + assert isinstance(self.projection, ProjectionConfig), str(type(self.projection)) + +SeedSource = Literal["trial", "config", "random"] + +@dataclasses.dataclass(frozen=True) +class ExperimentConfig(Config): + model: ModelConfig + # tokenizer: TokenizerConfig + + ####Below two lines commented by me ############## + data: DataConfig + # training: TrainingConfig + + trials: int = 3 + save_weights: bool = True + seed_source: SeedSource = "trial" + seed: int = 0 + + def __post_init__(self) -> None: + self.validate_field("seed_source", SeedSource) + + +# def load_configs(config_file: str) -> Iterator[ExperimentConfig]: +# """ +# A config file could contain many experiments. For any field in a config file, if it is a list, then it turns into multiple experiments. If there are multiple lists, then each combination of elements from each list forms a new experiment. +# """ +# with open(config_file, "r") as file: +# config_dict = tomli.loads(file.read()) + +# for flat in utils.flattened(config_dict): +# yield ExperimentConfig.from_dict(copy.deepcopy(flat)) \ No newline at end of file diff --git a/configs/hierarchical-vision-project/bold-banana-192.yaml b/configs/hierarchical-vision-project/bold-banana-192.yaml new file mode 100644 index 00000000..208d989e --- /dev/null +++ b/configs/hierarchical-vision-project/bold-banana-192.yaml @@ -0,0 +1,29 @@ +DATA: + DATASET: inat21 + IMG_SIZE: 192 + NUM_WORKERS: 32 +MODEL: + TYPE: swinv2 + NAME: swinv2_base_window12 + DROP_PATH_RATE: 0.2 + SWINV2: + EMBED_DIM: 128 + DEPTHS: [ 2, 2, 18, 2 ] + NUM_HEADS: [ 4, 8, 16, 32 ] + WINDOW_SIZE: 12 +TRAIN: + # Want a global batch size of 2048 because SwinV2 was trained on 16 V100s with batch size 128 (I think) + # But we are going to use a global batch size of 1024 because it's faster (throughput). + GLOBAL_BATCH_SIZE: 1024 + + # We are using limited epochs based on pre-training configs for imagenet22k + # Then we will pre-train on 256x256 for 30 epochs + EPOCHS: 90 + WARMUP_EPOCHS: 5 + WEIGHT_DECAY: 0.1 + +SAVE_FREQ: 4 + +EXPERIMENT: + NAME: bold-banana-192 + WANDB_ID: 12907g41 diff --git a/configs/hierarchical-vision-project/fuzzy-fig-192.yaml b/configs/hierarchical-vision-project/fuzzy-fig-192.yaml new file mode 100644 index 00000000..d68232a5 --- /dev/null +++ b/configs/hierarchical-vision-project/fuzzy-fig-192.yaml @@ -0,0 +1,32 @@ +DATA: + DATASET: inat21 + IMG_SIZE: 192 + NUM_WORKERS: 32 +MODEL: + TYPE: swinv2 + NAME: swinv2_base_window12 + DROP_PATH_RATE: 0.2 + SWINV2: + EMBED_DIM: 128 + DEPTHS: [ 2, 2, 18, 2 ] + NUM_HEADS: [ 4, 8, 16, 32 ] + WINDOW_SIZE: 12 +TRAIN: + # Want a global batch size of 2048 because SwinV2 was trained on 16 V100s with batch size 128 (I think) + # But we are going to use a global batch size of 1024 because it's faster (throughput). + GLOBAL_BATCH_SIZE: 1024 + + # We are using limited epochs based on pre-training configs for imagenet22k + # Then we will pre-train on 256x256 for 30 epochs + EPOCHS: 90 + WARMUP_EPOCHS: 5 + WEIGHT_DECAY: 0.1 + + # Use 1/4 of original learning rates + BASE_LR: 1.25e-4 + WARMUP_LR: 1.25e-7 + MIN_LR: 1.25e-6 + +EXPERIMENT: + NAME: fuzzy-fig-192 + WANDB_ID: 2c0bq1h2 diff --git a/configs/hierarchical-vision-project/groovy-grape-192.yaml b/configs/hierarchical-vision-project/groovy-grape-192.yaml new file mode 100644 index 00000000..02673338 --- /dev/null +++ b/configs/hierarchical-vision-project/groovy-grape-192.yaml @@ -0,0 +1,37 @@ + +DATA: + DATASET: inat21 + IMG_SIZE: 192 + NUM_WORKERS: 32 +MODEL: + TYPE: swinv2 + NAME: swinv2_base_window12 + DROP_PATH_RATE: 0.2 + SWINV2: + EMBED_DIM: 128 + DEPTHS: [ 2, 2, 18, 2 ] + NUM_HEADS: [ 4, 8, 16, 32 ] + WINDOW_SIZE: 12 +TRAIN: + # Want a global batch size of 2048 because SwinV2 was trained on 16 V100s with batch size 128 (I think) + # But we are going to use a global batch size of 1024 because it's faster (throughput). + GLOBAL_BATCH_SIZE: 1024 + + # We are using limited epochs based on pre-training configs for imagenet22k + # Then we will pre-train on 256x256 for 30 epochs + EPOCHS: 90 + WARMUP_EPOCHS: 5 + WEIGHT_DECAY: 0.1 + + # Use 1/4 of original learning rates + BASE_LR: 1.25e-4 + WARMUP_LR: 1.25e-7 + MIN_LR: 1.25e-6 + + HIERARCHICAL_COEFFS: [ 8, 5.65, 4, 2.82, 2, 1.41, 1 ] + +EXPERIMENT: + NAME: groovy-grape-192 + WANDB_ID: 3jcq2v9b + +HIERARCHICAL: true diff --git a/configs/hierarchical-vision-project/groovy-grape-256.yaml b/configs/hierarchical-vision-project/groovy-grape-256.yaml new file mode 100644 index 00000000..d000dccd --- /dev/null +++ b/configs/hierarchical-vision-project/groovy-grape-256.yaml @@ -0,0 +1,34 @@ +DATA: + DATASET: inat21 + IMG_SIZE: 256 + NUM_WORKERS: 32 +MODEL: + TYPE: swinv2 + NAME: swinv2_base_window16 + DROP_PATH_RATE: 0.2 + SWINV2: + EMBED_DIM: 128 + DEPTHS: [ 2, 2, 18, 2 ] + NUM_HEADS: [ 4, 8, 16, 32 ] + WINDOW_SIZE: 16 + PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ] +TRAIN: + GLOBAL_BATCH_SIZE: 1024 + EPOCHS: 30 + WARMUP_EPOCHS: 5 + + # Should weight decay be this low? + # I don't think so, but I am sticking with the default for now. + WEIGHT_DECAY: 1.0e-8 + + BASE_LR: 2.0e-05 + WARMUP_LR: 2.0e-08 + MIN_LR: 2.0e-07 + + HIERARCHICAL_COEFFS: [ 8, 5.65, 4, 2.82, 2, 1.41, 1 ] + +EXPERIMENT: + NAME: groovy-grape-256 + WANDB_ID: 11o47wpm + +HIERARCHICAL: true diff --git a/configs/hierarchical-vision-project/outrageous-orange-192.yaml b/configs/hierarchical-vision-project/outrageous-orange-192.yaml new file mode 100644 index 00000000..15d835e7 --- /dev/null +++ b/configs/hierarchical-vision-project/outrageous-orange-192.yaml @@ -0,0 +1,31 @@ +DATA: + DATASET: inat21 + IMG_SIZE: 192 + NUM_WORKERS: 32 +MODEL: + TYPE: swinv2 + NAME: swinv2_base_window12 + DROP_PATH_RATE: 0.2 + SWINV2: + EMBED_DIM: 128 + DEPTHS: [ 2, 2, 18, 2 ] + NUM_HEADS: [ 4, 8, 16, 32 ] + WINDOW_SIZE: 12 +TRAIN: + # Want a global batch size of 2048 because SwinV2 was trained on 16 V100s with batch size 128 (I think) + # But we are going to use a global batch size of 1024 because it's faster (throughput). + GLOBAL_BATCH_SIZE: 1024 + + # We are using limited epochs based on pre-training configs for imagenet22k + # Then we will pre-train on 256x256 for 30 epochs + EPOCHS: 90 + WARMUP_EPOCHS: 5 + WEIGHT_DECAY: 0.1 + HIERARCHICAL_COEFFS: [ 8, 5.65, 4, 2.82, 2, 1.41, 1 ] + +EXPERIMENT: + NAME: outrageous-orange-192 + WANDB_ID: y6zzxboz + +HIERARCHICAL: true +SAVE_FREQ: 4 diff --git a/configs/swinv2/swinv2_base_patch4_window12_192_inat21_hierarchical_lr1.25.yaml b/configs/swinv2/swinv2_base_patch4_window12_192_inat21_hierarchical_lr1.25.yaml new file mode 100644 index 00000000..cc23379a --- /dev/null +++ b/configs/swinv2/swinv2_base_patch4_window12_192_inat21_hierarchical_lr1.25.yaml @@ -0,0 +1,33 @@ +DATA: + DATASET: inat21 + IMG_SIZE: 192 + DATA_PATH: /mnt/10tb/data/inat21/resize-192 + NUM_WORKERS: 32 + BATCH_SIZE: 64 +MODEL: + TYPE: swinv2 + NAME: swinv2_base_patch4_window12_192_inat21_hierarchical_lr1.25 + DROP_PATH_RATE: 0.2 + SWINV2: + EMBED_DIM: 128 + DEPTHS: [ 2, 2, 18, 2 ] + NUM_HEADS: [ 4, 8, 16, 32 ] + WINDOW_SIZE: 12 +TRAIN: + # Want a global batch size of 2048 because SwinV2 was trained on 16 V100s with batch size 128 (I think) + # But we are going to use a global batch size of 1024 because it's faster (throughput). + ACCUMULATION_STEPS: 2 + + # We are using limited epochs based on pre-training configs for imagenet22k + # Then we will pre-train on 256x256 for 30 epochs + EPOCHS: 90 + WARMUP_EPOCHS: 5 + WEIGHT_DECAY: 0.1 + + # Use 1/4 of original learning rates + BASE_LR: 1.25e-4 + WARMUP_LR: 1.25e-7 + MIN_LR: 1.25e-6 + HIERARCHICAL_COEFFS: [ 8, 5.65, 4, 2.82, 2, 1.41, 1 ] + +HIERARCHICAL: true diff --git a/configs/swinv2/swinv2_base_patch4_window12_192_inat21_hierarchical_lr2.5.yaml b/configs/swinv2/swinv2_base_patch4_window12_192_inat21_hierarchical_lr2.5.yaml new file mode 100644 index 00000000..5bc4336f --- /dev/null +++ b/configs/swinv2/swinv2_base_patch4_window12_192_inat21_hierarchical_lr2.5.yaml @@ -0,0 +1,33 @@ +DATA: + DATASET: inat21 + IMG_SIZE: 192 + DATA_PATH: /mnt/10tb/data/inat21/resize-192 + NUM_WORKERS: 32 + BATCH_SIZE: 64 +MODEL: + TYPE: swinv2 + NAME: sweet-strawberry-192 + DROP_PATH_RATE: 0.2 + SWINV2: + EMBED_DIM: 128 + DEPTHS: [ 2, 2, 18, 2 ] + NUM_HEADS: [ 4, 8, 16, 32 ] + WINDOW_SIZE: 12 +TRAIN: + # Want a global batch size of 2048 because SwinV2 was trained on 16 V100s with batch size 128 (I think) + # But we are going to use a global batch size of 1024 because it's faster (throughput). + ACCUMULATION_STEPS: 2 + + # We are using limited epochs based on pre-training configs for imagenet22k + # Then we will pre-train on 256x256 for 30 epochs + EPOCHS: 90 + WARMUP_EPOCHS: 5 + WEIGHT_DECAY: 0.1 + + # Use 1/2 of original learning rates + BASE_LR: 2.5e-4 + WARMUP_LR: 2.5e-7 + MIN_LR: 2.5e-6 + HIERARCHICAL_COEFFS: [ 8, 5.65, 4, 2.82, 2, 1.41, 1 ] + +HIERARCHICAL: true diff --git a/configs/swinv2/swinv2_base_patch4_window12_192_inat21_hierarchical_lr5.yaml b/configs/swinv2/swinv2_base_patch4_window12_192_inat21_hierarchical_lr5.yaml new file mode 100644 index 00000000..2a2d422f --- /dev/null +++ b/configs/swinv2/swinv2_base_patch4_window12_192_inat21_hierarchical_lr5.yaml @@ -0,0 +1,29 @@ +DATA: + DATASET: inat21 + IMG_SIZE: 192 + DATA_PATH: /research/nfs_su_809/cv_datasets/inat21/train_val_192 + NUM_WORKERS: 32 + BATCH_SIZE: 128 +MODEL: + TYPE: swinv2 + NAME: outrageous-orange-192 + DROP_PATH_RATE: 0.2 + SWINV2: + EMBED_DIM: 128 + DEPTHS: [ 2, 2, 18, 2 ] + NUM_HEADS: [ 4, 8, 16, 32 ] + WINDOW_SIZE: 12 +TRAIN: + # Want a global batch size of 2048 because SwinV2 was trained on 16 V100s with batch size 128 (I think) + # But we are going to use a global batch size of 1024 because it's faster (throughput). + ACCUMULATION_STEPS: 2 + + # We are using limited epochs based on pre-training configs for imagenet22k + # Then we will pre-train on 256x256 for 30 epochs + EPOCHS: 90 + WARMUP_EPOCHS: 5 + WEIGHT_DECAY: 0.1 + HIERARCHICAL_COEFFS: [ 8, 5.65, 4, 2.82, 2, 1.41, 1 ] + +SAVE_FREQ: 4 +HIERARCHICAL: true diff --git a/configs/swinv2/swinv2_base_patch4_window12to16_192to256_inat21_hierarchical_lr2.5_ft.yaml b/configs/swinv2/swinv2_base_patch4_window12to16_192to256_inat21_hierarchical_lr2.5_ft.yaml new file mode 100644 index 00000000..c3f6b5df --- /dev/null +++ b/configs/swinv2/swinv2_base_patch4_window12to16_192to256_inat21_hierarchical_lr2.5_ft.yaml @@ -0,0 +1,34 @@ +DATA: + DATASET: inat21 + NUM_WORKERS: 32 + BATCH_SIZE: 16 + IMG_SIZE: 256 + DATA_PATH: /mnt/10tb/data/inat21/resize-256 +MODEL: + TYPE: swinv2 + PRETRAINED: /mnt/10tb/models/swinv2_base_patch4_window12_192_inat21_hierarchical_lr2.5_v0_epoch_89.pth + NAME: sweet-strawberry-256 + DROP_PATH_RATE: 0.2 + SWINV2: + EMBED_DIM: 128 + DEPTHS: [ 2, 2, 18, 2 ] + NUM_HEADS: [ 4, 8, 16, 32 ] + WINDOW_SIZE: 16 + PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ] +TRAIN: + # Global batch size of 1024 + ACCUMULATION_STEPS: 8 + EPOCHS: 30 + WARMUP_EPOCHS: 5 + + # Should weight decay be this low? + # I don't think so, but I am sticking with the default for now. + WEIGHT_DECAY: 1.0e-8 + + BASE_LR: 2.0e-05 + WARMUP_LR: 2.0e-08 + MIN_LR: 2.0e-07 + + HIERARCHICAL_COEFFS: [ 8, 5.65, 4, 2.82, 2, 1.41, 1 ] + +HIERARCHICAL: true diff --git a/data/__init__.py b/data/__init__.py index 5baad7ed..a8cfa99a 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -1,6 +1,6 @@ from .build import build_loader as _build_loader -from .data_simmim_pt import build_loader_simmim from .data_simmim_ft import build_loader_finetune +from .data_simmim_pt import build_loader_simmim def build_loader(config, simmim=False, is_pretrain=False): diff --git a/data/build.py b/data/build.py index 5799f253..601ab602 100644 --- a/data/build.py +++ b/data/build.py @@ -6,52 +6,71 @@ # -------------------------------------------------------- import os -import torch +import random + import numpy as np +import torch import torch.distributed as dist +from timm.data import Mixup, create_transform +from torch.utils.data import Subset from torchvision import datasets, transforms -from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.data import Mixup -from timm.data import create_transform from .cached_image_folder import CachedImageFolder +from .constants import data_mean_std +from .hierarchical import HierarchicalImageFolder, HierarchicalMixup from .imagenet22k_dataset import IN22KDATASET from .samplers import SubsetRandomSampler try: from torchvision.transforms import InterpolationMode - def _pil_interp(method): - if method == 'bicubic': + if method == "bicubic": return InterpolationMode.BICUBIC - elif method == 'lanczos': + elif method == "lanczos": return InterpolationMode.LANCZOS - elif method == 'hamming': + elif method == "hamming": return InterpolationMode.HAMMING else: # default bilinear, do we want to allow nearest? return InterpolationMode.BILINEAR - import timm.data.transforms as timm_transforms timm_transforms._pil_interp = _pil_interp -except: +except ImportError: from timm.data.transforms import _pil_interp def build_loader(config): config.defrost() - dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) + dataset_train, config.MODEL.NUM_CLASSES = build_dataset( + is_train=True, config=config + ) config.freeze() - print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset") + print( + f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset" + ) dataset_val, _ = build_dataset(is_train=False, config=config) - print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset") + print( + f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset" + ) + + # Check if we are overfitting some subset of the training data for debugging + if config.TRAIN.OVERFIT_BATCHES > 0: + n_examples = config.TRAIN.OVERFIT_BATCHES * config.TRAIN.DEVICE_BATCH_SIZE + indices = random.sample(range(len(dataset_train)), n_examples) + dataset_train = Subset(dataset_train, indices) + + # Check if training is for low data regieme; select subset of data (newly added script) + if config.TRAIN.DATA_PERCENTAGE < 1: + n_examples = config.TRAIN.DATA_PERCENTAGE * len(dataset_train) #config.TRAIN.DEVICE_BATCH_SIZE + indices = random.sample(range(len(dataset_train)), int(n_examples)) + dataset_train = Subset(dataset_train, indices) num_tasks = dist.get_world_size() global_rank = dist.get_rank() - if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part': + if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == "part": indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size()) sampler_train = SubsetRandomSampler(indices) else: @@ -67,57 +86,102 @@ def build_loader(config): ) data_loader_train = torch.utils.data.DataLoader( - dataset_train, sampler=sampler_train, - batch_size=config.DATA.BATCH_SIZE, + dataset_train, + sampler=sampler_train, + batch_size=config.TRAIN.DEVICE_BATCH_SIZE, num_workers=config.DATA.NUM_WORKERS, pin_memory=config.DATA.PIN_MEMORY, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader( - dataset_val, sampler=sampler_val, - batch_size=config.DATA.BATCH_SIZE, + dataset_val, + sampler=sampler_val, + batch_size=config.TRAIN.DEVICE_BATCH_SIZE, shuffle=False, num_workers=config.DATA.NUM_WORKERS, pin_memory=config.DATA.PIN_MEMORY, - drop_last=False + drop_last=False, ) # setup mixup / cutmix mixup_fn = None - mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None + mixup_active = ( + config.AUG.MIXUP > 0 + or config.AUG.CUTMIX > 0.0 + or config.AUG.CUTMIX_MINMAX is not None + ) if mixup_active: - mixup_fn = Mixup( - mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, - prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, - label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) + mixup_args = dict( + mixup_alpha=config.AUG.MIXUP, + cutmix_alpha=config.AUG.CUTMIX, + cutmix_minmax=config.AUG.CUTMIX_MINMAX, + prob=config.AUG.MIXUP_PROB, + switch_prob=config.AUG.MIXUP_SWITCH_PROB, + mode=config.AUG.MIXUP_MODE, + label_smoothing=config.MODEL.LABEL_SMOOTHING, + num_classes=config.MODEL.NUM_CLASSES, + ) + if config.HIERARCHICAL: + mixup_fn = HierarchicalMixup(**mixup_args) + else: + mixup_fn = Mixup(**mixup_args) return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn def build_dataset(is_train, config): transform = build_transform(is_train, config) - if config.DATA.DATASET == 'imagenet': - prefix = 'train' if is_train else 'val' + if config.DATA.DATASET == "imagenet": + prefix = "train" if is_train else "val" if config.DATA.ZIP_MODE: ann_file = prefix + "_map.txt" prefix = prefix + ".zip@/" - dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform, - cache_mode=config.DATA.CACHE_MODE if is_train else 'part') + dataset = CachedImageFolder( + config.DATA.DATA_PATH, + ann_file, + prefix, + transform, + cache_mode=config.DATA.CACHE_MODE if is_train else "part", + ) else: root = os.path.join(config.DATA.DATA_PATH, prefix) dataset = datasets.ImageFolder(root, transform=transform) nb_classes = 1000 - elif config.DATA.DATASET == 'imagenet22K': - prefix = 'ILSVRC2011fall_whole' + elif config.DATA.DATASET == "imagenet22K": + prefix = "ILSVRC2011fall_whole" if is_train: ann_file = prefix + "_map_train.txt" else: ann_file = prefix + "_map_val.txt" dataset = IN22KDATASET(config.DATA.DATA_PATH, ann_file, transform) nb_classes = 21841 + + elif config.DATA.DATASET == 'nabird': + prefix = 'train' if is_train else 'val' + root = os.path.join(config.DATA.DATA_PATH, prefix) + dataset = datasets.ImageFolder(root, transform=transform) + nb_classes = 555 + + elif config.DATA.DATASET == 'ip102': + prefix = 'train' if is_train else 'val' + root = os.path.join(config.DATA.DATA_PATH, prefix) + dataset = datasets.ImageFolder(root, transform=transform) + nb_classes = 102 + + elif config.DATA.DATASET == "inat21": + if config.DATA.ZIP_MODE: + raise NotImplementedError("We do not support zipped inat21") + prefix = "train" if is_train else "val" + root = os.path.join(config.DATA.DATA_PATH, prefix) + if config.HIERARCHICAL: + dataset = HierarchicalImageFolder(root, transform=transform) + nb_classes = dataset.num_classes + else: + dataset = datasets.ImageFolder(root, transform=transform) + nb_classes = 10_000 else: - raise NotImplementedError("We only support ImageNet Now.") + raise NotImplementedError("We only support ImageNet now.") return dataset, nb_classes @@ -129,8 +193,12 @@ def build_transform(is_train, config): transform = create_transform( input_size=config.DATA.IMG_SIZE, is_training=True, - color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, - auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, + color_jitter=config.AUG.COLOR_JITTER + if config.AUG.COLOR_JITTER > 0 + else None, + auto_augment=config.AUG.AUTO_AUGMENT + if config.AUG.AUTO_AUGMENT != "none" + else None, re_prob=config.AUG.REPROB, re_mode=config.AUG.REMODE, re_count=config.AUG.RECOUNT, @@ -139,7 +207,9 @@ def build_transform(is_train, config): if not resize_im: # replace RandomResizedCropAndInterpolation with # RandomCrop - transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) + transform.transforms[0] = transforms.RandomCrop( + config.DATA.IMG_SIZE, padding=4 + ) return transform t = [] @@ -147,16 +217,29 @@ def build_transform(is_train, config): if config.TEST.CROP: size = int((256 / 224) * config.DATA.IMG_SIZE) t.append( - transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), + transforms.Resize( + size, interpolation=_pil_interp(config.DATA.INTERPOLATION) + ), # to maintain same ratio w.r.t. 224 images ) t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) else: t.append( - transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), - interpolation=_pil_interp(config.DATA.INTERPOLATION)) + transforms.Resize( + (config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), + interpolation=_pil_interp(config.DATA.INTERPOLATION), + ) ) + if config.DATA.DATA_PATH in data_mean_std: + mean, std = data_mean_std[config.DATA.DATA_PATH] + elif config.DATA.DATASET in data_mean_std: + mean, std = data_mean_std[config.DATA.DATASET] + else: + raise RuntimeError( + f"Can't find mean/std for {config.DATA.DATASET} at {config.DATA.DATA_PATH}. Please add it to data/constants.py (try using python -m data.inat normalize for iNat)." + ) + t.append(transforms.ToTensor()) - t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) + t.append(transforms.Normalize(mean, std)) return transforms.Compose(t) diff --git a/data/cached_image_folder.py b/data/cached_image_folder.py index 7e1883b1..732d3236 100644 --- a/data/cached_image_folder.py +++ b/data/cached_image_folder.py @@ -8,11 +8,12 @@ import io import os import time + import torch.distributed as dist import torch.utils.data as data from PIL import Image -from .zipreader import is_zip_path, ZipReader +from .zipreader import ZipReader, is_zip_path def has_file_allowed_extension(filename, extensions): @@ -56,7 +57,7 @@ def make_dataset_with_ann(ann_file, img_prefix, extensions): with open(ann_file, "r") as f: contents = f.readlines() for line_str in contents: - path_contents = [c for c in line_str.split('\t')] + path_contents = [c for c in line_str.split("\t")] im_file_name = path_contents[0] class_index = int(path_contents[1]) @@ -89,21 +90,37 @@ class DatasetFolder(data.Dataset): samples (list): List of (sample path, class_index) tuples """ - def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None, - cache_mode="no"): + def __init__( + self, + root, + loader, + extensions, + ann_file="", + img_prefix="", + transform=None, + target_transform=None, + cache_mode="no", + ): # image folder mode - if ann_file == '': + if ann_file == "": _, class_to_idx = find_classes(root) samples = make_dataset(root, class_to_idx, extensions) # zip mode else: - samples = make_dataset_with_ann(os.path.join(root, ann_file), - os.path.join(root, img_prefix), - extensions) + samples = make_dataset_with_ann( + os.path.join(root, ann_file), os.path.join(root, img_prefix), extensions + ) if len(samples) == 0: - raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" + - "Supported extensions are: " + ",".join(extensions))) + raise ( + RuntimeError( + "Found 0 files in subfolders of: " + + root + + "\n" + + "Supported extensions are: " + + ",".join(extensions) + ) + ) self.root = root self.loader = loader @@ -131,7 +148,9 @@ def init_cache(self): for index in range(n_sample): if index % (n_sample // 10) == 0: t = time.time() - start_time - print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block') + print( + f"global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block" + ) start_time = time.time() path, target = self.samples[index] if self.cache_mode == "full": @@ -162,17 +181,21 @@ def __len__(self): return len(self.samples) def __repr__(self): - fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' - fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - fmt_str += ' Root Location: {}\n'.format(self.root) - tmp = ' Transforms (if any): ' - fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - tmp = ' Target Transforms (if any): ' - fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + fmt_str = "Dataset " + self.__class__.__name__ + "\n" + fmt_str += " Number of datapoints: {}\n".format(self.__len__()) + fmt_str += " Root Location: {}\n".format(self.root) + tmp = " Transforms (if any): " + fmt_str += "{0}{1}\n".format( + tmp, self.transform.__repr__().replace("\n", "\n" + " " * len(tmp)) + ) + tmp = " Target Transforms (if any): " + fmt_str += "{0}{1}".format( + tmp, self.target_transform.__repr__().replace("\n", "\n" + " " * len(tmp)) + ) return fmt_str -IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] +IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif"] def pil_loader(path): @@ -183,14 +206,15 @@ def pil_loader(path): data = ZipReader.read(path) img = Image.open(io.BytesIO(data)) else: - with open(path, 'rb') as f: + with open(path, "rb") as f: img = Image.open(f) - return img.convert('RGB') - return img.convert('RGB') + return img.convert("RGB") + return img.convert("RGB") def accimage_loader(path): import accimage + try: return accimage.Image(path) except IOError: @@ -200,7 +224,8 @@ def accimage_loader(path): def default_img_loader(path): from torchvision import get_image_backend - if get_image_backend() == 'accimage': + + if get_image_backend() == "accimage": return accimage_loader(path) else: return pil_loader(path) @@ -225,12 +250,26 @@ class CachedImageFolder(DatasetFolder): imgs (list): List of (image path, class_index) tuples """ - def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None, - loader=default_img_loader, cache_mode="no"): - super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, - ann_file=ann_file, img_prefix=img_prefix, - transform=transform, target_transform=target_transform, - cache_mode=cache_mode) + def __init__( + self, + root, + ann_file="", + img_prefix="", + transform=None, + target_transform=None, + loader=default_img_loader, + cache_mode="no", + ): + super(CachedImageFolder, self).__init__( + root, + loader, + IMG_EXTENSIONS, + ann_file=ann_file, + img_prefix=img_prefix, + transform=transform, + target_transform=target_transform, + cache_mode=cache_mode, + ) self.imgs = self.samples def __getitem__(self, index): diff --git a/data/constants.py b/data/constants.py new file mode 100644 index 00000000..3cfd39b1 --- /dev/null +++ b/data/constants.py @@ -0,0 +1,42 @@ +import torch +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + +data_mean_std = { + "/mnt/10tb/data/inat21/resize-192": ( + torch.tensor([0.4632684290409088, 0.48004600405693054, 0.37628623843193054]), + torch.tensor([0.23754851520061493, 0.22912880778312683, 0.24746596813201904]), + ), + "/local/scratch/cv_datasets/inat21/resize-192": ( + torch.tensor([0.23754851520061493, 0.22912880778312683, 0.24746596813201904]), + torch.tensor([0.4632684290409088, 0.48004600405693054, 0.37628623843193054]), + ), + "/mnt/10tb/data/inat21/resize-224": ( + torch.tensor([0.23762744665145874, 0.2292044311761856, 0.24757201969623566]), + torch.tensor([0.4632636606693268, 0.48004215955734253, 0.37622377276420593]), + ), + "/mnt/10tb/data/inat21/resize-256": ( + torch.tensor([0.23768986761569977, 0.22925858199596405, 0.2476460039615631]), + torch.tensor([0.4632672071456909, 0.480050653219223, 0.37618669867515564]), + ), + "/local/scratch/cv_datasets/inat21/resize-256": ( + torch.tensor([0.23768986761569977, 0.22925858199596405, 0.2476460039615631]), + torch.tensor([0.4632672071456909, 0.480050653219223, 0.37618669867515564]), + ), + "/home/ubuntu/AWS_Server/swin-transformer/datasets/nabirds/image_new/full_train_val": ( + torch.tensor([0.49044106, 0.5076765, 0.46390218]), + torch.tensor([0.16689847, 0.1688618, 0.18529404]), + ), + "/home/ubuntu/AWS_Server/swin-transformer/datasets/nabirds/image_new/10P_train_val": ( + torch.tensor([0.48588862, 0.50227299, 0.45998148]), + torch.tensor([0.16756063, 0.16897439, 0.18490989]), + ), + "/home/ubuntu/AWS_Server/swin-transformer/datasets/nabirds/image_new/full_train_test": ( + torch.tensor([0.49103116, 0.5080927, 0.46408487]), + torch.tensor([0.16669449, 0.16859235, 0.18495317]), + ), + "/home/ubuntu/AWS_Server/swin-transformer/datasets/ip102": ( + torch.tensor([0.51354748, 0.54016679, 0.38778601]), + torch.tensor([0.19195388, 0.19070604, 0.19121135]), + ), + "imagenet": (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), +} diff --git a/data/data_simmim_ft.py b/data/data_simmim_ft.py index 1d44ae7d..ffd4137c 100644 --- a/data/data_simmim_ft.py +++ b/data/data_simmim_ft.py @@ -6,18 +6,20 @@ # -------------------------------------------------------- import os + import torch.distributed as dist -from torch.utils.data import DataLoader, DistributedSampler -from torchvision import datasets, transforms +from timm.data import Mixup, create_transform from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.data import Mixup -from timm.data import create_transform from timm.data.transforms import _pil_interp +from torch.utils.data import DataLoader, DistributedSampler +from torchvision import datasets, transforms def build_loader_finetune(config): config.defrost() - dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) + dataset_train, config.MODEL.NUM_CLASSES = build_dataset( + is_train=True, config=config + ) config.freeze() dataset_val, _ = build_dataset(is_train=False, config=config) @@ -31,7 +33,8 @@ def build_loader_finetune(config): ) data_loader_train = DataLoader( - dataset_train, sampler=sampler_train, + dataset_train, + sampler=sampler_train, batch_size=config.DATA.BATCH_SIZE, num_workers=config.DATA.NUM_WORKERS, pin_memory=config.DATA.PIN_MEMORY, @@ -39,7 +42,8 @@ def build_loader_finetune(config): ) data_loader_val = DataLoader( - dataset_val, sampler=sampler_val, + dataset_val, + sampler=sampler_val, batch_size=config.DATA.BATCH_SIZE, num_workers=config.DATA.NUM_WORKERS, pin_memory=config.DATA.PIN_MEMORY, @@ -48,21 +52,31 @@ def build_loader_finetune(config): # setup mixup / cutmix mixup_fn = None - mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None + mixup_active = ( + config.AUG.MIXUP > 0 + or config.AUG.CUTMIX > 0.0 + or config.AUG.CUTMIX_MINMAX is not None + ) if mixup_active: mixup_fn = Mixup( - mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, - prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, - label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) + mixup_alpha=config.AUG.MIXUP, + cutmix_alpha=config.AUG.CUTMIX, + cutmix_minmax=config.AUG.CUTMIX_MINMAX, + prob=config.AUG.MIXUP_PROB, + switch_prob=config.AUG.MIXUP_SWITCH_PROB, + mode=config.AUG.MIXUP_MODE, + label_smoothing=config.MODEL.LABEL_SMOOTHING, + num_classes=config.MODEL.NUM_CLASSES, + ) return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn def build_dataset(is_train, config): transform = build_transform(is_train, config) - - if config.DATA.DATASET == 'imagenet': - prefix = 'train' if is_train else 'val' + + if config.DATA.DATASET == "imagenet": + prefix = "train" if is_train else "val" root = os.path.join(config.DATA.DATA_PATH, prefix) dataset = datasets.ImageFolder(root, transform=transform) nb_classes = 1000 @@ -79,8 +93,12 @@ def build_transform(is_train, config): transform = create_transform( input_size=config.DATA.IMG_SIZE, is_training=True, - color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, - auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, + color_jitter=config.AUG.COLOR_JITTER + if config.AUG.COLOR_JITTER > 0 + else None, + auto_augment=config.AUG.AUTO_AUGMENT + if config.AUG.AUTO_AUGMENT != "none" + else None, re_prob=config.AUG.REPROB, re_mode=config.AUG.REMODE, re_count=config.AUG.RECOUNT, @@ -89,7 +107,9 @@ def build_transform(is_train, config): if not resize_im: # replace RandomResizedCropAndInterpolation with # RandomCrop - transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) + transform.transforms[0] = transforms.RandomCrop( + config.DATA.IMG_SIZE, padding=4 + ) return transform t = [] @@ -97,14 +117,18 @@ def build_transform(is_train, config): if config.TEST.CROP: size = int((256 / 224) * config.DATA.IMG_SIZE) t.append( - transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), + transforms.Resize( + size, interpolation=_pil_interp(config.DATA.INTERPOLATION) + ), # to maintain same ratio w.r.t. 224 images ) t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) else: t.append( - transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), - interpolation=_pil_interp(config.DATA.INTERPOLATION)) + transforms.Resize( + (config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), + interpolation=_pil_interp(config.DATA.INTERPOLATION), + ) ) t.append(transforms.ToTensor()) diff --git a/data/data_simmim_pt.py b/data/data_simmim_pt.py index 0f2503a7..9826ce90 100644 --- a/data/data_simmim_pt.py +++ b/data/data_simmim_pt.py @@ -7,70 +7,81 @@ import math import random -import numpy as np +import numpy as np import torch import torch.distributed as dist import torchvision.transforms as T +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data._utils.collate import default_collate from torchvision.datasets import ImageFolder -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD class MaskGenerator: - def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6): + def __init__( + self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6 + ): self.input_size = input_size self.mask_patch_size = mask_patch_size self.model_patch_size = model_patch_size self.mask_ratio = mask_ratio - + assert self.input_size % self.mask_patch_size == 0 assert self.mask_patch_size % self.model_patch_size == 0 - + self.rand_size = self.input_size // self.mask_patch_size self.scale = self.mask_patch_size // self.model_patch_size - - self.token_count = self.rand_size ** 2 + + self.token_count = self.rand_size**2 self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) - + def __call__(self): - mask_idx = np.random.permutation(self.token_count)[:self.mask_count] + mask_idx = np.random.permutation(self.token_count)[: self.mask_count] mask = np.zeros(self.token_count, dtype=int) mask[mask_idx] = 1 - + mask = mask.reshape((self.rand_size, self.rand_size)) mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) - + return mask class SimMIMTransform: def __init__(self, config): - self.transform_img = T.Compose([ - T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), - T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)), - T.RandomHorizontalFlip(), - T.ToTensor(), - T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)), - ]) - - if config.MODEL.TYPE in ['swin', 'swinv2']: - model_patch_size=config.MODEL.SWIN.PATCH_SIZE + self.transform_img = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.RandomResizedCrop( + config.DATA.IMG_SIZE, + scale=(0.67, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + ), + T.RandomHorizontalFlip(), + T.ToTensor(), + T.Normalize( + mean=torch.tensor(IMAGENET_DEFAULT_MEAN), + std=torch.tensor(IMAGENET_DEFAULT_STD), + ), + ] + ) + + if config.MODEL.TYPE in ["swin", "swinv2"]: + model_patch_size = config.MODEL.SWIN.PATCH_SIZE else: raise NotImplementedError - + self.mask_generator = MaskGenerator( input_size=config.DATA.IMG_SIZE, mask_patch_size=config.DATA.MASK_PATCH_SIZE, model_patch_size=model_patch_size, mask_ratio=config.DATA.MASK_RATIO, ) - + def __call__(self, img): img = self.transform_img(img) mask = self.mask_generator() - + return img, mask @@ -84,7 +95,9 @@ def collate_fn(batch): if batch[0][0][item_idx] is None: ret.append(None) else: - ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)])) + ret.append( + default_collate([batch[i][0][item_idx] for i in range(batch_num)]) + ) ret.append(default_collate([batch[i][1] for i in range(batch_num)])) return ret @@ -92,8 +105,18 @@ def collate_fn(batch): def build_loader_simmim(config): transform = SimMIMTransform(config) dataset = ImageFolder(config.DATA.DATA_PATH, transform) - - sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) - dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn) - - return dataloader \ No newline at end of file + + sampler = DistributedSampler( + dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True + ) + dataloader = DataLoader( + dataset, + config.DATA.BATCH_SIZE, + sampler=sampler, + num_workers=config.DATA.NUM_WORKERS, + pin_memory=True, + drop_last=True, + collate_fn=collate_fn, + ) + + return dataloader diff --git a/data/hierarchical.py b/data/hierarchical.py new file mode 100644 index 00000000..5bcd5057 --- /dev/null +++ b/data/hierarchical.py @@ -0,0 +1,94 @@ +import os + +import torch +from timm.data import Mixup, mixup +from torchvision.datasets import ImageFolder + + +class HierarchicalImageFolder(ImageFolder): + """ + Parses an image folder where the hierarchy is represented as follows: + + 00000_top_middle_..._bottom + 00001_top_middle_..._other + ... + """ + + num_classes = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def find_classes(self, directory): + classes = sorted( + entry.name for entry in os.scandir(directory) if entry.is_dir() + ) + + tier_lookup = {} + class_to_idxs = {} + + for cls in classes: + tiers = make_hierarchical(cls) + + for tier, value in enumerate(tiers): + if tier not in tier_lookup: + tier_lookup[tier] = {} + + if value not in tier_lookup[tier]: + tier_lookup[tier][value] = len(tier_lookup[tier]) + + class_to_idxs[cls] = torch.tensor( + [tier_lookup[tier][value] for tier, value in enumerate(tiers)] + ) + + # Set self.num_classes + self.num_classes = tuple(len(tier) for tier in tier_lookup.values()) + + return classes, class_to_idxs + + +def make_hierarchical(name): + """ + Sometimes the tree is not really a tree; that is, sometimes there are + repeated orders, for example. + + Arguments: + name (str): the complete taxonomic name, separated by '_' + """ + # index is a number + # top is kingdom + index, top, *tiers = name.split("_") + + cleaned = [top] + + complete = top + for tier in tiers: + complete += f"-{tier}" + cleaned.append(complete) + + return cleaned + + +class HierarchicalMixup(Mixup): + def __call__(self, inputs, targets): + assert len(inputs) % 2 == 0, "Batch size should be even when using this" + if self.mode == "elem": + lam = self._mix_elem(inputs) + elif self.mode == "pair": + lam = self._mix_pair(inputs) + else: + lam = self._mix_batch(inputs) + + batch_size, *_ = inputs.shape + assert targets.shape == ( + batch_size, + len(self.num_classes), + ), f"{targets.shape} != {batch_size, len(self.num_classes)}" + + targets = [ + mixup.mixup_target( + target, num_classes, lam, self.label_smoothing, inputs.device + ) + for target, num_classes in zip(targets.T, self.num_classes) + ] + return inputs, targets diff --git a/data/imagenet22k_dataset.py b/data/imagenet22k_dataset.py index 5758060b..dc3210fe 100644 --- a/data/imagenet22k_dataset.py +++ b/data/imagenet22k_dataset.py @@ -1,16 +1,16 @@ -import os import json -import torch.utils.data as data +import os +import warnings + import numpy as np +import torch.utils.data as data from PIL import Image -import warnings - warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) class IN22KDATASET(data.Dataset): - def __init__(self, root, ann_file='', transform=None, target_transform=None): + def __init__(self, root, ann_file="", transform=None, target_transform=None): super(IN22KDATASET, self).__init__() self.data_path = root @@ -40,7 +40,7 @@ def __getitem__(self, index): idb = self.database[index] # images - images = self._load_image(self.data_path + '/' + idb[0]).convert('RGB') + images = self._load_image(self.data_path + "/" + idb[0]).convert("RGB") if self.transform is not None: images = self.transform(images) diff --git a/data/inat/__init__.py b/data/inat/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/data/inat/__main__.py b/data/inat/__main__.py new file mode 100644 index 00000000..c1c0883c --- /dev/null +++ b/data/inat/__main__.py @@ -0,0 +1,70 @@ +import argparse + +from .. import hierarchical +from . import datasets + + +def preprocess_cli(args): + datasets.preprocess_dataset(args.root, args.stage, args.strategy, args.size) + + +def normalize_cli(args): + mean, std = datasets.load_statistics(args.directory) + print("Add this to a constants.py file:") + print( + f""" +"{args.directory}": ( + torch.tensor({mean.tolist()}), + torch.tensor({std.tolist()}), +),""" + ) + + +def num_classes_cli(args): + dataset = hierarchical.HierarchicalImageFolder( + args.directory, + ) + + print("Add this to your config .yml file to the model section:") + print(f"num_classes: {list(dataset.num_classes)}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(help="Available commands.") + + # Preprocess + preprocess_parser = subparsers.add_parser("preprocess", help="Preprocess data.") + preprocess_parser.add_argument( + "root", + help="Root data folder. Should contain folders compressed/ and raw/train, raw/val, etc.", + ) + preprocess_parser.add_argument( + "stage", choices=datasets.Inat21.stages, help="Data stage to preprocess" + ) + preprocess_parser.add_argument( + "strategy", choices=datasets.Inat21.strategies, help="Preprocessing strategy" + ) + preprocess_parser.add_argument("size", type=int, help="Image size in pixels") + preprocess_parser.set_defaults(func=preprocess_cli) + + # Normalize + normalize_parser = subparsers.add_parser( + "normalize", help="Measure mean and std of dataset." + ) + normalize_parser.add_argument("directory", help="Data folder") + normalize_parser.set_defaults(func=normalize_cli) + + # Number of classes + num_classes_parser = subparsers.add_parser( + "num-classes", help="Measure number of classes in dataset." + ) + num_classes_parser.add_argument("directory", help="Data folder") + num_classes_parser.set_defaults(func=num_classes_cli) + + args = parser.parse_args() + + if hasattr(args, "func"): + args.func(args) + else: + parser.print_usage() diff --git a/data/inat/datasets.py b/data/inat/datasets.py new file mode 100644 index 00000000..b0676896 --- /dev/null +++ b/data/inat/datasets.py @@ -0,0 +1,194 @@ +import concurrent.futures +import os +import pathlib +import warnings + +import cv2 +import einops +import timm.data +import torch +import torchvision +from tqdm.auto import tqdm + + +def load_statistics(directory): + """ + Need to calculate mean and std for the individual channels so we can normalize the images. + """ + dataset = timm.data.ImageDataset( + root=directory, transform=torchvision.transforms.ToTensor() + ) + channels, height, width = dataset[0][0].shape + + total = torch.zeros((channels,)) + total_squared = torch.zeros((channels,)) + + dataloader = torch.utils.data.DataLoader(dataset, batch_size=256, num_workers=32) + + for batch, _ in tqdm(dataloader): + total += einops.reduce(batch, "batch channel height width -> channel", "sum") + total_squared += einops.reduce( + torch.mul(batch, batch), "batch channel height width -> channel", "sum" + ) + + divisor = len(dataset) * width * height + + mean = total / divisor + var = total_squared / divisor - torch.mul(mean, mean) + std = torch.sqrt(var) + + return mean, std + + +class Inat21: + val_tar_gz_hash = "f6f6e0e242e3d4c9569ba56400938afc" + train_tar_gz_hash = "e0526d53c7f7b2e3167b2b43bb2690ed" + train_mini_tar_gz_hash = "db6ed8330e634445efc8fec83ae81442" + strategies = ("resize", "pad") + stages = ("train", "val", "train_mini") + + def __init__(self, root: str, stage: str, strategy: str, size: int): + self.root = root + self.stage = stage + self._check_stage(stage) + + self.strategy = strategy + self._check_strategy(strategy) + + self.size = size + self._check_size(size) + + @property + def directory(self) -> str: + return os.path.join(self.root, f"{self.strategy}-{self.size}", self.stage) + + @property + def ffcv(self) -> str: + return os.path.join( + self.root, f"{self.stage}-{self.strategy}-{self.size}.beton" + ) + + @property + def tar_file(self): + return os.path.join(self.root, "compressed", f"{self.stage}.tar.gz") + + @property + def raw_dir(self): + return os.path.join(self.root, "raw", self.stage) + + def _check_stage(self, stage): + if stage not in self.stages: + raise ValueError(f"Stage '{stage}' must be one of {self.stages}") + + def _check_strategy(self, strategy): + if strategy not in self.strategies: + raise ValueError(f"Strategy '{strategy}' must be one of {self.strategies}") + + def _check_size(self, size): + if not isinstance(size, int): + raise ValueError(f"Size {size} must be int; not {type(int)}") + + def check(self): + # If /.finished doesn't exist, we need to preprocess. + if not os.path.isfile(finished_file_path(self.directory)): + warnings.warn( + f"Data not processed in {self.directory}! " + "You should run:\n\n" + f"\tpython -m src.data preprocess {self.root} {self.stage} {self.strategy} {self.size}" + "\n\nAnd then run this script again!" + ) + raise RuntimeError(f"Data {self.directory} not pre-processed!") + + +def finished_file_path(directory): + return os.path.join(directory, ".finished") + + +def preprocess_class(cls_dir, output_dir, strategy, size): + cls = os.path.basename(cls_dir) + output_dir = os.path.join(output_dir, cls) + os.makedirs(output_dir, exist_ok=True) + + with os.scandir(cls_dir) as entries: + for entry in entries: + if not entry.is_file(): + continue + + im = cv2.imread(entry.path) + im = preprocess_image(im, strategy, size) + output_path = os.path.join(output_dir, entry.name) + if not cv2.imwrite(output_path, im): + raise RuntimeError(output_path) + + +def preprocess_image(im, strategy, size): + if strategy == "resize": + return cv2.resize(im, (size, size), interpolation=cv2.INTER_LINEAR) + elif strategy == "pad": + # https://stackoverflow.com/questions/43391205/add-padding-to-images-to-get-them-into-the-same-shape + raise NotImplementedError() + else: + raise NotImplementedError() + + +def parent_of(path: str): + return pathlib.Path(path).parents[0] + + +def preprocess_dataset(root: str, stage: str, strategy: str, size: int) -> None: + inat = Inat21(root, stage, strategy, size) + + err_msg = ( + f"Can't prepare data for stage {stage}, strategy {strategy} and size {size}." + ) + + # 1. If the directory does not exist, ask the user to fix that for us. + if not os.path.isdir(inat.raw_dir): + # Check that the tar exists + if not os.path.isfile(inat.tar_file): + warn_msg = f"Please download the appropriate .tar.gz file to {root} for stage {stage}." + if "raw" in root: + warn_msg += f"\n\nYour root path should contain a 'raw' directory; did you mean to use {parent_of(root)}?\n" + elif "compressed" in root: + warn_msg += f"\n\nYour root path should contain a 'compressed' directory; did you mean to use {parent_of(root)}?\n" + + warnings.warn(warn_msg) + + raise RuntimeError(err_msg) + else: + warnings.warn( + f"Please untar {inat.tar_file} in {root}. Probably need to run 'cd {root}; tar -xvf {inat.tar_file}" + ) + raise RuntimeError(err_msg) + + # 2. Now that we know the raw directory exists, we need to convert it + # to a processed directory + out_path = os.path.join(root, inat.directory) + + # 3. Make sure the directory exists + if not os.path.isdir(out_path): + os.makedirs(out_path) + + # 4. For all raw files, process and save to directory. We do this with + # a process pool because it is both I/O (read/write) and CPU (image processing) + # bound. + with os.scandir(inat.raw_dir) as entries: + directories = [entry.path for entry in entries if entry.is_dir()] + # print(directories) + print(f"Found {len(directories)} directories to preprocess.") + + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [ + executor.submit(preprocess_class, directory, out_path, strategy, size) + for directory in tqdm(directories) + ] + + print(f"Submitted {len(futures)} jobs to executor.") + + for future in tqdm( + concurrent.futures.as_completed(futures), total=len(futures) + ): + future.result() + + # 5. Save a sentinel file called .finished + open(finished_file_path(out_path), "w").close() diff --git a/data/zipreader.py b/data/zipreader.py index 060bc46a..1babf75e 100644 --- a/data/zipreader.py +++ b/data/zipreader.py @@ -5,23 +5,24 @@ # Written by Ze Liu # -------------------------------------------------------- +import io import os import zipfile -import io + import numpy as np -from PIL import Image -from PIL import ImageFile +from PIL import Image, ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True def is_zip_path(img_or_path): """judge if this is a zip path""" - return '.zip@' in img_or_path + return ".zip@" in img_or_path class ZipReader(object): """A class to read zipped files""" + zip_bank = dict() def __init__(self): @@ -31,18 +32,20 @@ def __init__(self): def get_zipfile(path): zip_bank = ZipReader.zip_bank if path not in zip_bank: - zfile = zipfile.ZipFile(path, 'r') + zfile = zipfile.ZipFile(path, "r") zip_bank[path] = zfile return zip_bank[path] @staticmethod def split_zip_style_path(path): - pos_at = path.index('@') - assert pos_at != -1, "character '@' is not found from the given path '%s'" % path - - zip_path = path[0: pos_at] - folder_path = path[pos_at + 1:] - folder_path = str.strip(folder_path, '/') + pos_at = path.index("@") + assert pos_at != -1, ( + "character '@' is not found from the given path '%s'" % path + ) + + zip_path = path[0:pos_at] + folder_path = path[pos_at + 1 :] + folder_path = str.strip(folder_path, "/") return zip_path, folder_path @staticmethod @@ -52,33 +55,37 @@ def list_folder(path): zfile = ZipReader.get_zipfile(zip_path) folder_list = [] for file_foler_name in zfile.namelist(): - file_foler_name = str.strip(file_foler_name, '/') - if file_foler_name.startswith(folder_path) and \ - len(os.path.splitext(file_foler_name)[-1]) == 0 and \ - file_foler_name != folder_path: + file_foler_name = str.strip(file_foler_name, "/") + if ( + file_foler_name.startswith(folder_path) + and len(os.path.splitext(file_foler_name)[-1]) == 0 + and file_foler_name != folder_path + ): if len(folder_path) == 0: folder_list.append(file_foler_name) else: - folder_list.append(file_foler_name[len(folder_path) + 1:]) + folder_list.append(file_foler_name[len(folder_path) + 1 :]) return folder_list @staticmethod def list_files(path, extension=None): if extension is None: - extension = ['.*'] + extension = [".*"] zip_path, folder_path = ZipReader.split_zip_style_path(path) zfile = ZipReader.get_zipfile(zip_path) file_lists = [] for file_foler_name in zfile.namelist(): - file_foler_name = str.strip(file_foler_name, '/') - if file_foler_name.startswith(folder_path) and \ - str.lower(os.path.splitext(file_foler_name)[-1]) in extension: + file_foler_name = str.strip(file_foler_name, "/") + if ( + file_foler_name.startswith(folder_path) + and str.lower(os.path.splitext(file_foler_name)[-1]) in extension + ): if len(folder_path) == 0: file_lists.append(file_foler_name) else: - file_lists.append(file_foler_name[len(folder_path) + 1:]) + file_lists.append(file_foler_name[len(folder_path) + 1 :]) return file_lists diff --git a/docs/cli-design.md b/docs/cli-design.md new file mode 100644 index 00000000..4520f631 --- /dev/null +++ b/docs/cli-design.md @@ -0,0 +1,15 @@ +# CLI Design + +Options that are different per run: + +1. Config file +2. Device batch size (depends on image size and local GPUs) +3. Whether to debug or not (1 process or default number of processes) +4. Master port (with a default of 12345) +5. Data path +6. Number of processes (needs to match number of GPUs) + +Options that are different per machine: + +1. Output directory +2. Virtual environment location diff --git a/docs/experiments/bold-banana.md b/docs/experiments/bold-banana.md new file mode 100644 index 00000000..a0c8c8ce --- /dev/null +++ b/docs/experiments/bold-banana.md @@ -0,0 +1,20 @@ +# Bold Banana + +This experiment is the default swin-v2-base applied to iNat 21 192x192. +We train for 90 epochs at 192x192, then tune for 30 epochs at 256x256. + +```yaml +configs: +- configs/hierarchical-vision-project/bold-banana-192.yaml +codename: bold-banana +``` + +## Log + +I initialized training on strawberry0 on 4x A6000 servers. + +I decided to use the A6000 servers for 256x256 tuning, so I am moving the latest checkpoint to S3, then cloning it back to an 8x V100 server to finish training. +I am storing the 40th checkpoint on S3 as funky-banana-192-epoch40.pth. +It is now running as funky-banana-192 on 8x V100. + +It had a bad mean/std, so I am re-christening it as bold-banana on 4x A6000 diff --git a/docs/experiments/fuzzy-fig.md b/docs/experiments/fuzzy-fig.md new file mode 100644 index 00000000..476d2810 --- /dev/null +++ b/docs/experiments/fuzzy-fig.md @@ -0,0 +1,14 @@ +# Fuzzy Fig + +This experiment is the default swin-v2-base applied to iNat 21 192x192 with 1/4 the default learning rate (1.25e-4). +We train for 90 epochs at 192x192, then tune for 30 epochs at 256x256. + +```yaml +configs: +- configs/hierarchical-vision-project/fuzzy-fig-192.yaml +codename: fuzzy-fig +``` + +## Log + +Training is running on 8x V100 as fuzzy-fig-192. diff --git a/docs/experiments/groovy-grape.md b/docs/experiments/groovy-grape.md new file mode 100644 index 00000000..c794d8a7 --- /dev/null +++ b/docs/experiments/groovy-grape.md @@ -0,0 +1,19 @@ +# Groovy Grape + +This experiment trains swin-v2-base from scratch on iNat21 using the multitask objective using 1/4 of the default learning rate, which works out to 1.25e-4. +It also does 90 epochs at 192. + +```yaml +configs: +- configs/hierarchical-vision-project/groovy-grape-192.yaml +codename: groovy-grape +``` + +## Log + +This model trained the first 90 on 8x V100, and did 8/30 epochs at 256 on 8x V100. +I am storing the 8th checkpoint on S3 as groovy-grape-256-epoch8.pth. +Now it is stored at /local/scratch/stevens.994/hierarchical-vision/groovy-grape-256/v0 +It was originally haunted-broomstick on wandb, but is now groovy-grape-256. + +I messed it up by switching std/mean, so I am restarting this run on 8xV100. diff --git a/docs/experiments/index.md b/docs/experiments/index.md new file mode 100644 index 00000000..e79da1f6 --- /dev/null +++ b/docs/experiments/index.md @@ -0,0 +1,13 @@ +# Experiment Index + +This file has a lightweight index of all the experiments I'm running so you can go to the file if necessary. + +[bold-banana](bold-banana.md): Default swin-v2-base on iNat21 + +[fuzzy-fig](fuzzy-fig.md): Default swin-v2-base on iNat21 with 1/4 LR + +[outrageous-orange](outrageous-orange.md): Hierarchical swin-v2-base on iNat21 + +[sweet-strawberry](sweet-strawberry.md): Hierarchical swin-v2-base on iNat21 with 1/2 LR + +[groovy-grape](groovy-grape.md): Hierarchical swin-v2-base on iNat21 with 1/4 LR diff --git a/docs/experiments/outrageous-orange.md b/docs/experiments/outrageous-orange.md new file mode 100644 index 00000000..89fcc20c --- /dev/null +++ b/docs/experiments/outrageous-orange.md @@ -0,0 +1,24 @@ +# Outrageous Orange + +This experiment is the swin-v2-base with hierarchical multitask objective applied to iNat 21 192x192. +We train for 90 epochs at 192x192. + +```yaml +configs: +- configs/hierarchical-vision-project/outrageous-orange-192.yaml +codename: outrageous-orange +``` + +## Log + +I initialized training on strawberry0 on 4x A6000 servers. + +I decided to use the A6000 servers for 256x256 tuning, so I am moving the latest checkpoint to S3, then cloning it back to an 8x V100 server to finish training. +I am storing the 36th checkpoint on S3 as `outrageous-orange-192-epoch36.pth`. +It is now running as `outrageous-orange-192` on 8x V100. + +I am going to stop outrageous-orange-192 (and never run outrageous-orange-256) because it is underperforming the other hierarchical runs and I don't want to waste compute on an obviously bad run. I'll still upload the checkpoints to S3 to save them. +The latest checkpoint is at `s3://imageomics-models/outrageous-orange-192-epoch64.pth` + +This is a bad run (std/mean switching issue). +I am going to restart it on 4x A60000 because I expect it to underperform the other run. diff --git a/docs/experiments/sweet-strawberry.md b/docs/experiments/sweet-strawberry.md new file mode 100644 index 00000000..7fda07e4 --- /dev/null +++ b/docs/experiments/sweet-strawberry.md @@ -0,0 +1,20 @@ +# Sweet Strawberry + +This experiment trains swin-v2-base from scratch on iNat21 using the multitask objective using 1/2 of the default learning rate, which works out to 2.5e-4. +It also does 90 epochs at 192. + +```yaml +configs: +- configs/hierarchical-vision-project/sweet-strawberry-192.yaml +codename: sweet-strawberry +``` + +## Log + +This model trained the first 90 on 8x V100, and did 15/30 epochs at 256 on 8x V100. +I am storing the 15th checkpoint on S3 as sweet-strawberry-256-epoch15.pth. +Now it is stored at /local/scratch/stevens.994/hierarchical-vision/sweet-strawberry-256/v0 +It was originally unearthly-moon on wandb, but is now sweet-strawberry-256. +It is running on 4x A6000. + +I want to check that I didn't mix up sweet-strawberry-192 and groovy-grape-192, so I am going to check their validation top 1 accuracy on iNat21 192x192 using their final checkpoints. diff --git a/generate_configs/generate.py b/generate_configs/generate.py new file mode 100644 index 00000000..95c047b7 --- /dev/null +++ b/generate_configs/generate.py @@ -0,0 +1,128 @@ +import argparse +import os +import pathlib +import yaml +import tomli +import tomli_w +from tqdm.auto import tqdm +# import sys +# sys.path.append("..") + +# from .. import config, logging, templating, util + +import logger, config, utils +logger = logger.init("experiments.generate") + + +from . import templating #, utils + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="General .toml files from template .toml files. I kept all my templates in experiments/templates and my generated experiment configs in experiments/generated, which I then removed from version control.", + ) + parser.add_argument( + "--strategy", + type=str, + help="Strategy to use to combine multiple lists in a template.", + default="grid", + choices=["grid", "paired", "random"], + ) + parser.add_argument( + "--count", + type=int, + help="Number of configs to generate when using --strategy random. Required.", + default=-1, + ) + parser.add_argument( + "--no-expand", + type=str, + nargs="+", + default=[], + help=".-separated fields to not expand", + ) + parser.add_argument( + "--prefix", + type=str, + default="generated-", + help="Prefix to add to generated templates", + ) + parser.add_argument( + "templates", + nargs="+", + type=str, + help="Template .toml files or directories containing template .toml files.", + ) + parser.add_argument( + "output", + type=str, + help="Output directory to write the generated .toml files to.", + ) + return parser.parse_args() + + +def generate(args: argparse.Namespace) -> None: + + strategy = templating.Strategy.new(args.strategy) + + count = args.count + + if strategy is templating.Strategy.random: + assert count > 0, "Need to include --count!" + + for template_toml in utils.files_with_extension(args.templates, ".yaml"): + + with open(template_toml, "rb") as template_file: + try: + template_dict = yaml.safe_load(template_file) + except tomli.TOMLDecodeError as err: #ToDo have to replace this for yaml file + logger.warning( + "Error parsing template file. [file: %s, err: %s]", + template_toml, + err, + ) + continue + + + template_name = pathlib.Path(template_toml).stem + + + logger.info("Opened template file. [file: %s]", template_toml) + # print ("\n\n") + # print (template_dict) + + experiment_dicts = templating.generate( + template_dict, strategy, count=count, no_expand=set(args.no_expand) + ) + + + logger.info( + "Loaded experiment dictionaries. [count: %s]", len(experiment_dicts) + ) + + for i, experiment_dict in enumerate(tqdm(experiment_dicts)): + outputfile=experiment_dict["EXPERIMENT"]["NAME"] + experiment_dict["EXPERIMENT"]["NAME"]=os.path.join(outputfile, str(i)) + filename = f"{args.prefix}{template_name}-{i}.yaml" + filepath = os.path.join(args.output, filename) + with open(filepath, "w") as file: + yaml.dump(experiment_dict, file) + + + + #exit() + # Verifies that the configs are correctly loaded. + #list(config.load_configs(filepath)) + + + +def main() -> None: + args = parse_args() + + # print (args.no_expand) + + generate(args) + + +if __name__ == "__main__": + main() diff --git a/generate_configs/templating.py b/generate_configs/templating.py new file mode 100644 index 00000000..7d4836e8 --- /dev/null +++ b/generate_configs/templating.py @@ -0,0 +1,369 @@ +import copy +import dataclasses +import enum +import random +import re +import typing +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union #Protocol +from typing_extensions import Protocol +import orjson +import preface +import scipy.stats + +import logger, config + +# from . import config, logging + +Primitive = Union[str, int, float, bool] +StringDict = Dict[str, object] + +logger = logger.init(__name__) + + +TEMPLATE_PATTERN = re.compile(r"\((\d+)\.\.\.(\d+)\)") + +CONTINUOUS_DISTRIBUTION_SUFFIX = "__random-sample-distribution" + + +class Strategy(preface.SumType): + grid = enum.auto() + paired = enum.auto() + random = enum.auto() + + +@dataclasses.dataclass(frozen=True) +class DiscreteDistribution: + values: Sequence[Primitive] + + +class HasRandomVariates(Protocol): + def rvs(self, **kwargs: Dict[str, Any]) -> float: + ... + + +# def parse_dist( +# raw_dist: object, raw_params: object +# ) -> Tuple[HasRandomVariates, Dict[str, Any]]: +# try: +# # We type ignore because we catch any type errors +# parameter_list = list(raw_params) # type: ignore +# except TypeError: +# raise ValueError(f"{raw_params} should be a sequence!") + +# dist_choices = typing.get_args(config.Distribution) +# if raw_dist not in dist_choices: +# raise ValueError(f"{raw_dist} must be one of {', '.join(dist_choices)}") +# raw_dist = typing.cast(config.Distribution, raw_dist) + +# params = {} + +# if raw_dist == "uniform": +# dist = scipy.stats.uniform +# assert len(parameter_list) == 2 + +# low, high = parameter_list +# assert low < high +# params = {"loc": low, "scale": high - low} +# elif raw_dist == "normal": +# dist = scipy.stats.norm +# assert len(parameter_list) == 2 + +# mean, std = parameter_list +# params = {"loc": mean, "scale": std} +# elif raw_dist == "loguniform": +# dist = scipy.stats.loguniform +# assert len(parameter_list) == 2 + +# low, high = parameter_list +# assert low < high +# params = {"a": low, "b": high} +# else: +# preface.never(raw_dist) + +# return dist, params + + +@dataclasses.dataclass(frozen=True) +class ContinuousDistribution: + fn: HasRandomVariates + params: Dict[str, Any] + + @classmethod + def parse(cls, value: object) -> Optional["ContinuousDistribution"]: + assert isinstance(value, dict) + + assert "dist" in value + assert "params" in value + + dist, params = parse_dist(value["dist"], value["params"]) + + return cls(dist, params) + + +@dataclasses.dataclass(frozen=True) +class Hole: + path: str + distribution: Union[DiscreteDistribution, ContinuousDistribution] + + def __post_init__(self) -> None: + assert isinstance(self.distribution, DiscreteDistribution) or isinstance( + self.distribution, ContinuousDistribution + ), f"self.distribution ({self.distribution}) is {type(self.distribution)}!" + + +def makehole(key: str, value: object) -> Optional[Hole]: + if isinstance(value, list): + values: List[Union[str, int]] = [] + + for item in value: + if isinstance(item, str): + values += expand_numbers(item) + else: + values.append(item) + + return Hole(key, DiscreteDistribution(values)) + + if isinstance(value, str): + numbers = expand_numbers(value) + + if len(numbers) > 1: + return Hole(key, DiscreteDistribution(numbers)) + + if key.endswith(CONTINUOUS_DISTRIBUTION_SUFFIX): + key = key.removesuffix(CONTINUOUS_DISTRIBUTION_SUFFIX) + + dist = ContinuousDistribution.parse(value) + assert dist + + return Hole(key, dist) + + return None + + +def find_holes(template: StringDict) -> List[Hole]: + """ + Arguments: + template (StringDict): Template with potential holes + no_expand (Set[str]): Fields to not treat as holes, even if we would otherwise. + """ + holes = [] + + # Make it a list so we can modify template during iteration. + for key, value in list(template.items()): + # Have to check makehole first because value might be a dict, but + # if key ends with CONTINUOUS_DISTRIBUTION_SUFFIX, then we want to + # parse that dict as a continuous distribution + hole = makehole(key, value) + if hole: + holes.append(hole) + template.pop(key) + elif isinstance(value, dict): + holes.extend( + Hole(f"{key}.{hole.path}", hole.distribution) + for hole in find_holes(value) + ) + + return holes + + +def sort_by_json(dicts: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + return list( + sorted(dicts, key=lambda d: orjson.dumps(d, option=orjson.OPT_SORT_KEYS)) + ) + + +# region FILLING + + +def grid_fill(filled: StringDict, holes: List[Hole]) -> List[StringDict]: + if not holes: + return [filled] + + experiments = [] + first, rest = holes[0], holes[1:] + + if not isinstance(first.distribution, DiscreteDistribution): + raise RuntimeError( + f"Must sample from DiscreteDistribution with strategy grid, not {type(first.distribution)}!" + ) + + for value in first.distribution.values: + + experiment = copy.deepcopy(filled) + + preface.dict.set(experiment, first.path, value) + experiments.extend(grid_fill(experiment, rest)) + + + return sort_by_json(experiments) + + +def paired_fill(holes: List[Hole]) -> List[StringDict]: + experiments = [] + + assert all(isinstance(hole.distribution, DiscreteDistribution) for hole in holes) + + # We can type ignore because we assert that all distributions are discrete + shortest_hole = min(holes, key=lambda h: len(h.distribution.values)) # type: ignore + + assert isinstance(shortest_hole.distribution, DiscreteDistribution) + + for i in range(len(shortest_hole.distribution.values)): + experiment: StringDict = {} + for hole in holes: + assert isinstance(hole.distribution, DiscreteDistribution) + preface.dict.set(experiment, hole.path, hole.distribution.values[i]) + + experiments.append(experiment) + + return sort_by_json(experiments) + + +def random_fill(holes: List[Hole], count: int) -> List[StringDict]: + experiments = [] + + for _ in range(count): + experiment: StringDict = {} + for hole in holes: + if isinstance(hole.distribution, DiscreteDistribution): + preface.dict.set( + experiment, hole.path, random.choice(hole.distribution.values) + ) + elif isinstance(hole.distribution, ContinuousDistribution): + preface.dict.set( + experiment, + hole.path, + float(hole.distribution.fn.rvs(**hole.distribution.params)), + ) + else: + preface.never(hole.distribution) + + experiments.append(experiment) + + return experiments + + +# endregion + + +def generate( + template: StringDict, + strategy: Strategy, + count: int = 0, + *, + no_expand: Optional[Set[str]] = None, +) -> List[StringDict]: + """ + Turns a template (a dictionary with lists as values) into a list of experiments (dictionaries with no lists). + + If strategy is Strategy.Grid, returns an experiment for each possible combination of each value in each list. Strategy.Paired returns an experiment for sequential pair of values in each list. + + An example makes this clearer. If the template had 3 lists with lengths 5, 4, 10, respectively: + + Grid would return 5 x 4 x 10 = 200 experiments. + + Paired would return min(5, 4, 10) = 4 experiments + + Random would return experiments + + Experiments are returned sorted by the JSON value. + """ + ignored = {} + if no_expand is not None: + # print (no_expand) + # print (template) + + # for field in no_expand: + # print (field) + # exit() + + ignored = {field: preface.dict.pop(template, field) for field in no_expand} + # print (template) + # exit() + + template = copy.deepcopy(template) + # print ("\n\n\n") + # print (template) + # print ("\n\n\n") + holes = find_holes(template) + # print (holes) + + if not holes: + # We can return this directly because there are no holes. + return [template] + + + logger.info("Found all holes. [count: %d]", len(holes)) + + experiments: List[StringDict] = [] + + if strategy is Strategy.grid: + filled = grid_fill({}, holes) + elif strategy is Strategy.paired: + filled = paired_fill(holes) + elif strategy is Strategy.random: + filled = random_fill(holes, count) + else: + preface.never(strategy) + + logger.info("Filled all holes. [count: %d]", len(filled)) + + without_holes: StringDict = {} + for key, value in preface.dict.flattened(template).items(): + if makehole(key, value): + continue + + without_holes[key] = value + + for key, value in ignored.items(): + without_holes[key] = value + + experiments = [preface.dict.merge(exp, without_holes) for exp in filled] + + logger.info("Merged all experiment configs. [count: %d]", len(experiments)) + + return sort_by_json(experiments) + + +def expand_numbers(obj: str) -> Union[List[str], List[int]]: + """ + Given a string that potentially has the digits expander: + + "(0...34)" + + Returns a list of strings with each value in the range (inclusive, exclusive) + + ["0", "1", ..., "33"] + """ + splits = re.split(TEMPLATE_PATTERN, obj, maxsplit=1) + + # One split means the pattern isn't present. + if len(splits) == 1: + return [obj] + + if len(splits) != 4: + raise ValueError("Can't parse strings with more than one ( ... )") + + front, start_s, end_s, back = splits + + front_list = expand_numbers(front) + back_list = expand_numbers(back) + + start = int(start_s) + end = int(end_s) + + if start < end: + spread = range(start, end) + else: + spread = range(start, end, -1) + + expanded = [] + for f in front_list: + for i in spread: + for b in back_list: + expanded.append(f"{f}{i}{b}") + + try: + return [int(i) for i in expanded] + except ValueError: + return expanded diff --git a/hierarchical.py b/hierarchical.py new file mode 100644 index 00000000..bb549307 --- /dev/null +++ b/hierarchical.py @@ -0,0 +1,79 @@ +import einops +import torch + + +def accuracy(output, target, topk=(1,), hierarchy_level=-1): + """ + Computes the accuracy over the k top predictions for the specified values of k + + Copied from rwightman/pytorch-image-models/timm/utils/metrics.py and modified + to work with hierarchical outputs as well. + + When the output is hierarchical, only returns the accuracy for `hierarchy_level` + (default -1, which is the fine-grained level). + """ + output_levels = 1 + if isinstance(output, list): + output_levels = len(output) + output = output[-1] + print (output_levels) + # print (output) + + batch_size = output.size(0) + + # Target might have multiple levels because of the hierarchy + if target.squeeze().ndim == 2: + assert target.squeeze().shape == (batch_size, output_levels) + target = target[:, -1] + + maxk = min(max(topk), output.size(1)) + _, pred = output.topk(maxk, dim=1, largest=True, sorted=True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + return [ + correct[: min(k, maxk)].reshape(-1).float().sum(0) * 100.0 / batch_size + for k in topk + ] + + +class FineGrainedCrossEntropyLoss(torch.nn.CrossEntropyLoss): + """ + A cross-entropy used with hierarchical inputs and targets and only + looks at the finest-grained tier (the last level). + """ + + def forward(self, inputs, targets): + fine_grained_inputs = inputs[-1] + fine_grained_targets = targets[:, -1] + return super().forward(fine_grained_inputs, fine_grained_targets) + + +class HierarchicalCrossEntropyLoss(torch.nn.CrossEntropyLoss): + def __init__(self, *args, coeffs=(1.0,), **kwargs): + super().__init__(*args, **kwargs) + + if isinstance(coeffs, torch.Tensor): + coeffs = coeffs.clone().detach().type(torch.float) + else: + coeffs = torch.tensor(coeffs, dtype=torch.float) + + self.register_buffer("coeffs", coeffs) + + def forward(self, inputs, targets): + if not isinstance(targets, list): + targets = einops.rearrange(targets, "batch tiers -> tiers batch") + + assert ( + len(inputs) == len(targets) == len(self.coeffs) + ), f"{len(inputs)} != {len(targets)} != {len(self.coeffs)}" + + losses = torch.stack( + [ + # Need to specify arguments to super() because of some a bug + # with super() in list comprehensions/generators (unclear) + super(HierarchicalCrossEntropyLoss, self).forward(input, target) + for input, target in zip(inputs, targets) + ] + ) + + return torch.dot(self.coeffs, losses) diff --git a/kernels/window_process/setup.py b/kernels/window_process/setup.py index c78526d0..c6cb5ea6 100644 --- a/kernels/window_process/setup.py +++ b/kernels/window_process/setup.py @@ -1,12 +1,16 @@ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension - -setup(name='swin_window_process', +setup( + name="swin_window_process", ext_modules=[ - CUDAExtension('swin_window_process', [ - 'swin_window_process.cpp', - 'swin_window_process_kernel.cu', - ]) + CUDAExtension( + "swin_window_process", + [ + "swin_window_process.cpp", + "swin_window_process_kernel.cu", + ], + ) ], - cmdclass={'build_ext': BuildExtension}) \ No newline at end of file + cmdclass={"build_ext": BuildExtension}, +) diff --git a/kernels/window_process/unit_test.py b/kernels/window_process/unit_test.py index 65dee566..743d7691 100644 --- a/kernels/window_process/unit_test.py +++ b/kernels/window_process/unit_test.py @@ -4,22 +4,25 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- -import torch -import swin_window_process import random import time import unittest +import swin_window_process +import torch + class WindowProcess(torch.autograd.Function): @staticmethod def forward(ctx, input, B, H, W, C, shift_size, window_size): - output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size) + output = swin_window_process.roll_and_window_partition_forward( + input, B, H, W, C, shift_size, window_size + ) ctx.B = B ctx.H = H - ctx.W = W - ctx.C = C + ctx.W = W + ctx.C = C ctx.shift_size = shift_size ctx.window_size = window_size return output @@ -28,24 +31,28 @@ def forward(ctx, input, B, H, W, C, shift_size, window_size): def backward(ctx, grad_in): B = ctx.B H = ctx.H - W = ctx.W - C = ctx.C + W = ctx.W + C = ctx.C shift_size = ctx.shift_size window_size = ctx.window_size - grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size) + grad_out = swin_window_process.roll_and_window_partition_backward( + grad_in, B, H, W, C, shift_size, window_size + ) return grad_out, None, None, None, None, None, None, None class WindowProcessReverse(torch.autograd.Function): @staticmethod def forward(ctx, input, B, H, W, C, shift_size, window_size): - output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size) + output = swin_window_process.window_merge_and_roll_forward( + input, B, H, W, C, shift_size, window_size + ) ctx.B = B ctx.H = H - ctx.W = W - ctx.C = C + ctx.W = W + ctx.C = C ctx.shift_size = shift_size ctx.window_size = window_size @@ -55,12 +62,14 @@ def forward(ctx, input, B, H, W, C, shift_size, window_size): def backward(ctx, grad_in): B = ctx.B H = ctx.H - W = ctx.W - C = ctx.C + W = ctx.W + C = ctx.C shift_size = ctx.shift_size window_size = ctx.window_size - grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size) + grad_out = swin_window_process.window_merge_and_roll_backward( + grad_in, B, H, W, C, shift_size, window_size + ) return grad_out, None, None, None, None, None, None, None @@ -74,9 +83,12 @@ def window_partition(x, window_size): """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) return windows + def window_reverse(windows, window_size, H, W): """ Args: @@ -88,7 +100,9 @@ def window_reverse(windows, window_size, H, W): x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = windows.view( + B, H // window_size, W // window_size, window_size, window_size, -1 + ) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x @@ -119,6 +133,7 @@ def copy_one_tensor(input, requires_grad=True): input1 = input.clone().detach().requires_grad_(requires_grad).cuda() return input1 + class Test_WindowProcess(unittest.TestCase): def setUp(self): self.B = 192 @@ -129,10 +144,12 @@ def setUp(self): self.window_size = 7 self.nH = self.H // self.window_size self.nW = self.W // self.window_size - + def test_roll_and_window_partition_forward(self, dtype=torch.float32): - input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda() - + input = torch.randn( + (self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True + ).cuda() + input1 = copy_one_tensor(input, True) input2 = copy_one_tensor(input, True) @@ -140,15 +157,28 @@ def test_roll_and_window_partition_forward(self, dtype=torch.float32): # ori expected = pyt_forward(input1, self.shift_size, self.window_size) # fused kernel - fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size) - + fused_output = WindowProcess.apply( + input2, + self.B, + self.H, + self.W, + self.C, + -self.shift_size, + self.window_size, + ) + self.assertTrue(torch.equal(expected, fused_output)) - #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) - + # self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) + def test_roll_and_window_partition_backward(self, dtype=torch.float32): - input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda() - d_loss_tensor = torch.randn((self.B*self.nW*self.nH, self.window_size, self.window_size, self.C), dtype=dtype).cuda() - + input = torch.randn( + (self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True + ).cuda() + d_loss_tensor = torch.randn( + (self.B * self.nW * self.nH, self.window_size, self.window_size, self.C), + dtype=dtype, + ).cuda() + input1 = copy_one_tensor(input, True) input2 = copy_one_tensor(input, True) @@ -156,64 +186,105 @@ def test_roll_and_window_partition_backward(self, dtype=torch.float32): expected = pyt_forward(input1, self.shift_size, self.window_size) expected.backward(d_loss_tensor) # fused kernel - fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size) + fused_output = WindowProcess.apply( + input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size + ) fused_output.backward(d_loss_tensor) - + self.assertTrue(torch.equal(expected, fused_output)) - #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) + # self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) def test_window_merge_and_roll_forward(self, dtype=torch.float32): - input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda() - + input = torch.randn( + (self.B * self.nH * self.nW, self.window_size, self.window_size, self.C), + dtype=dtype, + requires_grad=True, + ).cuda() + input1 = copy_one_tensor(input, True) input2 = copy_one_tensor(input, True) with torch.no_grad(): # ori - expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W) + expected = reverse_pyt_forward( + input1, self.shift_size, self.window_size, self.H, self.W + ) # fused kernel - fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size) - + fused_output = WindowProcessReverse.apply( + input2, + self.B, + self.H, + self.W, + self.C, + self.shift_size, + self.window_size, + ) + self.assertTrue(torch.equal(expected, fused_output)) - #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) - + # self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) def test_window_merge_and_roll_backward(self, dtype=torch.float32): - input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda() - d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda() - + input = torch.randn( + (self.B * self.nH * self.nW, self.window_size, self.window_size, self.C), + dtype=dtype, + requires_grad=True, + ).cuda() + d_loss_tensor = torch.randn( + (self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True + ).cuda() + input1 = copy_one_tensor(input, True) input2 = copy_one_tensor(input, True) # ori - expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W) + expected = reverse_pyt_forward( + input1, self.shift_size, self.window_size, self.H, self.W + ) expected.backward(d_loss_tensor) # fused kernel - fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size) + fused_output = WindowProcessReverse.apply( + input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size + ) fused_output.backward(d_loss_tensor) - + self.assertTrue(torch.equal(expected, fused_output)) - #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) + # self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) def test_forward_backward_speed(self, dtype=torch.float32, times=1000): - input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda() - d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda() - + input = torch.randn( + (self.B * self.nH * self.nW, self.window_size, self.window_size, self.C), + dtype=dtype, + requires_grad=True, + ).cuda() + d_loss_tensor = torch.randn( + (self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True + ).cuda() + input1 = copy_one_tensor(input, True) input2 = copy_one_tensor(input, True) # SwinTransformer official def run_pyt(t=1000): for _ in range(t): - expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W) + expected = reverse_pyt_forward( + input1, self.shift_size, self.window_size, self.H, self.W + ) expected.backward(d_loss_tensor) # my op def run_fusedop(t=1000): for _ in range(t): - fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size) + fused_output = WindowProcessReverse.apply( + input2, + self.B, + self.H, + self.W, + self.C, + self.shift_size, + self.window_size, + ) fused_output.backward(d_loss_tensor) - + torch.cuda.synchronize() t1 = time.time() run_pyt(t=times) @@ -224,10 +295,10 @@ def run_fusedop(t=1000): t3 = time.time() self.assertTrue((t3 - t2) < (t2 - t1)) - print('Run {} times'.format(times)) - print('Original time cost: {}'.format(t2 - t1)) - print('Fused op time cost: {}'.format(t3 - t2)) - + print("Run {} times".format(times)) + print("Original time cost: {}".format(t2 - t1)) + print("Fused op time cost: {}".format(t3 - t2)) + def test_roll_and_window_partition_forward_fp16(self, dtype=torch.float16): self.test_roll_and_window_partition_forward(dtype=dtype) @@ -236,7 +307,7 @@ def test_roll_and_window_partition_backward_fp16(self, dtype=torch.float16): def test_window_merge_and_roll_forward_fp16(self, dtype=torch.float16): self.test_window_merge_and_roll_forward(dtype=dtype) - + def test_window_merge_and_roll_backward_fp16(self, dtype=torch.float16): self.test_window_merge_and_roll_backward(dtype=dtype) @@ -244,7 +315,7 @@ def test_forward_backward_speed_fp16(self, dtype=torch.float16, times=1000): self.test_forward_backward_speed(dtype=dtype, times=times) -if __name__ == '__main__': - print('Pass only two tensors are exactly the same (using torch.equal).\n') +if __name__ == "__main__": + print("Pass only two tensors are exactly the same (using torch.equal).\n") torch.manual_seed(0) unittest.main(verbosity=2) diff --git a/kernels/window_process/window_process.py b/kernels/window_process/window_process.py index ee43e9e9..d482d4f7 100644 --- a/kernels/window_process/window_process.py +++ b/kernels/window_process/window_process.py @@ -4,19 +4,21 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- -import torch import swin_window_process +import torch class WindowProcess(torch.autograd.Function): @staticmethod def forward(ctx, input, B, H, W, C, shift_size, window_size): - output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size) + output = swin_window_process.roll_and_window_partition_forward( + input, B, H, W, C, shift_size, window_size + ) ctx.B = B ctx.H = H - ctx.W = W - ctx.C = C + ctx.W = W + ctx.C = C ctx.shift_size = shift_size ctx.window_size = window_size return output @@ -25,24 +27,28 @@ def forward(ctx, input, B, H, W, C, shift_size, window_size): def backward(ctx, grad_in): B = ctx.B H = ctx.H - W = ctx.W - C = ctx.C + W = ctx.W + C = ctx.C shift_size = ctx.shift_size window_size = ctx.window_size - grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size) + grad_out = swin_window_process.roll_and_window_partition_backward( + grad_in, B, H, W, C, shift_size, window_size + ) return grad_out, None, None, None, None, None, None, None class WindowProcessReverse(torch.autograd.Function): @staticmethod def forward(ctx, input, B, H, W, C, shift_size, window_size): - output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size) + output = swin_window_process.window_merge_and_roll_forward( + input, B, H, W, C, shift_size, window_size + ) ctx.B = B ctx.H = H - ctx.W = W - ctx.C = C + ctx.W = W + ctx.C = C ctx.shift_size = shift_size ctx.window_size = window_size @@ -52,12 +58,14 @@ def forward(ctx, input, B, H, W, C, shift_size, window_size): def backward(ctx, grad_in): B = ctx.B H = ctx.H - W = ctx.W - C = ctx.C + W = ctx.W + C = ctx.C shift_size = ctx.shift_size window_size = ctx.window_size - #grad_out = ctx.saved_tensors[0] - #grad_out = torch.zeros((B, H, W, C), dtype=dtype).cuda() - grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size) + # grad_out = ctx.saved_tensors[0] + # grad_out = torch.zeros((B, H, W, C), dtype=dtype).cuda() + grad_out = swin_window_process.window_merge_and_roll_backward( + grad_in, B, H, W, C, shift_size, window_size + ) return grad_out, None, None, None, None, None, None, None diff --git a/logger.py b/logger.py index a066e55b..14a7d6d1 100644 --- a/logger.py +++ b/logger.py @@ -5,37 +5,149 @@ # Written by Ze Liu # -------------------------------------------------------- +import functools +import logging import os import sys -import logging -import functools + +import torch from termcolor import colored +from torch.utils.tensorboard import SummaryWriter + +import wandb @functools.lru_cache() -def create_logger(output_dir, dist_rank=0, name=''): +def create_logger(output_dir, dist_rank=0, name=""): # create logger logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) logger.propagate = False # create formatter - fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' - color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ - colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' + fmt = "[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s" + color_fmt = ( + colored("[%(asctime)s %(name)s]", "green") + + colored("(%(filename)s %(lineno)d)", "yellow") + + ": %(levelname)s %(message)s" + ) # create console handlers for master process if dist_rank == 0: console_handler = logging.StreamHandler(sys.stdout) console_handler.setLevel(logging.DEBUG) console_handler.setFormatter( - logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) + logging.Formatter(fmt=color_fmt, datefmt="%Y-%m-%d %H:%M:%S") + ) logger.addHandler(console_handler) # create file handlers - file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') + file_handler = logging.FileHandler( + os.path.join(output_dir, f"log_rank{dist_rank}.txt"), mode="a" + ) file_handler.setLevel(logging.DEBUG) - file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) + file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt="%Y-%m-%d %H:%M:%S")) logger.addHandler(file_handler) return logger + + +class TensorboardWriter: + writer = None + + def __init__(self, output_dir, dist_rank): + self.output_dir = output_dir + + if dist_rank == 0: + self.writer = SummaryWriter(log_dir=self.output_dir) + + def add_hparams(self, hparams, metrics): + pass + + def log(self, items, step): + if self.writer is None: + return + + # Copied from huggingface/accelerate/src/accelerate/tracking.py + for k, v in items.items(): + if isinstance(v, (int, float)): + self.writer.add_scalar(k, v, global_step=step) + elif isinstance(v, torch.Tensor): + assert v.numel() == 1 + self.writer.add_scalar(k, v.item(), global_step=step) + elif isinstance(v, str): + self.writer.add_text(k, v, global_step=step) + elif isinstance(v, dict): + self.writer.add_scalars(k, v, global_step=step) + else: + print(f"Can't log {v} because it is {type(v)}!") + + +class WandbWriter: + def __init__(self, rank): + self.rank = rank + + def init(self, config): + if self.rank != 0: + return + + kwargs = dict( + config=config, + project="hierarchical-vision", + name=config.EXPERIMENT.NAME, + ) + + if not config.EXPERIMENT.WANDB_ID: + print("Cannot resume wandb run because no id was provided!") + else: + kwargs["id"] = config.EXPERIMENT.WANDB_ID + kwargs["resume"] = "allow" + + wandb.init(**kwargs, mode="disabled") + + # Validation metrics + wandb.define_metric("val/loss", step_metric="epoch", summary="max") + wandb.define_metric("val/acc1", step_metric="epoch", summary="max") + wandb.define_metric("val/acc5", step_metric="epoch", summary="max") + + # Training metrics + wandb.define_metric("train/batch_time", step_metric="step", summary="last") + wandb.define_metric("train/grad_norm", step_metric="step", summary="last") + wandb.define_metric("train/batch_loss", step_metric="step", summary="last") + wandb.define_metric("train/loss_scale", step_metric="step", summary="last") + wandb.define_metric("train/learning_rate", step_metric="step", summary="last") + + wandb.define_metric("train/epoch_time", step_metric="epoch", summary="last") + wandb.define_metric("train/loss", step_metric="epoch", summary="last") + + # Other metrics + wandb.define_metric("memory_mb", summary="max") + + def log(self, dct): + if self.rank != 0: + return + + wandb.log(dct) + + @property + def name(self): + if self.rank != 0: + raise RuntimeError(f"Should not get .name with rank {self.rank}.") + + return wandb.run.name + + +############################# Added Script to generate yaml files ################################# + +def init(name: str, verbose: bool = False, date=True) -> logging.Logger: + if date: + log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" + else: + log_format = "[%(levelname)s] [%(name)s] %(message)s" + + if not verbose: + logging.basicConfig(level=logging.INFO, format=log_format) + else: + logging.basicConfig(level=logging.DEBUG, format=log_format) + + return logging.getLogger(name) diff --git a/lr_scheduler.py b/lr_scheduler.py index a2122e5d..fa0ec4f1 100644 --- a/lr_scheduler.py +++ b/lr_scheduler.py @@ -9,8 +9,8 @@ import torch from timm.scheduler.cosine_lr import CosineLRScheduler -from timm.scheduler.step_lr import StepLRScheduler from timm.scheduler.scheduler import Scheduler +from timm.scheduler.step_lr import StepLRScheduler def build_scheduler(config, optimizer, n_iter_per_epoch): @@ -20,11 +20,13 @@ def build_scheduler(config, optimizer, n_iter_per_epoch): multi_steps = [i * n_iter_per_epoch for i in config.TRAIN.LR_SCHEDULER.MULTISTEPS] lr_scheduler = None - if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': + if config.TRAIN.LR_SCHEDULER.NAME == "cosine": lr_scheduler = CosineLRScheduler( optimizer, - t_initial=(num_steps - warmup_steps) if config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX else num_steps, - t_mul=1., + t_initial=(num_steps - warmup_steps) + if config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX + else num_steps, + cycle_mul=1.0, lr_min=config.TRAIN.MIN_LR, warmup_lr_init=config.TRAIN.WARMUP_LR, warmup_t=warmup_steps, @@ -32,7 +34,7 @@ def build_scheduler(config, optimizer, n_iter_per_epoch): t_in_epochs=False, warmup_prefix=config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX, ) - elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': + elif config.TRAIN.LR_SCHEDULER.NAME == "linear": lr_scheduler = LinearLRScheduler( optimizer, t_initial=num_steps, @@ -41,16 +43,16 @@ def build_scheduler(config, optimizer, n_iter_per_epoch): warmup_t=warmup_steps, t_in_epochs=False, ) - elif config.TRAIN.LR_SCHEDULER.NAME == 'step': + elif config.TRAIN.LR_SCHEDULER.NAME == "step": lr_scheduler = StepLRScheduler( optimizer, decay_t=decay_steps, - decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, + cycle_decay=config.TRAIN.LR_SCHEDULER.DECAY_RATE, warmup_lr_init=config.TRAIN.WARMUP_LR, warmup_t=warmup_steps, t_in_epochs=False, ) - elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep': + elif config.TRAIN.LR_SCHEDULER.NAME == "multistep": lr_scheduler = MultiStepLRScheduler( optimizer, milestones=multi_steps, @@ -59,28 +61,58 @@ def build_scheduler(config, optimizer, n_iter_per_epoch): warmup_t=warmup_steps, t_in_epochs=False, ) + elif config.TRAIN.LR_SCHEDULER.NAME == "constant": + lr_scheduler = ConstantLRScheduler( + optimizer, + warmup_lr_init=config.TRAIN.WARMUP_LR, + warmup_t=warmup_steps, + t_in_epochs=False, + ) return lr_scheduler -class LinearLRScheduler(Scheduler): - def __init__(self, - optimizer: torch.optim.Optimizer, - t_initial: int, - lr_min_rate: float, - warmup_t=0, - warmup_lr_init=0., - t_in_epochs=True, - noise_range_t=None, - noise_pct=0.67, - noise_std=1.0, - noise_seed=42, - initialize=True, - ) -> None: +class TimmScheduler(Scheduler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None + + +class LinearLRScheduler(TimmScheduler): + def __init__( + self, + optimizer: torch.optim.Optimizer, + t_initial: int, + lr_min_rate: float, + warmup_t=0, + warmup_lr_init=0.0, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True, + ) -> None: super().__init__( - optimizer, param_group_field="lr", - noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, - initialize=initialize) + optimizer, + param_group_field="lr", + noise_range_t=noise_range_t, + noise_pct=noise_pct, + noise_std=noise_std, + noise_seed=noise_seed, + initialize=initialize, + ) self.t_initial = t_initial self.lr_min_rate = lr_min_rate @@ -88,7 +120,9 @@ def __init__(self, self.warmup_lr_init = warmup_lr_init self.t_in_epochs = t_in_epochs if self.warmup_t: - self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + self.warmup_steps = [ + (v - warmup_lr_init) / self.warmup_t for v in self.base_values + ] super().update_groups(self.warmup_lr_init) else: self.warmup_steps = [1 for _ in self.base_values] @@ -99,54 +133,75 @@ def _get_lr(self, t): else: t = t - self.warmup_t total_t = self.t_initial - self.warmup_t - lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] + lrs = [ + v - ((v - v * self.lr_min_rate) * (t / total_t)) + for v in self.base_values + ] return lrs - def get_epoch_values(self, epoch: int): - if self.t_in_epochs: - return self._get_lr(epoch) - else: - return None - def get_update_values(self, num_updates: int): - if not self.t_in_epochs: - return self._get_lr(num_updates) - else: - return None - - -class MultiStepLRScheduler(Scheduler): - def __init__(self, optimizer: torch.optim.Optimizer, milestones, gamma=0.1, warmup_t=0, warmup_lr_init=0, t_in_epochs=True) -> None: +class MultiStepLRScheduler(TimmScheduler): + def __init__( + self, + optimizer: torch.optim.Optimizer, + milestones, + gamma=0.1, + warmup_t=0, + warmup_lr_init=0, + t_in_epochs=True, + ) -> None: super().__init__(optimizer, param_group_field="lr") - + self.milestones = milestones self.gamma = gamma self.warmup_t = warmup_t self.warmup_lr_init = warmup_lr_init self.t_in_epochs = t_in_epochs if self.warmup_t: - self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + self.warmup_steps = [ + (v - warmup_lr_init) / self.warmup_t for v in self.base_values + ] super().update_groups(self.warmup_lr_init) else: self.warmup_steps = [1 for _ in self.base_values] - + assert self.warmup_t <= min(self.milestones) - + def _get_lr(self, t): if t < self.warmup_t: lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: - lrs = [v * (self.gamma ** bisect.bisect_right(self.milestones, t)) for v in self.base_values] + lrs = [ + v * (self.gamma ** bisect.bisect_right(self.milestones, t)) + for v in self.base_values + ] return lrs - def get_epoch_values(self, epoch: int): - if self.t_in_epochs: - return self._get_lr(epoch) + +class ConstantLRScheduler(TimmScheduler): + def __init__( + self, + optimizer: torch.optim.Optimizer, + warmup_t=0, + warmup_lr_init=0, + t_in_epochs=True, + ) -> None: + super().__init__(optimizer, param_group_field="lr") + + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.t_in_epochs = t_in_epochs + if self.warmup_t: + self.warmup_steps = [ + (v - warmup_lr_init) / self.warmup_t for v in self.base_values + ] + super().update_groups(self.warmup_lr_init) else: - return None + self.warmup_steps = [1 for _ in self.base_values] - def get_update_values(self, num_updates: int): - if not self.t_in_epochs: - return self._get_lr(num_updates) + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: - return None + lrs = [v for v in self.base_values] + return lrs diff --git a/main.py b/main.py index 84230ea7..416d7119 100644 --- a/main.py +++ b/main.py @@ -5,84 +5,171 @@ # Written by Ze Liu # -------------------------------------------------------- -import os -import time -import json -import random import argparse import datetime -import numpy as np +import json +import os +import random +import time +import numpy as np import torch import torch.backends.cudnn as cudnn import torch.distributed as dist - -from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy -from timm.utils import accuracy, AverageMeter +from timm.loss import LabelSmoothingCrossEntropy +from timm.utils import AverageMeter from config import get_config -from models import build_model from data import build_loader +from hierarchical import ( + FineGrainedCrossEntropyLoss, + HierarchicalCrossEntropyLoss, + accuracy, +) +from logger import WandbWriter, create_logger from lr_scheduler import build_scheduler +from models import build_model from optimizer import build_optimizer -from logger import create_logger -from utils import load_checkpoint, load_pretrained, save_checkpoint, NativeScalerWithGradNormCount, auto_resume_helper, \ - reduce_tensor +from utils import ( + NativeScalerWithGradNormCount, + auto_resume_helper, + batch_size_of, + load_checkpoint, + load_pretrained, + reduce_tensor, + save_checkpoint, + find_experiments, +) + + + +class EarlyStopper: + def __init__(self, patience=1, min_delta=0): + self.patience = patience + self.min_delta = min_delta + self.counter = 0 + self.min_validation_loss = np.inf + + def early_stop(self, validation_loss): + if validation_loss < self.min_validation_loss: + self.min_validation_loss = validation_loss + self.counter = 0 + elif validation_loss > (self.min_validation_loss + self.min_delta): + self.counter += 1 + if self.counter >= self.patience: + return True + return False + def parse_option(): - parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False) - parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) + parser = argparse.ArgumentParser( + "Swin Transformer training and evaluation script", add_help=False + ) + parser.add_argument( + "--cfg", + nargs="+", + #type=str, + required=True, + #metavar="FILE", + help="Paths to directories containing config.yaml files OR just a config.yaml file. Directories will be searched for any nested config.yaml files.", + ) parser.add_argument( "--opts", help="Modify config options by adding 'KEY VALUE' pairs. ", default=None, - nargs='+', + nargs="+", ) # easy config modification - parser.add_argument('--batch-size', type=int, help="batch size for single GPU") - parser.add_argument('--data-path', type=str, help='path to dataset') - parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') - parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], - help='no: no cache, ' - 'full: cache all data, ' - 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') - parser.add_argument('--pretrained', - help='pretrained weight from checkpoint, could be imagenet22k pretrained weight') - parser.add_argument('--resume', help='resume from checkpoint') - parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") - parser.add_argument('--use-checkpoint', action='store_true', - help="whether to use gradient checkpointing to save memory") - parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp') - parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'], - help='mixed precision opt level, if O0, no amp is used (deprecated!)') - parser.add_argument('--output', default='output', type=str, metavar='PATH', - help='root of output folder, the full path is // (default: output)') - parser.add_argument('--tag', help='tag of experiment') - parser.add_argument('--eval', action='store_true', help='Perform evaluation only') - parser.add_argument('--throughput', action='store_true', help='Test throughput only') - - # distributed training - parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') + parser.add_argument("--batch-size", type=int, help="batch size for single GPU") + parser.add_argument("--data-path", type=str, help="path to dataset") + parser.add_argument( + "--zip", + action="store_true", + help="use zipped dataset instead of folder dataset", + ) + parser.add_argument( + "--cache-mode", + type=str, + default="part", + choices=["no", "full", "part"], + help="no: no cache, " + "full: cache all data, " + "part: sharding the dataset into nonoverlapping pieces and only cache one piece", + ) + parser.add_argument( + "--pretrained", + help="pretrained weight from checkpoint, could be imagenet22k pretrained weight", + ) + parser.add_argument("--resume", help="resume from checkpoint") + parser.add_argument( + "--accumulation-steps", type=int, help="gradient accumulation steps" + ) + parser.add_argument( + "--use-checkpoint", + action="store_true", + help="whether to use gradient checkpointing to save memory", + ) + parser.add_argument( + "--disable_amp", action="store_true", help="Disable pytorch amp" + ) + parser.add_argument( + "--amp-opt-level", + type=str, + choices=["O0", "O1", "O2"], + help="mixed precision opt level, if O0, no amp is used (deprecated!)", + ) + parser.add_argument( + "--output", + default="output", + type=str, + metavar="PATH", + help="root of output folder, the full path is // (default: output)", + ) + parser.add_argument("--tag", help="tag of experiment") + parser.add_argument("--eval", action="store_true", help="Perform evaluation only") + parser.add_argument( + "--throughput", action="store_true", help="Test throughput only" + ) # for acceleration - parser.add_argument('--fused_window_process', action='store_true', - help='Fused window shift & window partition, similar for reversed part.') - parser.add_argument('--fused_layernorm', action='store_true', help='Use fused layernorm.') - ## overwrite optimizer in config (*.yaml) if specified, e.g., fused_adam/fused_lamb - parser.add_argument('--optim', type=str, - help='overwrite optimizer if provided, can be adamw/sgd/fused_adam/fused_lamb.') + parser.add_argument( + "--fused_window_process", + action="store_true", + help="Fused window shift & window partition, similar for reversed part.", + ) + parser.add_argument( + "--fused_layernorm", action="store_true", help="Use fused layernorm." + ) + # overwrite optimizer in config (*.yaml) if specified, e.g., fused_adam/fused_lamb + parser.add_argument( + "--optim", + type=str, + help="overwrite optimizer if provided, can be adamw/sgd/fused_adam/fused_lamb.", + ) + # low-data-regieme; percentage of training data to be use for fine tuning + parser.add_argument( + "--low-data", + type=float, + help="percentage of training data (.01 to 1) to be use for fine tuning", + ) args, unparsed = parser.parse_known_args() - config = get_config(args) + #config = get_config(args) - return args, config + return args #, config def main(config): - dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config) + ( + dataset_train, + dataset_val, + data_loader_train, + data_loader_val, + mixup_fn, + ) = build_loader(config) logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") model = build_model(config) @@ -90,7 +177,7 @@ def main(config): n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"number of params: {n_parameters}") - if hasattr(model, 'flops'): + if hasattr(model, "flops"): flops = model.flops() logger.info(f"number of GFLOPs: {flops / 1e9}") @@ -98,21 +185,39 @@ def main(config): model_without_ddp = model optimizer = build_optimizer(config, model) - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[config.LOCAL_RANK], + broadcast_buffers=False, + find_unused_parameters=False, + gradient_as_bucket_view=True, + static_graph=True, + ) loss_scaler = NativeScalerWithGradNormCount() if config.TRAIN.ACCUMULATION_STEPS > 1: - lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS) + lr_scheduler = build_scheduler( + config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS + ) else: lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) - if config.AUG.MIXUP > 0.: - # smoothing is handled with mixup label transform - criterion = SoftTargetCrossEntropy() - elif config.MODEL.LABEL_SMOOTHING > 0.: + if config.AUG.MIXUP == 0 and config.MODEL.LABEL_SMOOTHING > 0.0: + if config.HIERARCHICAL: + raise NotImplementedError( + "We don't support hierarhical loss with label smoothing and no mixup." + ) criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) else: - criterion = torch.nn.CrossEntropyLoss() + # If we have mixup, smoothing is handled with mixup label transform + if config.HIERARCHICAL: + criterion = HierarchicalCrossEntropyLoss( + coeffs=config.TRAIN.HIERARCHICAL_COEFFS + ).to(torch.cuda.current_device()) + else: + criterion = torch.nn.CrossEntropyLoss() + + logger.info("Loss function: %s", criterion) max_accuracy = 0.0 @@ -120,52 +225,102 @@ def main(config): resume_file = auto_resume_helper(config.OUTPUT) if resume_file: if config.MODEL.RESUME: - logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") + logger.warning( + f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}" + ) config.defrost() config.MODEL.RESUME = resume_file config.freeze() - logger.info(f'auto resuming from {resume_file}') + logger.info(f"auto resuming from {resume_file}") else: - logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') + logger.info(f"no checkpoint found in {config.OUTPUT}, ignoring auto resume") if config.MODEL.RESUME: - max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger) - acc1, acc5, loss = validate(config, data_loader_val, model) - logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") + max_accuracy = load_checkpoint( + config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger + ) + acc1, acc5, loss = validate( + config, data_loader_val, model, config.TRAIN.START_EPOCH - 1 + ) + logger.info( + f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" + ) + logger.info("Previously reported best accuracy: %.2f", max_accuracy) if config.EVAL_MODE: return if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): load_pretrained(config, model_without_ddp, logger) - acc1, acc5, loss = validate(config, data_loader_val, model) - logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") + acc1, acc5, loss = validate( + config, data_loader_val, model, config.TRAIN.START_EPOCH - 1 + ) + logger.info( + f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" + ) if config.THROUGHPUT_MODE: throughput(data_loader_val, model, logger) return + early_stopper = EarlyStopper(patience=3, min_delta=0.00001) + logger.info("Start training") start_time = time.time() for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): data_loader_train.sampler.set_epoch(epoch) - train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, - loss_scaler) - if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): - save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler, - logger) - - acc1, acc5, loss = validate(config, data_loader_val, model) - logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") + train_one_epoch( + config, + model, + criterion, + data_loader_train, + optimizer, + epoch, + mixup_fn, + lr_scheduler, + loss_scaler, + ) + if dist.get_rank() == 0 and ( + epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1) + ): + save_checkpoint( + config, + epoch, + model_without_ddp, + max_accuracy, + optimizer, + lr_scheduler, + loss_scaler, + logger, + ) + + acc1, acc5, loss = validate(config, data_loader_val, model, epoch) + logger.info( + f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" + ) max_accuracy = max(max_accuracy, acc1) - logger.info(f'Max accuracy: {max_accuracy:.2f}%') + logger.info(f"Max accuracy: {max_accuracy:.2f}%") + + if early_stopper.early_stop(loss): + print ("early stop") + break total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - logger.info('Training time {}'.format(total_time_str)) - - -def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler): + logger.info("Training time {}".format(total_time_str)) + + +def train_one_epoch( + config, + model, + criterion, + data_loader, + optimizer, + epoch, + mixup_fn, + lr_scheduler, + loss_scaler, +): model.train() optimizer.zero_grad() @@ -190,18 +345,33 @@ def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mix loss = loss / config.TRAIN.ACCUMULATION_STEPS # this attribute is added by timm on one optimizer (adahessian) - is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order - grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD, - parameters=model.parameters(), create_graph=is_second_order, - update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0) + is_second_order = ( + hasattr(optimizer, "is_second_order") and optimizer.is_second_order + ) + grad_norm = loss_scaler( + loss, + optimizer, + clip_grad=config.TRAIN.CLIP_GRAD, + parameters=model.parameters(), + create_graph=is_second_order, + update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0, + ) if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: optimizer.zero_grad() - lr_scheduler.step_update((epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS) + if lr_scheduler is not None: + lr_scheduler.step_update( + (epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS + ) loss_scale_value = loss_scaler.state_dict()["scale"] torch.cuda.synchronize() - loss_meter.update(loss.item(), targets.size(0)) + # We divide by accumulation steps (not sure why) but it makes + # the logged values look weird. So I multiply by it to fix that. + loss_meter.update( + loss.item() * config.TRAIN.ACCUMULATION_STEPS, batch_size_of(targets) + ) + if grad_norm is not None: # loss_scaler return None if not update norm_meter.update(grad_norm) scaler_meter.update(loss_scale_value) @@ -209,25 +379,56 @@ def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mix end = time.time() if idx % config.PRINT_FREQ == 0: - lr = optimizer.param_groups[0]['lr'] - wd = optimizer.param_groups[0]['weight_decay'] + lr = optimizer.param_groups[0]["lr"] + wd = optimizer.param_groups[0]["weight_decay"] memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) etas = batch_time.avg * (num_steps - idx) logger.info( - f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' - f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t wd {wd:.4f}\t' - f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' - f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' - f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' - f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t' - f'mem {memory_used:.0f}MB') - epoch_time = time.time() - start - logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") + f"Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t" + f"eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t" + f"wd {wd:.4f}\t" + f"time {batch_time.val:.4f} ({batch_time.avg:.4f})\t" + f"loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t" + f"grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t" + f"loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t" + f"mem {memory_used:.0f}MB" + ) + stats = { + "train/batch_time": batch_time.val, + "train/batch_loss": loss_meter.val, + "train/grad_norm": norm_meter.val, + "train/loss_scale": scaler_meter.val, + "memory_mb": memory_used, + "train/learning_rate": config.TRAIN.BASE_LR, + } + if lr_scheduler is not None: + stats["train/learning_rate"] = lr_scheduler.get_update_values( + # Copied from line 326 + (epoch * num_steps + idx) + // config.TRAIN.ACCUMULATION_STEPS + )[0] + + wandb_writer.log( + {**stats, "step": epoch * num_steps + idx, "epoch": epoch, "batch": idx} + ) + epoch_time = time.time() - start + logger.info( + f"EPOCH {epoch} training took {datetime.timedelta(seconds=int(epoch_time))}" + ) + wandb_writer.log( + {"train/epoch_time": epoch_time, "train/loss": loss_meter.avg, "epoch": epoch}, + ) +val_loss=[] +val_acc1=[] +val_acc5=[] @torch.no_grad() -def validate(config, data_loader, model): - criterion = torch.nn.CrossEntropyLoss() +def validate(config, data_loader, model, epoch): + if config.HIERARCHICAL: + criterion = FineGrainedCrossEntropyLoss() + else: + criterion = torch.nn.CrossEntropyLoss() model.eval() batch_time = AverageMeter() @@ -246,6 +447,9 @@ def validate(config, data_loader, model): # measure accuracy and record loss loss = criterion(output, target) + # print (output) + # print (output.shape) + # exit() acc1, acc5 = accuracy(output, target, topk=(1, 5)) acc1 = reduce_tensor(acc1) @@ -263,13 +467,29 @@ def validate(config, data_loader, model): if idx % config.PRINT_FREQ == 0: memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) logger.info( - f'Test: [{idx}/{len(data_loader)}]\t' - f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' - f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' - f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' - f'Mem {memory_used:.0f}MB') - logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') + f"Test: [{idx}/{len(data_loader)}]\t" + f"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" + f"Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t" + f"Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t" + f"Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t" + f"Mem {memory_used:.0f}MB" + ) + logger.info(f" * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}") + val_loss.append(loss_meter.avg) + val_acc1.append(acc1_meter.avg) + val_acc5.append(acc5_meter.avg) + logger.info(f' * loss_avg {val_loss}') + logger.info(f' * acc1_avg {val_acc1}') + logger.info(f' * acc5_avg {val_acc5}') + wandb_writer.log( + { + "val/acc1": acc1_meter.avg, + "val/acc5": acc5_meter.avg, + "val/loss": loss_meter.avg, + "epoch": epoch, + }, + ) + return acc1_meter.avg, acc5_meter.avg, loss_meter.avg @@ -280,69 +500,108 @@ def throughput(data_loader, model, logger): for idx, (images, _) in enumerate(data_loader): images = images.cuda(non_blocking=True) batch_size = images.shape[0] - for i in range(50): + for _ in range(50): model(images) torch.cuda.synchronize() - logger.info(f"throughput averaged with 30 times") + logger.info("throughput averaged with 30 times") tic1 = time.time() - for i in range(30): + for _ in range(30): model(images) torch.cuda.synchronize() tic2 = time.time() - logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") + logger.info( + f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}" + ) return -if __name__ == '__main__': - args, config = parse_option() +if __name__ == "__main__": - if config.AMP_OPT_LEVEL: - print("[warning] Apex amp has been deprecated, please use pytorch amp instead!") + args = parse_option() - if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: rank = int(os.environ["RANK"]) - world_size = int(os.environ['WORLD_SIZE']) + world_size = int(os.environ["WORLD_SIZE"]) print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") else: rank = -1 world_size = -1 - torch.cuda.set_device(config.LOCAL_RANK) - torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) + + # torch.cuda.set_device(config.LOCAL_RANK) + torch.distributed.init_process_group( + backend="nccl", init_method="env://", world_size=world_size, rank=rank + ) torch.distributed.barrier() - seed = config.SEED + dist.get_rank() - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - np.random.seed(seed) - random.seed(seed) - cudnn.benchmark = True - - # linear scale the learning rate according to total batch size, may not be optimal - linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 - linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 - linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 - # gradient accumulation also need to scale the learning rate - if config.TRAIN.ACCUMULATION_STEPS > 1: - linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS - linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS - linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS - config.defrost() - config.TRAIN.BASE_LR = linear_scaled_lr - config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr - config.TRAIN.MIN_LR = linear_scaled_min_lr - config.freeze() - - os.makedirs(config.OUTPUT, exist_ok=True) - logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") - - if dist.get_rank() == 0: - path = os.path.join(config.OUTPUT, "config.json") - with open(path, "w") as f: - f.write(config.dump()) - logger.info(f"Full config saved to {path}") - - # print config - logger.info(config.dump()) - logger.info(json.dumps(vars(args))) - - main(config) + for experiment_config in find_experiments(args.cfg): + + # print (find_experiments(args.cfg)) + # print ("experiment_config",experiment_config) + args.cfg=experiment_config + + config = get_config(args) + + if config.AMP_OPT_LEVEL: + print("[warning] Apex amp has been deprecated, please use pytorch amp instead!") + + torch.cuda.set_device(config.LOCAL_RANK) + + seed = config.SEED + dist.get_rank() + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + cudnn.benchmark = True + + # linear scale the learning rate according to total batch size, may not be optimal + linear_scaled_lr = ( + config.TRAIN.BASE_LR + * config.TRAIN.DEVICE_BATCH_SIZE + * dist.get_world_size() + / 512.0 + ) + linear_scaled_warmup_lr = ( + config.TRAIN.WARMUP_LR + * config.TRAIN.DEVICE_BATCH_SIZE + * dist.get_world_size() + / 512.0 + ) + linear_scaled_min_lr = ( + config.TRAIN.MIN_LR + * config.TRAIN.DEVICE_BATCH_SIZE + * dist.get_world_size() + / 512.0 + ) + # gradient accumulation also need to scale the learning rate + if config.TRAIN.ACCUMULATION_STEPS > 1: + linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS + linear_scaled_warmup_lr = ( + linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS + ) + linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS + config.defrost() + config.TRAIN.BASE_LR = linear_scaled_lr + config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr + config.TRAIN.MIN_LR = linear_scaled_min_lr + config.freeze() + + os.makedirs(config.OUTPUT, exist_ok=True) + logger = create_logger( + output_dir=config.OUTPUT, + dist_rank=dist.get_rank(), + name=f"{config.EXPERIMENT.NAME}", + ) + wandb_writer = WandbWriter(rank=dist.get_rank()) + wandb_writer.init(config) + + if dist.get_rank() == 0: + path = os.path.join(config.OUTPUT, "config.yaml") + with open(path, "w") as f: + f.write(config.dump()) + logger.info(f"Full config saved to {path}") + + # print config + logger.info(config.dump()) + logger.info(json.dumps(vars(args))) + + main(config) diff --git a/main_moe.py b/main_moe.py index acf5d205..9a611cfa 100644 --- a/main_moe.py +++ b/main_moe.py @@ -5,70 +5,117 @@ # Written by Ze Liu # -------------------------------------------------------- -from tutel import system - -import os -import time -import json -import random import argparse import datetime -import numpy as np +import json +import os +import random +import time from functools import partial + +import numpy as np import torch import torch.backends.cudnn as cudnn import torch.distributed as dist - from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy -from timm.utils import accuracy, AverageMeter +from timm.utils import AverageMeter, accuracy +from tutel import system from config import get_config -from models import build_model from data import build_loader +from logger import create_logger from lr_scheduler import build_scheduler +from models import build_model from optimizer import build_optimizer -from logger import create_logger from utils import NativeScalerWithGradNormCount, reduce_tensor -from utils_moe import load_checkpoint, load_pretrained, save_checkpoint, auto_resume_helper, hook_scale_grad +from utils_moe import ( + auto_resume_helper, + hook_scale_grad, + load_checkpoint, + load_pretrained, + save_checkpoint, +) -assert torch.__version__ >= '1.8.0', "DDP-based MoE requires Pytorch >= 1.8.0" +assert torch.__version__ >= "1.8.0", "DDP-based MoE requires Pytorch >= 1.8.0" def parse_option(): - parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False) - parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) + parser = argparse.ArgumentParser( + "Swin Transformer training and evaluation script", add_help=False + ) + parser.add_argument( + "--cfg", + type=str, + required=True, + metavar="FILE", + help="path to config file", + ) parser.add_argument( "--opts", help="Modify config options by adding 'KEY VALUE' pairs. ", default=None, - nargs='+', + nargs="+", ) # easy config modification - parser.add_argument('--batch-size', type=int, help="batch size for single GPU") - parser.add_argument('--data-path', type=str, help='path to dataset') - parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') - parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], - help='no: no cache, ' - 'full: cache all data, ' - 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') - parser.add_argument('--pretrained', - help='pretrained weight from checkpoint, could be imagenet22k pretrained weight') - parser.add_argument('--resume', help='resume from checkpoint') - parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") - parser.add_argument('--use-checkpoint', action='store_true', - help="whether to use gradient checkpointing to save memory") - parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp') - parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'], - help='mixed precision opt level, if O0, no amp is used (deprecated!)') - parser.add_argument('--output', default='output', type=str, metavar='PATH', - help='root of output folder, the full path is // (default: output)') - parser.add_argument('--tag', help='tag of experiment') - parser.add_argument('--eval', action='store_true', help='Perform evaluation only') - parser.add_argument('--throughput', action='store_true', help='Test throughput only') + parser.add_argument("--batch-size", type=int, help="batch size for single GPU") + parser.add_argument("--data-path", type=str, help="path to dataset") + parser.add_argument( + "--zip", + action="store_true", + help="use zipped dataset instead of folder dataset", + ) + parser.add_argument( + "--cache-mode", + type=str, + default="part", + choices=["no", "full", "part"], + help="no: no cache, " + "full: cache all data, " + "part: sharding the dataset into nonoverlapping pieces and only cache one piece", + ) + parser.add_argument( + "--pretrained", + help="pretrained weight from checkpoint, could be imagenet22k pretrained weight", + ) + parser.add_argument("--resume", help="resume from checkpoint") + parser.add_argument( + "--accumulation-steps", type=int, help="gradient accumulation steps" + ) + parser.add_argument( + "--use-checkpoint", + action="store_true", + help="whether to use gradient checkpointing to save memory", + ) + parser.add_argument( + "--disable_amp", action="store_true", help="Disable pytorch amp" + ) + parser.add_argument( + "--amp-opt-level", + type=str, + choices=["O0", "O1", "O2"], + help="mixed precision opt level, if O0, no amp is used (deprecated!)", + ) + parser.add_argument( + "--output", + default="output", + type=str, + metavar="PATH", + help="root of output folder, the full path is // (default: output)", + ) + parser.add_argument("--tag", help="tag of experiment") + parser.add_argument("--eval", action="store_true", help="Perform evaluation only") + parser.add_argument( + "--throughput", action="store_true", help="Test throughput only" + ) # distributed training - parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') + parser.add_argument( + "--local_rank", + type=int, + required=True, + help="local rank for DistributedDataParallel", + ) args, unparsed = parser.parse_known_args() @@ -78,7 +125,13 @@ def parse_option(): def main(config): - dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config) + ( + dataset_train, + dataset_val, + data_loader_train, + data_loader_val, + mixup_fn, + ) = build_loader(config) logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") model = build_model(config) @@ -86,18 +139,32 @@ def main(config): # For Tutel MoE for name, param in model.named_parameters(): - if param.requires_grad == True and hasattr(param, 'skip_allreduce') and param.skip_allreduce is True: + if ( + param.requires_grad == True + and hasattr(param, "skip_allreduce") + and param.skip_allreduce is True + ): model.add_param_to_skip_allreduce(name) param.register_hook(partial(hook_scale_grad, dist.get_world_size())) - logger.info(f"[rank{dist.get_rank()}] [{name}] skip all_reduce and div {dist.get_world_size()} for grad") + logger.info( + f"[rank{dist.get_rank()}] [{name}] skip all_reduce and div {dist.get_world_size()} for grad" + ) - n_parameters_single = sum(p.numel() * model.sharded_count if hasattr(p, 'skip_allreduce') - else p.numel() for p in model.parameters() if p.requires_grad) + n_parameters_single = sum( + p.numel() * model.sharded_count if hasattr(p, "skip_allreduce") else p.numel() + for p in model.parameters() + if p.requires_grad + ) logger.info(f"number of params single: {n_parameters_single}") - n_parameters_whole = sum(p.numel() * model.sharded_count * model.global_experts if hasattr(p, 'skip_allreduce') - else p.numel() for p in model.parameters() if p.requires_grad) + n_parameters_whole = sum( + p.numel() * model.sharded_count * model.global_experts + if hasattr(p, "skip_allreduce") + else p.numel() + for p in model.parameters() + if p.requires_grad + ) logger.info(f"number of params whole: {n_parameters_whole}") - if hasattr(model, 'flops'): + if hasattr(model, "flops"): flops = model.flops() logger.info(f"number of GFLOPs: {flops / 1e9}") @@ -105,18 +172,22 @@ def main(config): model_without_ddp = model optimizer = build_optimizer(config, model) - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False + ) loss_scaler = NativeScalerWithGradNormCount() if config.TRAIN.ACCUMULATION_STEPS > 1: - lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS) + lr_scheduler = build_scheduler( + config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS + ) else: lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) - if config.AUG.MIXUP > 0.: + if config.AUG.MIXUP > 0.0: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() - elif config.MODEL.LABEL_SMOOTHING > 0.: + elif config.MODEL.LABEL_SMOOTHING > 0.0: criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) else: criterion = torch.nn.CrossEntropyLoss() @@ -127,25 +198,33 @@ def main(config): resume_file = auto_resume_helper(config.OUTPUT, config.TRAIN.MOE.SAVE_MASTER) if resume_file: if config.MODEL.RESUME: - logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") + logger.warning( + f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}" + ) config.defrost() config.MODEL.RESUME = resume_file config.freeze() - logger.info(f'auto resuming from {resume_file}') + logger.info(f"auto resuming from {resume_file}") else: - logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') + logger.info(f"no checkpoint found in {config.OUTPUT}, ignoring auto resume") if config.MODEL.RESUME: - max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger) + max_accuracy = load_checkpoint( + config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger + ) acc1, acc5, loss = validate(config, data_loader_val, model) - logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") + logger.info( + f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" + ) if config.EVAL_MODE: return if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): load_pretrained(config, model_without_ddp, logger) acc1, acc5, loss = validate(config, data_loader_val, model) - logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") + logger.info( + f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" + ) if config.EVAL_MODE: return @@ -158,24 +237,62 @@ def main(config): for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): data_loader_train.sampler.set_epoch(epoch) - train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, - loss_scaler) - if (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): - save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler, - logger) + train_one_epoch( + config, + model, + criterion, + data_loader_train, + optimizer, + epoch, + mixup_fn, + lr_scheduler, + loss_scaler, + ) + if epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1): + save_checkpoint( + config, + epoch, + model_without_ddp, + max_accuracy, + optimizer, + lr_scheduler, + loss_scaler, + logger, + ) acc1, acc5, loss = validate(config, data_loader_val, model) - logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") + logger.info( + f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" + ) max_accuracy = max(max_accuracy, acc1) - logger.info(f'Max accuracy: {max_accuracy:.2f}%') - save_checkpoint(config, 'final', model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler, - logger, zero_redundancy=True) + logger.info(f"Max accuracy: {max_accuracy:.2f}%") + save_checkpoint( + config, + "final", + model_without_ddp, + max_accuracy, + optimizer, + lr_scheduler, + loss_scaler, + logger, + zero_redundancy=True, + ) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - logger.info('Training time {}'.format(total_time_str)) - - -def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler): + logger.info("Training time {}".format(total_time_str)) + + +def train_one_epoch( + config, + model, + criterion, + data_loader, + optimizer, + epoch, + mixup_fn, + lr_scheduler, + loss_scaler, +): model.train() optimizer.zero_grad() @@ -203,20 +320,31 @@ def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mix loss = loss / config.TRAIN.ACCUMULATION_STEPS # this attribute is added by timm on one optimizer (adahessian) - is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order - grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD, - parameters=model.parameters(), create_graph=is_second_order, - update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0) + is_second_order = ( + hasattr(optimizer, "is_second_order") and optimizer.is_second_order + ) + grad_norm = loss_scaler( + loss, + optimizer, + clip_grad=config.TRAIN.CLIP_GRAD, + parameters=model.parameters(), + create_graph=is_second_order, + update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0, + ) if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: optimizer.zero_grad() - lr_scheduler.step_update((epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS) + lr_scheduler.step_update( + (epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS + ) loss_scale_value = loss_scaler.state_dict()["scale"] torch.cuda.synchronize() loss_meter.update(loss.item(), targets.size(0)) loss_cls_meter.update(l_cls.item(), targets.size(0)) - loss_aux_meter.update(l_aux if isinstance(l_aux, float) else l_aux.item(), targets.size(0)) + loss_aux_meter.update( + l_aux if isinstance(l_aux, float) else l_aux.item(), targets.size(0) + ) if grad_norm is not None: # loss_scaler return None if not update norm_meter.update(grad_norm) scaler_meter.update(loss_scale_value) @@ -224,22 +352,25 @@ def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mix end = time.time() if idx % config.PRINT_FREQ == 0: - lr = optimizer.param_groups[0]['lr'] - wd = optimizer.param_groups[0]['weight_decay'] + lr = optimizer.param_groups[0]["lr"] + wd = optimizer.param_groups[0]["weight_decay"] memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) etas = batch_time.avg * (num_steps - idx) logger.info( - f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' - f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t wd {wd:.4f}\t' - f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' - f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' - f'loss-cls {loss_cls_meter.val:.4f} ({loss_cls_meter.avg:.4f})\t' - f'loss-aux {loss_aux_meter.val:.4f} ({loss_aux_meter.avg:.4f})\t' - f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' - f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t' - f'mem {memory_used:.0f}MB') + f"Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t" + f"eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t wd {wd:.4f}\t" + f"time {batch_time.val:.4f} ({batch_time.avg:.4f})\t" + f"loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t" + f"loss-cls {loss_cls_meter.val:.4f} ({loss_cls_meter.avg:.4f})\t" + f"loss-aux {loss_aux_meter.val:.4f} ({loss_aux_meter.avg:.4f})\t" + f"grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t" + f"loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t" + f"mem {memory_used:.0f}MB" + ) epoch_time = time.time() - start - logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") + logger.info( + f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}" + ) @torch.no_grad() @@ -270,7 +401,9 @@ def validate(config, data_loader, model): acc5 = reduce_tensor(acc5) loss_cls_meter.update(l_cls.item(), target.size(0)) - loss_aux_meter.update(l_aux if isinstance(l_aux, float) else l_aux.item(), target.size(0)) + loss_aux_meter.update( + l_aux if isinstance(l_aux, float) else l_aux.item(), target.size(0) + ) acc1_meter.update(acc1.item(), target.size(0)) acc5_meter.update(acc5.item(), target.size(0)) @@ -281,14 +414,15 @@ def validate(config, data_loader, model): if idx % config.PRINT_FREQ == 0: memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) logger.info( - f'Test: [{idx}/{len(data_loader)}]\t' - f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - f'Loss-Cls {loss_cls_meter.val:.4f} ({loss_cls_meter.avg:.4f})\t' - f'Loss-Aux {loss_aux_meter.val:.4f} ({loss_aux_meter.avg:.4f})\t' - f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' - f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' - f'Mem {memory_used:.0f}MB') - logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') + f"Test: [{idx}/{len(data_loader)}]\t" + f"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" + f"Loss-Cls {loss_cls_meter.val:.4f} ({loss_cls_meter.avg:.4f})\t" + f"Loss-Aux {loss_aux_meter.val:.4f} ({loss_aux_meter.avg:.4f})\t" + f"Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t" + f"Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t" + f"Mem {memory_used:.0f}MB" + ) + logger.info(f" * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}") return acc1_meter.avg, acc5_meter.avg, loss_cls_meter.avg @@ -308,25 +442,29 @@ def throughput(data_loader, model, logger): model(images) torch.cuda.synchronize() tic2 = time.time() - logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") + logger.info( + f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}" + ) return -if __name__ == '__main__': +if __name__ == "__main__": args, config = parse_option() if config.AMP_OPT_LEVEL: print("[warning] Apex amp has been deprecated, please use pytorch amp instead!") - if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: rank = int(os.environ["RANK"]) - world_size = int(os.environ['WORLD_SIZE']) + world_size = int(os.environ["WORLD_SIZE"]) print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") else: rank = -1 world_size = -1 torch.cuda.set_device(config.LOCAL_RANK) - torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) + torch.distributed.init_process_group( + backend="nccl", init_method="env://", world_size=world_size, rank=rank + ) torch.distributed.barrier() seed = config.SEED + dist.get_rank() @@ -337,13 +475,21 @@ def throughput(data_loader, model, logger): cudnn.benchmark = True # linear scale the learning rate according to total batch size, may not be optimal - linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 - linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 - linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 + linear_scaled_lr = ( + config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 + ) + linear_scaled_warmup_lr = ( + config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 + ) + linear_scaled_min_lr = ( + config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 + ) # gradient accumulation also need to scale the learning rate if config.TRAIN.ACCUMULATION_STEPS > 1: linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS - linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS + linear_scaled_warmup_lr = ( + linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS + ) linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS config.defrost() config.TRAIN.BASE_LR = linear_scaled_lr @@ -352,7 +498,11 @@ def throughput(data_loader, model, logger): config.freeze() os.makedirs(config.OUTPUT, exist_ok=True) - logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") + logger = create_logger( + output_dir=config.OUTPUT, + dist_rank=dist.get_rank(), + name=f"{config.EXPERIMENT.NAME}", + ) if dist.get_rank() == 0: path = os.path.join(config.OUTPUT, "config.json") diff --git a/main_simmim_ft.py b/main_simmim_ft.py index 067dfbb0..e0ea09ad 100644 --- a/main_simmim_ft.py +++ b/main_simmim_ft.py @@ -6,58 +6,87 @@ # Modified by Zhenda Xie # -------------------------------------------------------- -import os -import time import argparse import datetime -import numpy as np +import os +import time +import numpy as np import torch import torch.backends.cudnn as cudnn -import torch.distributed as dist import torch.cuda.amp as amp - +import torch.distributed as dist from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy -from timm.utils import accuracy, AverageMeter +from timm.utils import AverageMeter, accuracy from config import get_config -from models import build_model from data import build_loader +from logger import create_logger from lr_scheduler import build_scheduler +from models import build_model from optimizer import build_optimizer -from logger import create_logger -from utils_simmim import load_checkpoint, load_pretrained, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor +from utils_simmim import ( + auto_resume_helper, + get_grad_norm, + load_checkpoint, + load_pretrained, + reduce_tensor, + save_checkpoint, +) def parse_option(): - parser = argparse.ArgumentParser('SimMIM fine-tuning script', add_help=False) - parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) + parser = argparse.ArgumentParser("SimMIM fine-tuning script", add_help=False) + parser.add_argument( + "--cfg", + type=str, + required=True, + metavar="FILE", + help="path to config file", + ) parser.add_argument( "--opts", help="Modify config options by adding 'KEY VALUE' pairs. ", default=None, - nargs='+', + nargs="+", ) # easy config modification - parser.add_argument('--batch-size', type=int, help="batch size for single GPU") - parser.add_argument('--data-path', type=str, help='path to dataset') - parser.add_argument('--pretrained', type=str, help='path to pre-trained model') - parser.add_argument('--resume', help='resume from checkpoint') - parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") - parser.add_argument('--use-checkpoint', action='store_true', - help="whether to use gradient checkpointing to save memory") - parser.add_argument('--enable-amp', action='store_true') - parser.add_argument('--disable-amp', action='store_false', dest='enable_amp') + parser.add_argument("--batch-size", type=int, help="batch size for single GPU") + parser.add_argument("--data-path", type=str, help="path to dataset") + parser.add_argument("--pretrained", type=str, help="path to pre-trained model") + parser.add_argument("--resume", help="resume from checkpoint") + parser.add_argument( + "--accumulation-steps", type=int, help="gradient accumulation steps" + ) + parser.add_argument( + "--use-checkpoint", + action="store_true", + help="whether to use gradient checkpointing to save memory", + ) + parser.add_argument("--enable-amp", action="store_true") + parser.add_argument("--disable-amp", action="store_false", dest="enable_amp") parser.set_defaults(enable_amp=True) - parser.add_argument('--output', default='output', type=str, metavar='PATH', - help='root of output folder, the full path is // (default: output)') - parser.add_argument('--tag', help='tag of experiment') - parser.add_argument('--eval', action='store_true', help='Perform evaluation only') - parser.add_argument('--throughput', action='store_true', help='Test throughput only') + parser.add_argument( + "--output", + default="output", + type=str, + metavar="PATH", + help="root of output folder, the full path is // (default: output)", + ) + parser.add_argument("--tag", help="tag of experiment") + parser.add_argument("--eval", action="store_true", help="Perform evaluation only") + parser.add_argument( + "--throughput", action="store_true", help="Test throughput only" + ) # distributed training - parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') + parser.add_argument( + "--local_rank", + type=int, + required=True, + help="local rank for DistributedDataParallel", + ) args = parser.parse_args() @@ -67,7 +96,13 @@ def parse_option(): def main(config): - dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config, simmim=True, is_pretrain=False) + ( + dataset_train, + dataset_val, + data_loader_train, + data_loader_val, + mixup_fn, + ) = build_loader(config, simmim=True, is_pretrain=False) logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") model = build_model(config, is_pretrain=False) @@ -75,22 +110,24 @@ def main(config): logger.info(str(model)) optimizer = build_optimizer(config, model, simmim=True, is_pretrain=False) - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False + ) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"number of params: {n_parameters}") - if hasattr(model_without_ddp, 'flops'): + if hasattr(model_without_ddp, "flops"): flops = model_without_ddp.flops() logger.info(f"number of GFLOPs: {flops / 1e9}") lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) scaler = amp.GradScaler() - if config.AUG.MIXUP > 0.: + if config.AUG.MIXUP > 0.0: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() - elif config.MODEL.LABEL_SMOOTHING > 0.: + elif config.MODEL.LABEL_SMOOTHING > 0.0: criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) else: criterion = torch.nn.CrossEntropyLoss() @@ -101,25 +138,33 @@ def main(config): resume_file = auto_resume_helper(config.OUTPUT, logger) if resume_file: if config.MODEL.RESUME: - logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") + logger.warning( + f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}" + ) config.defrost() config.MODEL.RESUME = resume_file config.freeze() - logger.info(f'auto resuming from {resume_file}') + logger.info(f"auto resuming from {resume_file}") else: - logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') + logger.info(f"no checkpoint found in {config.OUTPUT}, ignoring auto resume") if config.MODEL.RESUME: - max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger) + max_accuracy = load_checkpoint( + config, model_without_ddp, optimizer, lr_scheduler, scaler, logger + ) acc1, acc5, loss = validate(config, data_loader_val, model) - logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") + logger.info( + f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" + ) if config.EVAL_MODE: return if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): load_pretrained(config, model_without_ddp, logger) acc1, acc5, loss = validate(config, data_loader_val, model) - logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") + logger.info( + f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" + ) if config.THROUGHPUT_MODE: throughput(data_loader_val, model, logger) @@ -130,25 +175,60 @@ def main(config): for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): data_loader_train.sampler.set_epoch(epoch) - train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, scaler) - if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): - save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, scaler, logger) + train_one_epoch( + config, + model, + criterion, + data_loader_train, + optimizer, + epoch, + mixup_fn, + lr_scheduler, + scaler, + ) + if dist.get_rank() == 0 and ( + epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1) + ): + save_checkpoint( + config, + epoch, + model_without_ddp, + max_accuracy, + optimizer, + lr_scheduler, + scaler, + logger, + ) acc1, acc5, loss = validate(config, data_loader_val, model) - logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") + logger.info( + f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" + ) max_accuracy = max(max_accuracy, acc1) - logger.info(f'Max accuracy: {max_accuracy:.2f}%') + logger.info(f"Max accuracy: {max_accuracy:.2f}%") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - logger.info('Training time {}'.format(total_time_str)) - - -def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, scaler): + logger.info("Training time {}".format(total_time_str)) + + +def train_one_epoch( + config, + model, + criterion, + data_loader, + optimizer, + epoch, + mixup_fn, + lr_scheduler, + scaler, +): model.train() optimizer.zero_grad() - - logger.info(f'Current learning rate for different parameter groups: {[it["lr"] for it in optimizer.param_groups]}') + + logger.info( + f'Current learning rate for different parameter groups: {[it["lr"] for it in optimizer.param_groups]}' + ) num_steps = len(data_loader) batch_time = AverageMeter() @@ -173,7 +253,9 @@ def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mix scaler.scale(loss).backward() if config.TRAIN.CLIP_GRAD: scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), config.TRAIN.CLIP_GRAD + ) else: grad_norm = get_grad_norm(model.parameters()) if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: @@ -187,7 +269,9 @@ def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mix scaler.scale(loss).backward() if config.TRAIN.CLIP_GRAD: scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), config.TRAIN.CLIP_GRAD + ) else: grad_norm = get_grad_norm(model.parameters()) scaler.step(optimizer) @@ -203,19 +287,22 @@ def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mix end = time.time() if idx % config.PRINT_FREQ == 0: - lr = optimizer.param_groups[-1]['lr'] + lr = optimizer.param_groups[-1]["lr"] memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) etas = batch_time.avg * (num_steps - idx) logger.info( - f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' - f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' - f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' - f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' - f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' - f'loss_scale {loss_scale_meter.val:.4f} ({loss_scale_meter.avg:.4f})\t' - f'mem {memory_used:.0f}MB') + f"Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t" + f"eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t" + f"time {batch_time.val:.4f} ({batch_time.avg:.4f})\t" + f"loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t" + f"grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t" + f"loss_scale {loss_scale_meter.val:.4f} ({loss_scale_meter.avg:.4f})\t" + f"mem {memory_used:.0f}MB" + ) epoch_time = time.time() - start - logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") + logger.info( + f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}" + ) @torch.no_grad() @@ -255,13 +342,14 @@ def validate(config, data_loader, model): if idx % config.PRINT_FREQ == 0: memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) logger.info( - f'Test: [{idx}/{len(data_loader)}]\t' - f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' - f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' - f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' - f'Mem {memory_used:.0f}MB') - logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') + f"Test: [{idx}/{len(data_loader)}]\t" + f"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" + f"Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t" + f"Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t" + f"Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t" + f"Mem {memory_used:.0f}MB" + ) + logger.info(f" * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}") return acc1_meter.avg, acc5_meter.avg, loss_meter.avg @@ -281,22 +369,26 @@ def throughput(data_loader, model, logger): model(images) torch.cuda.synchronize() tic2 = time.time() - logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") + logger.info( + f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}" + ) return -if __name__ == '__main__': +if __name__ == "__main__": _, config = parse_option() - if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: rank = int(os.environ["RANK"]) - world_size = int(os.environ['WORLD_SIZE']) + world_size = int(os.environ["WORLD_SIZE"]) print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") else: rank = -1 world_size = -1 torch.cuda.set_device(config.LOCAL_RANK) - torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) + torch.distributed.init_process_group( + backend="nccl", init_method="env://", world_size=world_size, rank=rank + ) torch.distributed.barrier() seed = config.SEED + dist.get_rank() @@ -305,13 +397,21 @@ def throughput(data_loader, model, logger): cudnn.benchmark = True # linear scale the learning rate according to total batch size, may not be optimal - linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 - linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 - linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 + linear_scaled_lr = ( + config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 + ) + linear_scaled_warmup_lr = ( + config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 + ) + linear_scaled_min_lr = ( + config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 + ) # gradient accumulation also need to scale the learning rate if config.TRAIN.ACCUMULATION_STEPS > 1: linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS - linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS + linear_scaled_warmup_lr = ( + linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS + ) linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS config.defrost() config.TRAIN.BASE_LR = linear_scaled_lr @@ -320,7 +420,11 @@ def throughput(data_loader, model, logger): config.freeze() os.makedirs(config.OUTPUT, exist_ok=True) - logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") + logger = create_logger( + output_dir=config.OUTPUT, + dist_rank=dist.get_rank(), + name=f"{config.EXPERIMENT.NAME}", + ) if dist.get_rank() == 0: path = os.path.join(config.OUTPUT, "config.json") @@ -331,4 +435,4 @@ def throughput(data_loader, model, logger): # print config logger.info(config.dump()) - main(config) \ No newline at end of file + main(config) diff --git a/main_simmim_pt.py b/main_simmim_pt.py index 6591d214..decca083 100644 --- a/main_simmim_pt.py +++ b/main_simmim_pt.py @@ -6,53 +6,79 @@ # Modified by Zhenda Xie # -------------------------------------------------------- -import os -import time import argparse import datetime -import numpy as np +import os +import time +import numpy as np import torch import torch.backends.cudnn as cudnn -import torch.distributed as dist import torch.cuda.amp as amp +import torch.distributed as dist from timm.utils import AverageMeter from config import get_config -from models import build_model from data import build_loader +from logger import create_logger from lr_scheduler import build_scheduler +from models import build_model from optimizer import build_optimizer -from logger import create_logger -from utils_simmim import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper +from utils_simmim import ( + auto_resume_helper, + get_grad_norm, + load_checkpoint, + save_checkpoint, +) def parse_option(): - parser = argparse.ArgumentParser('SimMIM pre-training script', add_help=False) - parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) + parser = argparse.ArgumentParser("SimMIM pre-training script", add_help=False) + parser.add_argument( + "--cfg", + type=str, + required=True, + metavar="FILE", + help="path to config file", + ) parser.add_argument( "--opts", help="Modify config options by adding 'KEY VALUE' pairs. ", default=None, - nargs='+', + nargs="+", ) # easy config modification - parser.add_argument('--batch-size', type=int, help="batch size for single GPU") - parser.add_argument('--data-path', type=str, help='path to dataset') - parser.add_argument('--resume', help='resume from checkpoint') - parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") - parser.add_argument('--use-checkpoint', action='store_true', - help="whether to use gradient checkpointing to save memory") - parser.add_argument('--enable-amp', action='store_true') - parser.add_argument('--disable-amp', action='store_false', dest='enable_amp') + parser.add_argument("--batch-size", type=int, help="batch size for single GPU") + parser.add_argument("--data-path", type=str, help="path to dataset") + parser.add_argument("--resume", help="resume from checkpoint") + parser.add_argument( + "--accumulation-steps", type=int, help="gradient accumulation steps" + ) + parser.add_argument( + "--use-checkpoint", + action="store_true", + help="whether to use gradient checkpointing to save memory", + ) + parser.add_argument("--enable-amp", action="store_true") + parser.add_argument("--disable-amp", action="store_false", dest="enable_amp") parser.set_defaults(enable_amp=True) - parser.add_argument('--output', default='output', type=str, metavar='PATH', - help='root of output folder, the full path is // (default: output)') - parser.add_argument('--tag', help='tag of experiment') + parser.add_argument( + "--output", + default="output", + type=str, + metavar="PATH", + help="root of output folder, the full path is // (default: output)", + ) + parser.add_argument("--tag", help="tag of experiment") # distributed training - parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') + parser.add_argument( + "--local_rank", + type=int, + required=True, + help="local rank for DistributedDataParallel", + ) args = parser.parse_args() @@ -70,12 +96,14 @@ def main(config): logger.info(str(model)) optimizer = build_optimizer(config, model, simmim=True, is_pretrain=True) - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False + ) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"number of params: {n_parameters}") - if hasattr(model_without_ddp, 'flops'): + if hasattr(model_without_ddp, "flops"): flops = model_without_ddp.flops() logger.info(f"number of GFLOPs: {flops / 1e9}") @@ -86,29 +114,46 @@ def main(config): resume_file = auto_resume_helper(config.OUTPUT, logger) if resume_file: if config.MODEL.RESUME: - logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") + logger.warning( + f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}" + ) config.defrost() config.MODEL.RESUME = resume_file config.freeze() - logger.info(f'auto resuming from {resume_file}') + logger.info(f"auto resuming from {resume_file}") else: - logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') + logger.info(f"no checkpoint found in {config.OUTPUT}, ignoring auto resume") if config.MODEL.RESUME: - load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger) + load_checkpoint( + config, model_without_ddp, optimizer, lr_scheduler, scaler, logger + ) logger.info("Start training") start_time = time.time() for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): data_loader_train.sampler.set_epoch(epoch) - train_one_epoch(config, model, data_loader_train, optimizer, epoch, lr_scheduler, scaler) - if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): - save_checkpoint(config, epoch, model_without_ddp, 0., optimizer, lr_scheduler, scaler, logger) + train_one_epoch( + config, model, data_loader_train, optimizer, epoch, lr_scheduler, scaler + ) + if dist.get_rank() == 0 and ( + epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1) + ): + save_checkpoint( + config, + epoch, + model_without_ddp, + 0.0, + optimizer, + lr_scheduler, + scaler, + logger, + ) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - logger.info('Training time {}'.format(total_time_str)) + logger.info("Training time {}".format(total_time_str)) def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler, scaler): @@ -135,7 +180,9 @@ def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler, scaler.scale(loss).backward() if config.TRAIN.CLIP_GRAD: scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), config.TRAIN.CLIP_GRAD + ) else: grad_norm = get_grad_norm(model.parameters()) if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: @@ -148,7 +195,9 @@ def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler, scaler.scale(loss).backward() if config.TRAIN.CLIP_GRAD: scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), config.TRAIN.CLIP_GRAD + ) else: grad_norm = get_grad_norm(model.parameters()) scaler.step(optimizer) @@ -164,33 +213,38 @@ def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler, end = time.time() if idx % config.PRINT_FREQ == 0: - lr = optimizer.param_groups[0]['lr'] + lr = optimizer.param_groups[0]["lr"] memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) etas = batch_time.avg * (num_steps - idx) logger.info( - f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' - f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' - f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' - f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' - f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' - f'loss_scale {loss_scale_meter.val:.4f} ({loss_scale_meter.avg:.4f})\t' - f'mem {memory_used:.0f}MB') + f"Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t" + f"eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t" + f"time {batch_time.val:.4f} ({batch_time.avg:.4f})\t" + f"loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t" + f"grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t" + f"loss_scale {loss_scale_meter.val:.4f} ({loss_scale_meter.avg:.4f})\t" + f"mem {memory_used:.0f}MB" + ) epoch_time = time.time() - start - logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") + logger.info( + f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}" + ) -if __name__ == '__main__': +if __name__ == "__main__": _, config = parse_option() - if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: rank = int(os.environ["RANK"]) - world_size = int(os.environ['WORLD_SIZE']) + world_size = int(os.environ["WORLD_SIZE"]) print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") else: rank = -1 world_size = -1 torch.cuda.set_device(config.LOCAL_RANK) - torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) + torch.distributed.init_process_group( + backend="nccl", init_method="env://", world_size=world_size, rank=rank + ) torch.distributed.barrier() seed = config.SEED + dist.get_rank() @@ -199,13 +253,21 @@ def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler, cudnn.benchmark = True # linear scale the learning rate according to total batch size, may not be optimal - linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 - linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 - linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 + linear_scaled_lr = ( + config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 + ) + linear_scaled_warmup_lr = ( + config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 + ) + linear_scaled_min_lr = ( + config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 + ) # gradient accumulation also need to scale the learning rate if config.TRAIN.ACCUMULATION_STEPS > 1: linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS - linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS + linear_scaled_warmup_lr = ( + linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS + ) linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS config.defrost() config.TRAIN.BASE_LR = linear_scaled_lr @@ -214,7 +276,11 @@ def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler, config.freeze() os.makedirs(config.OUTPUT, exist_ok=True) - logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") + logger = create_logger( + output_dir=config.OUTPUT, + dist_rank=dist.get_rank(), + name=f"{config.EXPERIMENT.NAME}", + ) if dist.get_rank() == 0: path = os.path.join(config.OUTPUT, "config.json") @@ -225,4 +291,4 @@ def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler, # print config logger.info(config.dump()) - main(config) \ No newline at end of file + main(config) diff --git a/makefile b/makefile new file mode 100644 index 00000000..caf42d01 --- /dev/null +++ b/makefile @@ -0,0 +1,6 @@ +fmt: + fd -e py | xargs isort --profile black + fd -e py | xargs black + +lint: + fd -e py | xargs flake8 diff --git a/models/__init__.py b/models/__init__.py index 2d9c65e3..59774f75 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1 +1 @@ -from .build import build_model \ No newline at end of file +from .build import build_model diff --git a/models/build.py b/models/build.py index c37384d2..55932d94 100644 --- a/models/build.py +++ b/models/build.py @@ -5,11 +5,11 @@ # Written by Ze Liu # -------------------------------------------------------- +from .simmim import build_simmim +from .swin_mlp import SwinMLP from .swin_transformer import SwinTransformer -from .swin_transformer_v2 import SwinTransformerV2 from .swin_transformer_moe import SwinTransformerMoE -from .swin_mlp import SwinMLP -from .simmim import build_simmim +from .swin_transformer_v2 import SwinTransformerV2 def build_model(config, is_pretrain=False): @@ -19,102 +19,112 @@ def build_model(config, is_pretrain=False): if config.FUSED_LAYERNORM: try: import apex as amp + layernorm = amp.normalization.FusedLayerNorm - except: + except ImportError: layernorm = None print("To use FusedLayerNorm, please install apex.") else: import torch.nn as nn + layernorm = nn.LayerNorm if is_pretrain: model = build_simmim(config) return model - if model_type == 'swin': - model = SwinTransformer(img_size=config.DATA.IMG_SIZE, - patch_size=config.MODEL.SWIN.PATCH_SIZE, - in_chans=config.MODEL.SWIN.IN_CHANS, - num_classes=config.MODEL.NUM_CLASSES, - embed_dim=config.MODEL.SWIN.EMBED_DIM, - depths=config.MODEL.SWIN.DEPTHS, - num_heads=config.MODEL.SWIN.NUM_HEADS, - window_size=config.MODEL.SWIN.WINDOW_SIZE, - mlp_ratio=config.MODEL.SWIN.MLP_RATIO, - qkv_bias=config.MODEL.SWIN.QKV_BIAS, - qk_scale=config.MODEL.SWIN.QK_SCALE, - drop_rate=config.MODEL.DROP_RATE, - drop_path_rate=config.MODEL.DROP_PATH_RATE, - ape=config.MODEL.SWIN.APE, - norm_layer=layernorm, - patch_norm=config.MODEL.SWIN.PATCH_NORM, - use_checkpoint=config.TRAIN.USE_CHECKPOINT, - fused_window_process=config.FUSED_WINDOW_PROCESS) - elif model_type == 'swinv2': - model = SwinTransformerV2(img_size=config.DATA.IMG_SIZE, - patch_size=config.MODEL.SWINV2.PATCH_SIZE, - in_chans=config.MODEL.SWINV2.IN_CHANS, - num_classes=config.MODEL.NUM_CLASSES, - embed_dim=config.MODEL.SWINV2.EMBED_DIM, - depths=config.MODEL.SWINV2.DEPTHS, - num_heads=config.MODEL.SWINV2.NUM_HEADS, - window_size=config.MODEL.SWINV2.WINDOW_SIZE, - mlp_ratio=config.MODEL.SWINV2.MLP_RATIO, - qkv_bias=config.MODEL.SWINV2.QKV_BIAS, - drop_rate=config.MODEL.DROP_RATE, - drop_path_rate=config.MODEL.DROP_PATH_RATE, - ape=config.MODEL.SWINV2.APE, - patch_norm=config.MODEL.SWINV2.PATCH_NORM, - use_checkpoint=config.TRAIN.USE_CHECKPOINT, - pretrained_window_sizes=config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES) - elif model_type == 'swin_moe': - model = SwinTransformerMoE(img_size=config.DATA.IMG_SIZE, - patch_size=config.MODEL.SWIN_MOE.PATCH_SIZE, - in_chans=config.MODEL.SWIN_MOE.IN_CHANS, - num_classes=config.MODEL.NUM_CLASSES, - embed_dim=config.MODEL.SWIN_MOE.EMBED_DIM, - depths=config.MODEL.SWIN_MOE.DEPTHS, - num_heads=config.MODEL.SWIN_MOE.NUM_HEADS, - window_size=config.MODEL.SWIN_MOE.WINDOW_SIZE, - mlp_ratio=config.MODEL.SWIN_MOE.MLP_RATIO, - qkv_bias=config.MODEL.SWIN_MOE.QKV_BIAS, - qk_scale=config.MODEL.SWIN_MOE.QK_SCALE, - drop_rate=config.MODEL.DROP_RATE, - drop_path_rate=config.MODEL.DROP_PATH_RATE, - ape=config.MODEL.SWIN_MOE.APE, - patch_norm=config.MODEL.SWIN_MOE.PATCH_NORM, - mlp_fc2_bias=config.MODEL.SWIN_MOE.MLP_FC2_BIAS, - init_std=config.MODEL.SWIN_MOE.INIT_STD, - use_checkpoint=config.TRAIN.USE_CHECKPOINT, - pretrained_window_sizes=config.MODEL.SWIN_MOE.PRETRAINED_WINDOW_SIZES, - moe_blocks=config.MODEL.SWIN_MOE.MOE_BLOCKS, - num_local_experts=config.MODEL.SWIN_MOE.NUM_LOCAL_EXPERTS, - top_value=config.MODEL.SWIN_MOE.TOP_VALUE, - capacity_factor=config.MODEL.SWIN_MOE.CAPACITY_FACTOR, - cosine_router=config.MODEL.SWIN_MOE.COSINE_ROUTER, - normalize_gate=config.MODEL.SWIN_MOE.NORMALIZE_GATE, - use_bpr=config.MODEL.SWIN_MOE.USE_BPR, - is_gshard_loss=config.MODEL.SWIN_MOE.IS_GSHARD_LOSS, - gate_noise=config.MODEL.SWIN_MOE.GATE_NOISE, - cosine_router_dim=config.MODEL.SWIN_MOE.COSINE_ROUTER_DIM, - cosine_router_init_t=config.MODEL.SWIN_MOE.COSINE_ROUTER_INIT_T, - moe_drop=config.MODEL.SWIN_MOE.MOE_DROP, - aux_loss_weight=config.MODEL.SWIN_MOE.AUX_LOSS_WEIGHT) - elif model_type == 'swin_mlp': - model = SwinMLP(img_size=config.DATA.IMG_SIZE, - patch_size=config.MODEL.SWIN_MLP.PATCH_SIZE, - in_chans=config.MODEL.SWIN_MLP.IN_CHANS, - num_classes=config.MODEL.NUM_CLASSES, - embed_dim=config.MODEL.SWIN_MLP.EMBED_DIM, - depths=config.MODEL.SWIN_MLP.DEPTHS, - num_heads=config.MODEL.SWIN_MLP.NUM_HEADS, - window_size=config.MODEL.SWIN_MLP.WINDOW_SIZE, - mlp_ratio=config.MODEL.SWIN_MLP.MLP_RATIO, - drop_rate=config.MODEL.DROP_RATE, - drop_path_rate=config.MODEL.DROP_PATH_RATE, - ape=config.MODEL.SWIN_MLP.APE, - patch_norm=config.MODEL.SWIN_MLP.PATCH_NORM, - use_checkpoint=config.TRAIN.USE_CHECKPOINT) + if model_type == "swin": + model = SwinTransformer( + img_size=config.DATA.IMG_SIZE, + patch_size=config.MODEL.SWIN.PATCH_SIZE, + in_chans=config.MODEL.SWIN.IN_CHANS, + num_classes=config.MODEL.NUM_CLASSES, + embed_dim=config.MODEL.SWIN.EMBED_DIM, + depths=config.MODEL.SWIN.DEPTHS, + num_heads=config.MODEL.SWIN.NUM_HEADS, + window_size=config.MODEL.SWIN.WINDOW_SIZE, + mlp_ratio=config.MODEL.SWIN.MLP_RATIO, + qkv_bias=config.MODEL.SWIN.QKV_BIAS, + qk_scale=config.MODEL.SWIN.QK_SCALE, + drop_rate=config.MODEL.DROP_RATE, + drop_path_rate=config.MODEL.DROP_PATH_RATE, + ape=config.MODEL.SWIN.APE, + norm_layer=layernorm, + patch_norm=config.MODEL.SWIN.PATCH_NORM, + use_checkpoint=config.TRAIN.USE_CHECKPOINT, + fused_window_process=config.FUSED_WINDOW_PROCESS, + ) + elif model_type == "swinv2": + model = SwinTransformerV2( + img_size=config.DATA.IMG_SIZE, + patch_size=config.MODEL.SWINV2.PATCH_SIZE, + in_chans=config.MODEL.SWINV2.IN_CHANS, + num_classes=config.MODEL.NUM_CLASSES, + embed_dim=config.MODEL.SWINV2.EMBED_DIM, + depths=config.MODEL.SWINV2.DEPTHS, + num_heads=config.MODEL.SWINV2.NUM_HEADS, + window_size=config.MODEL.SWINV2.WINDOW_SIZE, + mlp_ratio=config.MODEL.SWINV2.MLP_RATIO, + qkv_bias=config.MODEL.SWINV2.QKV_BIAS, + drop_rate=config.MODEL.DROP_RATE, + drop_path_rate=config.MODEL.DROP_PATH_RATE, + ape=config.MODEL.SWINV2.APE, + patch_norm=config.MODEL.SWINV2.PATCH_NORM, + use_checkpoint=config.TRAIN.USE_CHECKPOINT, + pretrained_window_sizes=config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES, + ) + elif model_type == "swin_moe": + model = SwinTransformerMoE( + img_size=config.DATA.IMG_SIZE, + patch_size=config.MODEL.SWIN_MOE.PATCH_SIZE, + in_chans=config.MODEL.SWIN_MOE.IN_CHANS, + num_classes=config.MODEL.NUM_CLASSES, + embed_dim=config.MODEL.SWIN_MOE.EMBED_DIM, + depths=config.MODEL.SWIN_MOE.DEPTHS, + num_heads=config.MODEL.SWIN_MOE.NUM_HEADS, + window_size=config.MODEL.SWIN_MOE.WINDOW_SIZE, + mlp_ratio=config.MODEL.SWIN_MOE.MLP_RATIO, + qkv_bias=config.MODEL.SWIN_MOE.QKV_BIAS, + qk_scale=config.MODEL.SWIN_MOE.QK_SCALE, + drop_rate=config.MODEL.DROP_RATE, + drop_path_rate=config.MODEL.DROP_PATH_RATE, + ape=config.MODEL.SWIN_MOE.APE, + patch_norm=config.MODEL.SWIN_MOE.PATCH_NORM, + mlp_fc2_bias=config.MODEL.SWIN_MOE.MLP_FC2_BIAS, + init_std=config.MODEL.SWIN_MOE.INIT_STD, + use_checkpoint=config.TRAIN.USE_CHECKPOINT, + pretrained_window_sizes=config.MODEL.SWIN_MOE.PRETRAINED_WINDOW_SIZES, + moe_blocks=config.MODEL.SWIN_MOE.MOE_BLOCKS, + num_local_experts=config.MODEL.SWIN_MOE.NUM_LOCAL_EXPERTS, + top_value=config.MODEL.SWIN_MOE.TOP_VALUE, + capacity_factor=config.MODEL.SWIN_MOE.CAPACITY_FACTOR, + cosine_router=config.MODEL.SWIN_MOE.COSINE_ROUTER, + normalize_gate=config.MODEL.SWIN_MOE.NORMALIZE_GATE, + use_bpr=config.MODEL.SWIN_MOE.USE_BPR, + is_gshard_loss=config.MODEL.SWIN_MOE.IS_GSHARD_LOSS, + gate_noise=config.MODEL.SWIN_MOE.GATE_NOISE, + cosine_router_dim=config.MODEL.SWIN_MOE.COSINE_ROUTER_DIM, + cosine_router_init_t=config.MODEL.SWIN_MOE.COSINE_ROUTER_INIT_T, + moe_drop=config.MODEL.SWIN_MOE.MOE_DROP, + aux_loss_weight=config.MODEL.SWIN_MOE.AUX_LOSS_WEIGHT, + ) + elif model_type == "swin_mlp": + model = SwinMLP( + img_size=config.DATA.IMG_SIZE, + patch_size=config.MODEL.SWIN_MLP.PATCH_SIZE, + in_chans=config.MODEL.SWIN_MLP.IN_CHANS, + num_classes=config.MODEL.NUM_CLASSES, + embed_dim=config.MODEL.SWIN_MLP.EMBED_DIM, + depths=config.MODEL.SWIN_MLP.DEPTHS, + num_heads=config.MODEL.SWIN_MLP.NUM_HEADS, + window_size=config.MODEL.SWIN_MLP.WINDOW_SIZE, + mlp_ratio=config.MODEL.SWIN_MLP.MLP_RATIO, + drop_rate=config.MODEL.DROP_RATE, + drop_path_rate=config.MODEL.DROP_PATH_RATE, + ape=config.MODEL.SWIN_MLP.APE, + patch_norm=config.MODEL.SWIN_MLP.PATCH_NORM, + use_checkpoint=config.TRAIN.USE_CHECKPOINT, + ) else: raise NotImplementedError(f"Unkown model: {model_type}") diff --git a/models/simmim.py b/models/simmim.py index fc482b80..731a1c06 100644 --- a/models/simmim.py +++ b/models/simmim.py @@ -1,5 +1,3 @@ - - # -------------------------------------------------------- # SimMIM # Copyright (c) 2021 Microsoft @@ -20,21 +18,41 @@ def norm_targets(targets, patch_size): assert patch_size % 2 == 1 - + targets_ = targets targets_count = torch.ones_like(targets) - targets_square = targets ** 2. - - targets_mean = F.avg_pool2d(targets, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=False) - targets_square_mean = F.avg_pool2d(targets_square, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=False) - targets_count = F.avg_pool2d(targets_count, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=True) * (patch_size ** 2) - - targets_var = (targets_square_mean - targets_mean ** 2.) * (targets_count / (targets_count - 1)) - targets_var = torch.clamp(targets_var, min=0.) - - targets_ = (targets_ - targets_mean) / (targets_var + 1.e-6) ** 0.5 - + targets_square = targets**2.0 + + targets_mean = F.avg_pool2d( + targets, + kernel_size=patch_size, + stride=1, + padding=patch_size // 2, + count_include_pad=False, + ) + targets_square_mean = F.avg_pool2d( + targets_square, + kernel_size=patch_size, + stride=1, + padding=patch_size // 2, + count_include_pad=False, + ) + targets_count = F.avg_pool2d( + targets_count, + kernel_size=patch_size, + stride=1, + padding=patch_size // 2, + count_include_pad=True, + ) * (patch_size**2) + + targets_var = (targets_square_mean - targets_mean**2.0) * ( + targets_count / (targets_count - 1) + ) + targets_var = torch.clamp(targets_var, min=0.0) + + targets_ = (targets_ - targets_mean) / (targets_var + 1.0e-6) ** 0.5 + return targets_ @@ -45,7 +63,7 @@ def __init__(self, **kwargs): assert self.num_classes == 0 self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) - trunc_normal_(self.mask_token, mean=0., std=.02) + trunc_normal_(self.mask_token, mean=0.0, std=0.02) def forward(self, x, mask): x = self.patch_embed(x) @@ -55,7 +73,7 @@ def forward(self, x, mask): mask_tokens = self.mask_token.expand(B, L, -1) w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens) - x = x * (1. - w) + mask_tokens * w + x = x * (1.0 - w) + mask_tokens * w if self.ape: x = x + self.absolute_pos_embed @@ -67,13 +85,13 @@ def forward(self, x, mask): x = x.transpose(1, 2) B, C, L = x.shape - H = W = int(L ** 0.5) + H = W = int(L**0.5) x = x.reshape(B, C, H, W) return x @torch.jit.ignore def no_weight_decay(self): - return super().no_weight_decay() | {'mask_token'} + return super().no_weight_decay() | {"mask_token"} class SwinTransformerV2ForSimMIM(SwinTransformerV2): @@ -83,7 +101,7 @@ def __init__(self, **kwargs): assert self.num_classes == 0 self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) - trunc_normal_(self.mask_token, mean=0., std=.02) + trunc_normal_(self.mask_token, mean=0.0, std=0.02) def forward(self, x, mask): x = self.patch_embed(x) @@ -93,7 +111,7 @@ def forward(self, x, mask): mask_tokens = self.mask_token.expand(B, L, -1) w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens) - x = x * (1. - w) + mask_tokens * w + x = x * (1.0 - w) + mask_tokens * w if self.ape: x = x + self.absolute_pos_embed @@ -105,13 +123,13 @@ def forward(self, x, mask): x = x.transpose(1, 2) B, C, L = x.shape - H = W = int(L ** 0.5) + H = W = int(L**0.5) x = x.reshape(B, C, H, W) return x @torch.jit.ignore def no_weight_decay(self): - return super().no_weight_decay() | {'mask_token'} + return super().no_weight_decay() | {"mask_token"} class SimMIM(nn.Module): @@ -124,7 +142,9 @@ def __init__(self, config, encoder, encoder_stride, in_chans, patch_size): self.decoder = nn.Sequential( nn.Conv2d( in_channels=self.encoder.num_features, - out_channels=self.encoder_stride ** 2 * 3, kernel_size=1), + out_channels=self.encoder_stride**2 * 3, + kernel_size=1, + ), nn.PixelShuffle(self.encoder_stride), ) @@ -135,32 +155,37 @@ def forward(self, x, mask): z = self.encoder(x, mask) x_rec = self.decoder(z) - mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous() - + mask = ( + mask.repeat_interleave(self.patch_size, 1) + .repeat_interleave(self.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) + # norm target as prompted if self.config.NORM_TARGET.ENABLE: x = norm_targets(x, self.config.NORM_TARGET.PATCH_SIZE) - - loss_recon = F.l1_loss(x, x_rec, reduction='none') + + loss_recon = F.l1_loss(x, x_rec, reduction="none") loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans return loss @torch.jit.ignore def no_weight_decay(self): - if hasattr(self.encoder, 'no_weight_decay'): - return {'encoder.' + i for i in self.encoder.no_weight_decay()} + if hasattr(self.encoder, "no_weight_decay"): + return {"encoder." + i for i in self.encoder.no_weight_decay()} return {} @torch.jit.ignore def no_weight_decay_keywords(self): - if hasattr(self.encoder, 'no_weight_decay_keywords'): - return {'encoder.' + i for i in self.encoder.no_weight_decay_keywords()} + if hasattr(self.encoder, "no_weight_decay_keywords"): + return {"encoder." + i for i in self.encoder.no_weight_decay_keywords()} return {} def build_simmim(config): model_type = config.MODEL.TYPE - if model_type == 'swin': + if model_type == "swin": encoder = SwinTransformerForSimMIM( img_size=config.DATA.IMG_SIZE, patch_size=config.MODEL.SWIN.PATCH_SIZE, @@ -177,11 +202,12 @@ def build_simmim(config): drop_path_rate=config.MODEL.DROP_PATH_RATE, ape=config.MODEL.SWIN.APE, patch_norm=config.MODEL.SWIN.PATCH_NORM, - use_checkpoint=config.TRAIN.USE_CHECKPOINT) + use_checkpoint=config.TRAIN.USE_CHECKPOINT, + ) encoder_stride = 32 in_chans = config.MODEL.SWIN.IN_CHANS patch_size = config.MODEL.SWIN.PATCH_SIZE - elif model_type == 'swinv2': + elif model_type == "swinv2": encoder = SwinTransformerV2ForSimMIM( img_size=config.DATA.IMG_SIZE, patch_size=config.MODEL.SWINV2.PATCH_SIZE, @@ -197,13 +223,20 @@ def build_simmim(config): drop_path_rate=config.MODEL.DROP_PATH_RATE, ape=config.MODEL.SWINV2.APE, patch_norm=config.MODEL.SWINV2.PATCH_NORM, - use_checkpoint=config.TRAIN.USE_CHECKPOINT) + use_checkpoint=config.TRAIN.USE_CHECKPOINT, + ) encoder_stride = 32 in_chans = config.MODEL.SWINV2.IN_CHANS patch_size = config.MODEL.SWINV2.PATCH_SIZE else: raise NotImplementedError(f"Unknown pre-train model: {model_type}") - model = SimMIM(config=config.MODEL.SIMMIM, encoder=encoder, encoder_stride=encoder_stride, in_chans=in_chans, patch_size=patch_size) + model = SimMIM( + config=config.MODEL.SIMMIM, + encoder=encoder, + encoder_stride=encoder_stride, + in_chans=in_chans, + patch_size=patch_size, + ) - return model \ No newline at end of file + return model diff --git a/models/swin_mlp.py b/models/swin_mlp.py index 115c43cd..bd0ae21b 100644 --- a/models/swin_mlp.py +++ b/models/swin_mlp.py @@ -13,7 +13,14 @@ class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -42,7 +49,9 @@ def window_partition(x, window_size): """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) return windows @@ -58,13 +67,15 @@ def window_reverse(windows, window_size, H, W): x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = windows.view( + B, H // window_size, W // window_size, window_size, window_size, -1 + ) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class SwinMLPBlock(nn.Module): - r""" Swin MLP Block. + r"""Swin MLP Block. Args: dim (int): Number of input channels. @@ -79,9 +90,19 @@ class SwinMLPBlock(nn.Module): norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): super().__init__() self.dim = dim self.input_resolution = input_resolution @@ -93,22 +114,35 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + assert ( + 0 <= self.shift_size < self.window_size + ), "shift_size must in 0-window_size" - self.padding = [self.window_size - self.shift_size, self.shift_size, - self.window_size - self.shift_size, self.shift_size] # P_l,P_r,P_t,P_b + self.padding = [ + self.window_size - self.shift_size, + self.shift_size, + self.window_size - self.shift_size, + self.shift_size, + ] # P_l,P_r,P_t,P_b self.norm1 = norm_layer(dim) # use group convolution to implement multi-head MLP - self.spatial_mlp = nn.Conv1d(self.num_heads * self.window_size ** 2, - self.num_heads * self.window_size ** 2, - kernel_size=1, - groups=self.num_heads) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.spatial_mlp = nn.Conv1d( + self.num_heads * self.window_size**2, + self.num_heads * self.window_size**2, + kernel_size=1, + groups=self.num_heads, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) def forward(self, x): H, W = self.input_resolution @@ -128,22 +162,42 @@ def forward(self, x): _, _H, _W, _ = shifted_x.shape # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C # Window/Shifted-Window Spatial MLP - x_windows_heads = x_windows.view(-1, self.window_size * self.window_size, self.num_heads, C // self.num_heads) - x_windows_heads = x_windows_heads.transpose(1, 2) # nW*B, nH, window_size*window_size, C//nH - x_windows_heads = x_windows_heads.reshape(-1, self.num_heads * self.window_size * self.window_size, - C // self.num_heads) - spatial_mlp_windows = self.spatial_mlp(x_windows_heads) # nW*B, nH*window_size*window_size, C//nH - spatial_mlp_windows = spatial_mlp_windows.view(-1, self.num_heads, self.window_size * self.window_size, - C // self.num_heads).transpose(1, 2) - spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size * self.window_size, C) + x_windows_heads = x_windows.view( + -1, self.window_size * self.window_size, self.num_heads, C // self.num_heads + ) + x_windows_heads = x_windows_heads.transpose( + 1, 2 + ) # nW*B, nH, window_size*window_size, C//nH + x_windows_heads = x_windows_heads.reshape( + -1, + self.num_heads * self.window_size * self.window_size, + C // self.num_heads, + ) + spatial_mlp_windows = self.spatial_mlp( + x_windows_heads + ) # nW*B, nH*window_size*window_size, C//nH + spatial_mlp_windows = spatial_mlp_windows.view( + -1, self.num_heads, self.window_size * self.window_size, C // self.num_heads + ).transpose(1, 2) + spatial_mlp_windows = spatial_mlp_windows.reshape( + -1, self.window_size * self.window_size, C + ) # merge windows - spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(spatial_mlp_windows, self.window_size, _H, _W) # B H' W' C + spatial_mlp_windows = spatial_mlp_windows.reshape( + -1, self.window_size, self.window_size, C + ) + shifted_x = window_reverse( + spatial_mlp_windows, self.window_size, _H, _W + ) # B H' W' C # reverse shift if self.shift_size > 0: @@ -160,8 +214,10 @@ def forward(self, x): return x def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + ) def flops(self): flops = 0 @@ -174,7 +230,12 @@ def flops(self): nW = (H / self.window_size + 1) * (W / self.window_size + 1) else: nW = H * W / self.window_size / self.window_size - flops += nW * self.dim * (self.window_size * self.window_size) * (self.window_size * self.window_size) + flops += ( + nW + * self.dim + * (self.window_size * self.window_size) + * (self.window_size * self.window_size) + ) # mlp flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio # norm2 @@ -183,7 +244,7 @@ def flops(self): class PatchMerging(nn.Module): - r""" Patch Merging Layer. + r"""Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. @@ -232,7 +293,7 @@ def flops(self): class BasicLayer(nn.Module): - """ A basic Swin MLP layer for one stage. + """A basic Swin MLP layer for one stage. Args: dim (int): Number of input channels. @@ -248,9 +309,20 @@ class BasicLayer(nn.Module): use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., drop=0., drop_path=0., - norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4.0, + drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + ): super().__init__() self.dim = dim @@ -259,19 +331,30 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size, self.use_checkpoint = use_checkpoint # build blocks - self.blocks = nn.ModuleList([ - SwinMLPBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - drop=drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) - for i in range(depth)]) + self.blocks = nn.ModuleList( + [ + SwinMLPBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) + else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) # patch merging layer if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + self.downsample = downsample( + input_resolution, dim=dim, norm_layer=norm_layer + ) else: self.downsample = None @@ -298,7 +381,7 @@ def flops(self): class PatchEmbed(nn.Module): - r""" Image to Patch Embedding + r"""Image to Patch Embedding Args: img_size (int): Image size. Default: 224. @@ -308,11 +391,16 @@ class PatchEmbed(nn.Module): norm_layer (nn.Module, optional): Normalization layer. Default: None """ - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + def __init__( + self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None + ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + patches_resolution = [ + img_size[0] // patch_size[0], + img_size[1] // patch_size[1], + ] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution @@ -321,7 +409,9 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_la self.in_chans = in_chans self.embed_dim = embed_dim - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: @@ -330,8 +420,9 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_la def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C if self.norm is not None: x = self.norm(x) @@ -339,14 +430,20 @@ def forward(self, x): def flops(self): Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + flops = ( + Ho + * Wo + * self.embed_dim + * self.in_chans + * (self.patch_size[0] * self.patch_size[1]) + ) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops class SwinMLP(nn.Module): - r""" Swin MLP + r"""Swin MLP Args: img_size (int | tuple(int)): Input image size. Default 224 @@ -366,11 +463,25 @@ class SwinMLP(nn.Module): use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, - embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=7, mlp_ratio=4., drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, **kwargs): + def __init__( + self, + img_size=224, + patch_size=4, + in_chans=3, + num_classes=1000, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4.0, + drop_rate=0.0, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + **kwargs, + ): super().__init__() self.num_classes = num_classes @@ -383,48 +494,64 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, # split image into non-overlapping patches self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + ) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # absolute position embedding if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, embed_dim) + ) + trunc_normal_(self.absolute_pos_embed, std=0.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): - layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), - input_resolution=(patches_resolution[0] // (2 ** i_layer), - patches_resolution[1] // (2 ** i_layer)), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - drop=drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, - use_checkpoint=use_checkpoint) + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + input_resolution=( + patches_resolution[0] // (2**i_layer), + patches_resolution[1] // (2**i_layer), + ), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + drop=drop_rate, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + ) self.layers.append(layer) self.norm = norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool1d(1) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head = ( + nn.Linear(self.num_features, num_classes) + if num_classes > 0 + else nn.Identity() + ) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Linear, nn.Conv1d)): - trunc_normal_(m.weight, std=.02) + trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): @@ -433,11 +560,11 @@ def _init_weights(self, m): @torch.jit.ignore def no_weight_decay(self): - return {'absolute_pos_embed'} + return {"absolute_pos_embed"} @torch.jit.ignore def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} + return {"relative_position_bias_table"} def forward_features(self, x): x = self.patch_embed(x) @@ -463,6 +590,11 @@ def flops(self): flops += self.patch_embed.flops() for i, layer in enumerate(self.layers): flops += layer.flops() - flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += ( + self.num_features + * self.patches_resolution[0] + * self.patches_resolution[1] + // (2**self.num_layers) + ) flops += self.num_features * self.num_classes return flops diff --git a/models/swin_transformer.py b/models/swin_transformer.py index dde06bc5..a7ba5f15 100644 --- a/models/swin_transformer.py +++ b/models/swin_transformer.py @@ -11,20 +11,33 @@ from timm.models.layers import DropPath, to_2tuple, trunc_normal_ try: - import os, sys + import os + import sys - kernel_path = os.path.abspath(os.path.join('..')) + kernel_path = os.path.abspath(os.path.join("..")) sys.path.append(kernel_path) - from kernels.window_process.window_process import WindowProcess, WindowProcessReverse + from kernels.window_process.window_process import ( + WindowProcess, + WindowProcessReverse, + ) except: WindowProcess = None WindowProcessReverse = None - print("[Warning] Fused window process have not been installed. Please refer to get_started.md for installation.") + print( + "[Warning] Fused window process have not been installed. Please refer to get_started.md for installation." + ) class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -53,7 +66,9 @@ def window_partition(x, window_size): """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) return windows @@ -69,13 +84,15 @@ def window_reverse(windows, window_size, H, W): x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = windows.view( + B, H // window_size, W // window_size, window_size, window_size, -1 + ) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. + r"""Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: @@ -88,26 +105,40 @@ class WindowAttention(nn.Module): proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = qk_scale or head_dim**-0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 @@ -119,7 +150,7 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) - trunc_normal_(self.relative_position_bias_table, std=.02) + trunc_normal_(self.relative_position_bias_table, std=0.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): @@ -129,20 +160,37 @@ def forward(self, x, mask=None): mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + qkv = ( + self.qkv(x) + .reshape(B_, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1, + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( + 1 + ).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: @@ -156,7 +204,7 @@ def forward(self, x, mask=None): return x def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}" def flops(self, N): # calculate flops for 1 window with token length of N @@ -173,7 +221,7 @@ def flops(self, N): class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. + r"""Swin Transformer Block. Args: dim (int): Number of input channels. @@ -192,10 +240,23 @@ class SwinTransformerBlock(nn.Module): fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False """ - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, - fused_window_process=False): + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + fused_window_process=False, + ): super().__init__() self.dim = dim self.input_resolution = input_resolution @@ -207,38 +268,59 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + assert ( + 0 <= self.shift_size < self.window_size + ), "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) if self.shift_size > 0: # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + attn_mask = attn_mask.masked_fill( + attn_mask != 0, float(-100.0) + ).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None @@ -257,20 +339,32 @@ def forward(self, x): # cyclic shift if self.shift_size > 0: if not self.fused_window_process: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) + ) # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C else: - x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size) + x_windows = WindowProcess.apply( + x, B, H, W, C, -self.shift_size, self.window_size + ) else: shifted_x = x # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows = self.attn( + x_windows, mask=self.attn_mask + ) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) @@ -278,12 +372,20 @@ def forward(self, x): # reverse cyclic shift if self.shift_size > 0: if not self.fused_window_process: - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + shifted_x = window_reverse( + attn_windows, self.window_size, H, W + ) # B H' W' C + x = torch.roll( + shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) + ) else: - x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size) + x = WindowProcessReverse.apply( + attn_windows, B, H, W, C, self.shift_size, self.window_size + ) else: - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + shifted_x = window_reverse( + attn_windows, self.window_size, H, W + ) # B H' W' C x = shifted_x x = x.view(B, H * W, C) x = shortcut + self.drop_path(x) @@ -294,8 +396,10 @@ def forward(self, x): return x def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + ) def flops(self): flops = 0 @@ -313,7 +417,7 @@ def flops(self): class PatchMerging(nn.Module): - r""" Patch Merging Layer. + r"""Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. @@ -362,7 +466,7 @@ def flops(self): class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. + """A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. @@ -382,10 +486,24 @@ class BasicLayer(nn.Module): fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False """ - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - fused_window_process=False): + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + fused_window_process=False, + ): super().__init__() self.dim = dim @@ -394,21 +512,34 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size, self.use_checkpoint = use_checkpoint # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer, - fused_window_process=fused_window_process) - for i in range(depth)]) + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) + else drop_path, + norm_layer=norm_layer, + fused_window_process=fused_window_process, + ) + for i in range(depth) + ] + ) # patch merging layer if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + self.downsample = downsample( + input_resolution, dim=dim, norm_layer=norm_layer + ) else: self.downsample = None @@ -435,7 +566,7 @@ def flops(self): class PatchEmbed(nn.Module): - r""" Image to Patch Embedding + r"""Image to Patch Embedding Args: img_size (int): Image size. Default: 224. @@ -445,11 +576,16 @@ class PatchEmbed(nn.Module): norm_layer (nn.Module, optional): Normalization layer. Default: None """ - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + def __init__( + self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None + ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + patches_resolution = [ + img_size[0] // patch_size[0], + img_size[1] // patch_size[1], + ] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution @@ -458,7 +594,9 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_la self.in_chans = in_chans self.embed_dim = embed_dim - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: @@ -467,8 +605,9 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_la def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C if self.norm is not None: x = self.norm(x) @@ -476,14 +615,20 @@ def forward(self, x): def flops(self): Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + flops = ( + Ho + * Wo + * self.embed_dim + * self.in_chans + * (self.patch_size[0] * self.patch_size[1]) + ) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops class SwinTransformer(nn.Module): - r""" Swin Transformer + r"""Swin Transformer A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 @@ -509,12 +654,29 @@ class SwinTransformer(nn.Module): fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False """ - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, - embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, fused_window_process=False, **kwargs): + def __init__( + self, + img_size=224, + patch_size=4, + in_chans=3, + num_classes=1000, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + fused_window_process=False, + **kwargs, + ): super().__init__() self.num_classes = num_classes @@ -527,50 +689,68 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, # split image into non-overlapping patches self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + ) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # absolute position embedding if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, embed_dim) + ) + trunc_normal_(self.absolute_pos_embed, std=0.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): - layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), - input_resolution=(patches_resolution[0] // (2 ** i_layer), - patches_resolution[1] // (2 ** i_layer)), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, - use_checkpoint=use_checkpoint, - fused_window_process=fused_window_process) + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + input_resolution=( + patches_resolution[0] // (2**i_layer), + patches_resolution[1] // (2**i_layer), + ), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + fused_window_process=fused_window_process, + ) self.layers.append(layer) self.norm = norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool1d(1) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head = ( + nn.Linear(self.num_features, num_classes) + if num_classes > 0 + else nn.Identity() + ) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) + trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): @@ -579,11 +759,11 @@ def _init_weights(self, m): @torch.jit.ignore def no_weight_decay(self): - return {'absolute_pos_embed'} + return {"absolute_pos_embed"} @torch.jit.ignore def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} + return {"relative_position_bias_table"} def forward_features(self, x): x = self.patch_embed(x) @@ -609,6 +789,11 @@ def flops(self): flops += self.patch_embed.flops() for i, layer in enumerate(self.layers): flops += layer.flops() - flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += ( + self.num_features + * self.patches_resolution[0] + * self.patches_resolution[1] + // (2**self.num_layers) + ) flops += self.num_features * self.num_classes return flops diff --git a/models/swin_transformer_moe.py b/models/swin_transformer_moe.py index e9f26d43..a714a029 100644 --- a/models/swin_transformer_moe.py +++ b/models/swin_transformer_moe.py @@ -5,24 +5,33 @@ # Written by Ze Liu # -------------------------------------------------------- +import numpy as np import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F -import torch.distributed as dist import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ -import numpy as np try: from tutel import moe as tutel_moe except: tutel_moe = None - print("Tutel has not been installed. To use Swin-MoE, please install Tutel; otherwise, just ignore this.") + print( + "Tutel has not been installed. To use Swin-MoE, please install Tutel; otherwise, just ignore this." + ) class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., - mlp_fc2_bias=True): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + mlp_fc2_bias=True, + ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -41,10 +50,24 @@ def forward(self, x): class MoEMlp(nn.Module): - def __init__(self, in_features, hidden_features, num_local_experts, top_value, capacity_factor=1.25, - cosine_router=False, normalize_gate=False, use_bpr=True, is_gshard_loss=True, - gate_noise=1.0, cosine_router_dim=256, cosine_router_init_t=0.5, moe_drop=0.0, init_std=0.02, - mlp_fc2_bias=True): + def __init__( + self, + in_features, + hidden_features, + num_local_experts, + top_value, + capacity_factor=1.25, + cosine_router=False, + normalize_gate=False, + use_bpr=True, + is_gshard_loss=True, + gate_noise=1.0, + cosine_router_dim=256, + cosine_router_init_t=0.5, + moe_drop=0.0, + init_std=0.02, + mlp_fc2_bias=True, + ): super().__init__() self.in_features = in_features @@ -62,23 +85,30 @@ def __init__(self, in_features, hidden_features, num_local_experts, top_value, c self._dropout = nn.Dropout(p=moe_drop) - _gate_type = {'type': 'cosine_top' if cosine_router else 'top', - 'k': top_value, 'capacity_factor': capacity_factor, - 'gate_noise': gate_noise, 'fp32_gate': True} + _gate_type = { + "type": "cosine_top" if cosine_router else "top", + "k": top_value, + "capacity_factor": capacity_factor, + "gate_noise": gate_noise, + "fp32_gate": True, + } if cosine_router: - _gate_type['proj_dim'] = cosine_router_dim - _gate_type['init_t'] = cosine_router_init_t + _gate_type["proj_dim"] = cosine_router_dim + _gate_type["init_t"] = cosine_router_init_t self._moe_layer = tutel_moe.moe_layer( gate_type=_gate_type, model_dim=in_features, - experts={'type': 'ffn', 'count_per_node': num_local_experts, 'hidden_size_per_expert': hidden_features, - 'activation_fn': lambda x: self._dropout(F.gelu(x))}, - scan_expert_func=lambda name, param: setattr(param, 'skip_allreduce', True), + experts={ + "type": "ffn", + "count_per_node": num_local_experts, + "hidden_size_per_expert": hidden_features, + "activation_fn": lambda x: self._dropout(F.gelu(x)), + }, + scan_expert_func=lambda name, param: setattr(param, "skip_allreduce", True), seeds=(1, self.dist_rank + 1, self.dist_rank + 1), batch_prioritized_routing=use_bpr, normalize_gate=normalize_gate, is_gshard_loss=is_gshard_loss, - ) if not self.mlp_fc2_bias: self._moe_layer.experts.batched_fc2_bias.requires_grad = False @@ -88,10 +118,12 @@ def forward(self, x): return x, x.l_aux def extra_repr(self) -> str: - return f'[Statistics-{self.dist_rank}] param count for MoE, ' \ - f'in_features = {self.in_features}, hidden_features = {self.hidden_features}, ' \ - f'num_local_experts = {self.num_local_experts}, top_value = {self.top_value}, ' \ - f'cosine_router={self.cosine_router} normalize_gate={self.normalize_gate}, use_bpr = {self.use_bpr}' + return ( + f"[Statistics-{self.dist_rank}] param count for MoE, " + f"in_features = {self.in_features}, hidden_features = {self.hidden_features}, " + f"num_local_experts = {self.num_local_experts}, top_value = {self.top_value}, " + f"cosine_router={self.cosine_router} normalize_gate={self.normalize_gate}, use_bpr = {self.use_bpr}" + ) def _init_weights(self): if hasattr(self._moe_layer, "experts"): @@ -112,7 +144,9 @@ def window_partition(x, window_size): """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) return windows @@ -128,13 +162,15 @@ def window_reverse(windows, window_size, H, W): x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = windows.view( + B, H // window_size, W // window_size, window_size, window_size, -1 + ) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. + r"""Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: @@ -148,8 +184,17 @@ class WindowAttention(nn.Module): pretrained_window_size (tuple[int]): The height and width of the window in pre-training. """ - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., - pretrained_window_size=[0, 0]): + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + pretrained_window_size=[0, 0], + ): super().__init__() self.dim = dim @@ -158,28 +203,40 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = qk_scale or head_dim**-0.5 # mlp to generate continuous relative position bias - self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), - nn.ReLU(inplace=True), - nn.Linear(512, num_heads, bias=False)) + self.cpb_mlp = nn.Sequential( + nn.Linear(2, 512, bias=True), + nn.ReLU(inplace=True), + nn.Linear(512, num_heads, bias=False), + ) # get relative_coords_table - relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) - relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) - relative_coords_table = torch.stack( - torch.meshgrid([relative_coords_h, - relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + relative_coords_h = torch.arange( + -(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32 + ) + relative_coords_w = torch.arange( + -(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32 + ) + relative_coords_table = ( + torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) + .permute(1, 2, 0) + .contiguous() + .unsqueeze(0) + ) # 1, 2*Wh-1, 2*Ww-1, 2 if pretrained_window_size[0] > 0: - relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) - relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) + relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1 + relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1 else: - relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) - relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1 + relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1 relative_coords_table *= 8 # normalize to -8, 8 - relative_coords_table = torch.sign(relative_coords_table) * torch.log2( - torch.abs(relative_coords_table) + 1.0) / np.log2(8) + relative_coords_table = ( + torch.sign(relative_coords_table) + * torch.log2(torch.abs(relative_coords_table) + 1.0) + / np.log2(8) + ) self.register_buffer("relative_coords_table", relative_coords_table) @@ -188,8 +245,12 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 @@ -209,21 +270,40 @@ def forward(self, x, mask=None): mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + qkv = ( + self.qkv(x) + .reshape(B_, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) q = q * self.scale - attn = (q @ k.transpose(-2, -1)) + attn = q @ k.transpose(-2, -1) - relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) - relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view( + -1, self.num_heads + ) + relative_position_bias = relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1, + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( + 1 + ).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: @@ -237,8 +317,10 @@ def forward(self, x, mask=None): return x def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, ' \ - f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}' + return ( + f"dim={self.dim}, window_size={self.window_size}, " + f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}" + ) def flops(self, N): # calculate flops for 1 window with token length of N @@ -255,7 +337,7 @@ def flops(self, N): class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. + r"""Swin Transformer Block. Args: dim (int): Number of input channels. @@ -289,12 +371,37 @@ class SwinTransformerBlock(nn.Module): moe_drop (float): Dropout rate in MoE. Default: 0.0 """ - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, mlp_fc2_bias=True, init_std=0.02, pretrained_window_size=0, - is_moe=False, num_local_experts=1, top_value=1, capacity_factor=1.25, cosine_router=False, - normalize_gate=False, use_bpr=True, is_gshard_loss=True, gate_noise=1.0, - cosine_router_dim=256, cosine_router_init_t=0.5, moe_drop=0.0): + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + mlp_fc2_bias=True, + init_std=0.02, + pretrained_window_size=0, + is_moe=False, + num_local_experts=1, + top_value=1, + capacity_factor=1.25, + cosine_router=False, + normalize_gate=False, + use_bpr=True, + is_gshard_loss=True, + gate_noise=1.0, + cosine_router_dim=256, + cosine_router_init_t=0.5, + moe_drop=0.0, + ): super().__init__() self.dim = dim self.input_resolution = input_resolution @@ -310,57 +417,80 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + assert ( + 0 <= self.shift_size < self.window_size + ), "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, - pretrained_window_size=to_2tuple(pretrained_window_size)) + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + pretrained_window_size=to_2tuple(pretrained_window_size), + ) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) if self.is_moe: - self.mlp = MoEMlp(in_features=dim, - hidden_features=mlp_hidden_dim, - num_local_experts=num_local_experts, - top_value=top_value, - capacity_factor=capacity_factor, - cosine_router=cosine_router, - normalize_gate=normalize_gate, - use_bpr=use_bpr, - is_gshard_loss=is_gshard_loss, - gate_noise=gate_noise, - cosine_router_dim=cosine_router_dim, - cosine_router_init_t=cosine_router_init_t, - moe_drop=moe_drop, - mlp_fc2_bias=mlp_fc2_bias, - init_std=init_std) + self.mlp = MoEMlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + num_local_experts=num_local_experts, + top_value=top_value, + capacity_factor=capacity_factor, + cosine_router=cosine_router, + normalize_gate=normalize_gate, + use_bpr=use_bpr, + is_gshard_loss=is_gshard_loss, + gate_noise=gate_noise, + cosine_router_dim=cosine_router_dim, + cosine_router_init_t=cosine_router_init_t, + moe_drop=moe_drop, + mlp_fc2_bias=mlp_fc2_bias, + init_std=init_std, + ) else: - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, - mlp_fc2_bias=mlp_fc2_bias) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + mlp_fc2_bias=mlp_fc2_bias, + ) if self.shift_size > 0: # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + attn_mask = attn_mask.masked_fill( + attn_mask != 0, float(-100.0) + ).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None @@ -377,16 +507,24 @@ def forward(self, x): # cyclic shift if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) + ) else: shifted_x = x # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows = self.attn( + x_windows, mask=self.attn_mask + ) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) @@ -394,7 +532,9 @@ def forward(self, x): # reverse cyclic shift if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + x = torch.roll( + shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) + ) else: x = shifted_x x = x.view(B, H * W, C) @@ -412,8 +552,10 @@ def forward(self, x): return x def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + ) def flops(self): flops = 0 @@ -425,7 +567,16 @@ def flops(self): flops += nW * self.attn.flops(self.window_size * self.window_size) # mlp if self.is_moe: - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio * self.capacity_factor * self.top_value + flops += ( + 2 + * H + * W + * self.dim + * self.dim + * self.mlp_ratio + * self.capacity_factor + * self.top_value + ) else: flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio # norm2 @@ -434,7 +585,7 @@ def flops(self): class PatchMerging(nn.Module): - r""" Patch Merging Layer. + r"""Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. @@ -483,7 +634,7 @@ def flops(self): class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. + """A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. @@ -518,13 +669,38 @@ class BasicLayer(nn.Module): moe_drop (float): Dropout rate in MoE. Default: 0.0 """ - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, - mlp_fc2_bias=True, init_std=0.02, use_checkpoint=False, pretrained_window_size=0, - moe_block=[-1], num_local_experts=1, top_value=1, capacity_factor=1.25, cosine_router=False, - normalize_gate=False, use_bpr=True, is_gshard_loss=True, - cosine_router_dim=256, cosine_router_init_t=0.5, gate_noise=1.0, moe_drop=0.0): + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + mlp_fc2_bias=True, + init_std=0.02, + use_checkpoint=False, + pretrained_window_size=0, + moe_block=[-1], + num_local_experts=1, + top_value=1, + capacity_factor=1.25, + cosine_router=False, + normalize_gate=False, + use_bpr=True, + is_gshard_loss=True, + cosine_router_dim=256, + cosine_router_init_t=0.5, + gate_noise=1.0, + moe_drop=0.0, + ): super().__init__() self.dim = dim @@ -533,36 +709,48 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size, self.use_checkpoint = use_checkpoint # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer, - mlp_fc2_bias=mlp_fc2_bias, - init_std=init_std, - pretrained_window_size=pretrained_window_size, - - is_moe=True if i in moe_block else False, - num_local_experts=num_local_experts, - top_value=top_value, - capacity_factor=capacity_factor, - cosine_router=cosine_router, - normalize_gate=normalize_gate, - use_bpr=use_bpr, - is_gshard_loss=is_gshard_loss, - gate_noise=gate_noise, - cosine_router_dim=cosine_router_dim, - cosine_router_init_t=cosine_router_init_t, - moe_drop=moe_drop) - for i in range(depth)]) + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) + else drop_path, + norm_layer=norm_layer, + mlp_fc2_bias=mlp_fc2_bias, + init_std=init_std, + pretrained_window_size=pretrained_window_size, + is_moe=True if i in moe_block else False, + num_local_experts=num_local_experts, + top_value=top_value, + capacity_factor=capacity_factor, + cosine_router=cosine_router, + normalize_gate=normalize_gate, + use_bpr=use_bpr, + is_gshard_loss=is_gshard_loss, + gate_noise=gate_noise, + cosine_router_dim=cosine_router_dim, + cosine_router_init_t=cosine_router_init_t, + moe_drop=moe_drop, + ) + for i in range(depth) + ] + ) # patch merging layer if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + self.downsample = downsample( + input_resolution, dim=dim, norm_layer=norm_layer + ) else: self.downsample = None @@ -597,7 +785,7 @@ def flops(self): class PatchEmbed(nn.Module): - r""" Image to Patch Embedding + r"""Image to Patch Embedding Args: img_size (int): Image size. Default: 224. @@ -607,11 +795,16 @@ class PatchEmbed(nn.Module): norm_layer (nn.Module, optional): Normalization layer. Default: None """ - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + def __init__( + self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None + ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + patches_resolution = [ + img_size[0] // patch_size[0], + img_size[1] // patch_size[1], + ] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution @@ -620,7 +813,9 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_la self.in_chans = in_chans self.embed_dim = embed_dim - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: @@ -629,8 +824,9 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_la def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C if self.norm is not None: x = self.norm(x) @@ -638,14 +834,20 @@ def forward(self, x): def flops(self): Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + flops = ( + Ho + * Wo + * self.embed_dim + * self.in_chans + * (self.patch_size[0] * self.patch_size[1]) + ) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops class SwinTransformerMoE(nn.Module): - r""" Swin Transformer + r"""Swin Transformer A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 @@ -687,15 +889,44 @@ class SwinTransformerMoE(nn.Module): aux_loss_weight (float): auxiliary loss weight. Default: 0.1 """ - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, - embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - mlp_fc2_bias=True, init_std=0.02, use_checkpoint=False, pretrained_window_sizes=[0, 0, 0, 0], - moe_blocks=[[-1], [-1], [-1], [-1]], num_local_experts=1, top_value=1, capacity_factor=1.25, - cosine_router=False, normalize_gate=False, use_bpr=True, is_gshard_loss=True, gate_noise=1.0, - cosine_router_dim=256, cosine_router_init_t=0.5, moe_drop=0.0, aux_loss_weight=0.01, **kwargs): + def __init__( + self, + img_size=224, + patch_size=4, + in_chans=3, + num_classes=1000, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + mlp_fc2_bias=True, + init_std=0.02, + use_checkpoint=False, + pretrained_window_sizes=[0, 0, 0, 0], + moe_blocks=[[-1], [-1], [-1], [-1]], + num_local_experts=1, + top_value=1, + capacity_factor=1.25, + cosine_router=False, + normalize_gate=False, + use_bpr=True, + is_gshard_loss=True, + gate_noise=1.0, + cosine_router_dim=256, + cosine_router_init_t=0.5, + moe_drop=0.0, + aux_loss_weight=0.01, + **kwargs, + ): super().__init__() self._ddp_params_and_buffers_to_ignore = list() @@ -709,65 +940,87 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, self.init_std = init_std self.aux_loss_weight = aux_loss_weight self.num_local_experts = num_local_experts - self.global_experts = num_local_experts * dist.get_world_size() if num_local_experts > 0 \ + self.global_experts = ( + num_local_experts * dist.get_world_size() + if num_local_experts > 0 else dist.get_world_size() // (-num_local_experts) - self.sharded_count = (1.0 / num_local_experts) if num_local_experts > 0 else (-num_local_experts) + ) + self.sharded_count = ( + (1.0 / num_local_experts) if num_local_experts > 0 else (-num_local_experts) + ) # split image into non-overlapping patches self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + ) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # absolute position embedding if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, embed_dim) + ) trunc_normal_(self.absolute_pos_embed, std=self.init_std) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): - layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), - input_resolution=(patches_resolution[0] // (2 ** i_layer), - patches_resolution[1] // (2 ** i_layer)), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, - mlp_fc2_bias=mlp_fc2_bias, - init_std=init_std, - use_checkpoint=use_checkpoint, - pretrained_window_size=pretrained_window_sizes[i_layer], - - moe_block=moe_blocks[i_layer], - num_local_experts=num_local_experts, - top_value=top_value, - capacity_factor=capacity_factor, - cosine_router=cosine_router, - normalize_gate=normalize_gate, - use_bpr=use_bpr, - is_gshard_loss=is_gshard_loss, - gate_noise=gate_noise, - cosine_router_dim=cosine_router_dim, - cosine_router_init_t=cosine_router_init_t, - moe_drop=moe_drop) + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + input_resolution=( + patches_resolution[0] // (2**i_layer), + patches_resolution[1] // (2**i_layer), + ), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + mlp_fc2_bias=mlp_fc2_bias, + init_std=init_std, + use_checkpoint=use_checkpoint, + pretrained_window_size=pretrained_window_sizes[i_layer], + moe_block=moe_blocks[i_layer], + num_local_experts=num_local_experts, + top_value=top_value, + capacity_factor=capacity_factor, + cosine_router=cosine_router, + normalize_gate=normalize_gate, + use_bpr=use_bpr, + is_gshard_loss=is_gshard_loss, + gate_noise=gate_noise, + cosine_router_dim=cosine_router_dim, + cosine_router_init_t=cosine_router_init_t, + moe_drop=moe_drop, + ) self.layers.append(layer) self.norm = norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool1d(1) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head = ( + nn.Linear(self.num_features, num_classes) + if num_classes > 0 + else nn.Identity() + ) self.apply(self._init_weights) @@ -784,12 +1037,19 @@ def _init_weights(self, m): @torch.jit.ignore def no_weight_decay(self): - return {'absolute_pos_embed'} + return {"absolute_pos_embed"} @torch.jit.ignore def no_weight_decay_keywords(self): - return {"cpb_mlp", 'relative_position_bias_table', 'fc1_bias', 'fc2_bias', - 'temperature', 'cosine_projector', 'sim_matrix'} + return { + "cpb_mlp", + "relative_position_bias_table", + "fc1_bias", + "fc2_bias", + "temperature", + "cosine_projector", + "sim_matrix", + } def forward_features(self, x): x = self.patch_embed(x) @@ -819,6 +1079,11 @@ def flops(self): flops += self.patch_embed.flops() for i, layer in enumerate(self.layers): flops += layer.flops() - flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += ( + self.num_features + * self.patches_resolution[0] + * self.patches_resolution[1] + // (2**self.num_layers) + ) flops += self.num_features * self.num_classes return flops diff --git a/models/swin_transformer_v2.py b/models/swin_transformer_v2.py index a429d0a2..99a7a2cb 100644 --- a/models/swin_transformer_v2.py +++ b/models/swin_transformer_v2.py @@ -5,16 +5,41 @@ # Written by Ze Liu # -------------------------------------------------------- +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ -import numpy as np + + +class HierarchicalHead(nn.Module): + def __init__(self, num_features, num_classes): + super().__init__() + self.num_classes = tuple(num_classes) + for num_class in self.num_classes: + assert num_class > 0 + + self.heads = nn.ModuleList( + [nn.Linear(num_features, num_class) for num_class in self.num_classes] + ) + + def forward(self, x): + # we do not want to use self.heads(x) because that would feed them through + # each element in the list sequentially, whereas we want x through each head + # individually. + return [head(x) for head in self.heads] class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -43,7 +68,9 @@ def window_partition(x, window_size): """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) return windows @@ -59,13 +86,15 @@ def window_reverse(windows, window_size, H, W): x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = windows.view( + B, H // window_size, W // window_size, window_size, window_size, -1 + ) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. + r"""Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: @@ -78,8 +107,16 @@ class WindowAttention(nn.Module): pretrained_window_size (tuple[int]): The height and width of the window in pre-training. """ - def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., - pretrained_window_size=[0, 0]): + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + attn_drop=0.0, + proj_drop=0.0, + pretrained_window_size=[0, 0], + ): super().__init__() self.dim = dim @@ -87,28 +124,43 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., pro self.pretrained_window_size = pretrained_window_size self.num_heads = num_heads - self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) + self.logit_scale = nn.Parameter( + torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True + ) + self.register_buffer("logit_clamp_max", torch.log(torch.tensor(1.0 / 0.01))) # mlp to generate continuous relative position bias - self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), - nn.ReLU(inplace=True), - nn.Linear(512, num_heads, bias=False)) + self.cpb_mlp = nn.Sequential( + nn.Linear(2, 512, bias=True), + nn.ReLU(inplace=True), + nn.Linear(512, num_heads, bias=False), + ) # get relative_coords_table - relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) - relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) - relative_coords_table = torch.stack( - torch.meshgrid([relative_coords_h, - relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + relative_coords_h = torch.arange( + -(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32 + ) + relative_coords_w = torch.arange( + -(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32 + ) + relative_coords_table = ( + torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) + .permute(1, 2, 0) + .contiguous() + .unsqueeze(0) + ) # 1, 2*Wh-1, 2*Ww-1, 2 if pretrained_window_size[0] > 0: - relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) - relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) + relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1 + relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1 else: - relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) - relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1 + relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1 relative_coords_table *= 8 # normalize to -8, 8 - relative_coords_table = torch.sign(relative_coords_table) * torch.log2( - torch.abs(relative_coords_table) + 1.0) / np.log2(8) + relative_coords_table = ( + torch.sign(relative_coords_table) + * torch.log2(torch.abs(relative_coords_table) + 1.0) + / np.log2(8) + ) self.register_buffer("relative_coords_table", relative_coords_table) @@ -117,8 +169,12 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., pro coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 @@ -146,26 +202,47 @@ def forward(self, x, mask=None): B_, N, C = x.shape qkv_bias = None if self.q_bias is not None: - qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + qkv_bias = torch.cat( + ( + self.q_bias, + torch.zeros_like(self.v_bias, requires_grad=False), + self.v_bias, + ) + ) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) # cosine attention - attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) - logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp() + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_clamp_max).exp() attn = attn * logit_scale - relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) - relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view( + -1, self.num_heads + ) + relative_position_bias = relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1, + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww relative_position_bias = 16 * torch.sigmoid(relative_position_bias) attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( + 1 + ).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: @@ -179,8 +256,10 @@ def forward(self, x, mask=None): return x def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, ' \ - f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}' + return ( + f"dim={self.dim}, window_size={self.window_size}, " + f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}" + ) def flops(self, N): # calculate flops for 1 window with token length of N @@ -197,7 +276,7 @@ def flops(self, N): class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. + r"""Swin Transformer Block. Args: dim (int): Number of input channels. @@ -215,9 +294,22 @@ class SwinTransformerBlock(nn.Module): pretrained_window_size (int): Window size in pre-training. """ - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0): + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + pretrained_window_size=0, + ): super().__init__() self.dim = dim self.input_resolution = input_resolution @@ -229,39 +321,59 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + assert ( + 0 <= self.shift_size < self.window_size + ), "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, - pretrained_window_size=to_2tuple(pretrained_window_size)) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + pretrained_window_size=to_2tuple(pretrained_window_size), + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) if self.shift_size > 0: # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + attn_mask = attn_mask.masked_fill( + attn_mask != 0, float(-100.0) + ).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None @@ -277,16 +389,24 @@ def forward(self, x): # cyclic shift if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) + ) else: shifted_x = x # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + attn_windows = self.attn( + x_windows, mask=self.attn_mask + ) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) @@ -294,7 +414,9 @@ def forward(self, x): # reverse cyclic shift if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + x = torch.roll( + shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) + ) else: x = shifted_x x = x.view(B, H * W, C) @@ -306,8 +428,10 @@ def forward(self, x): return x def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + ) def flops(self): flops = 0 @@ -325,7 +449,7 @@ def flops(self): class PatchMerging(nn.Module): - r""" Patch Merging Layer. + r"""Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. @@ -374,7 +498,7 @@ def flops(self): class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. + """A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. @@ -393,10 +517,23 @@ class BasicLayer(nn.Module): pretrained_window_size (int): Local window size in pre-training. """ - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - pretrained_window_size=0): + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4.0, + qkv_bias=True, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + pretrained_window_size=0, + ): super().__init__() self.dim = dim @@ -405,21 +542,33 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size, self.use_checkpoint = use_checkpoint # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer, - pretrained_window_size=pretrained_window_size) - for i in range(depth)]) + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) + else drop_path, + norm_layer=norm_layer, + pretrained_window_size=pretrained_window_size, + ) + for i in range(depth) + ] + ) # patch merging layer if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + self.downsample = downsample( + input_resolution, dim=dim, norm_layer=norm_layer + ) else: self.downsample = None @@ -453,7 +602,7 @@ def _init_respostnorm(self): class PatchEmbed(nn.Module): - r""" Image to Patch Embedding + r"""Image to Patch Embedding Args: img_size (int): Image size. Default: 224. @@ -463,11 +612,16 @@ class PatchEmbed(nn.Module): norm_layer (nn.Module, optional): Normalization layer. Default: None """ - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + def __init__( + self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None + ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + patches_resolution = [ + img_size[0] // patch_size[0], + img_size[1] // patch_size[1], + ] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution @@ -476,7 +630,9 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_la self.in_chans = in_chans self.embed_dim = embed_dim - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: @@ -485,8 +641,9 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_la def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C if self.norm is not None: x = self.norm(x) @@ -494,14 +651,20 @@ def forward(self, x): def flops(self): Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + flops = ( + Ho + * Wo + * self.embed_dim + * self.in_chans + * (self.patch_size[0] * self.patch_size[1]) + ) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops class SwinTransformerV2(nn.Module): - r""" Swin Transformer + r"""Swin Transformer A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 @@ -526,12 +689,28 @@ class SwinTransformerV2(nn.Module): pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer. """ - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, - embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=7, mlp_ratio=4., qkv_bias=True, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, pretrained_window_sizes=[0, 0, 0, 0], **kwargs): + def __init__( + self, + img_size=224, + patch_size=4, + in_chans=3, + num_classes=1000, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + pretrained_window_sizes=[0, 0, 0, 0], + **kwargs, + ): super().__init__() self.num_classes = num_classes @@ -544,44 +723,69 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, # split image into non-overlapping patches self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + ) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # absolute position embedding if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, embed_dim) + ) + trunc_normal_(self.absolute_pos_embed, std=0.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): - layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), - input_resolution=(patches_resolution[0] // (2 ** i_layer), - patches_resolution[1] // (2 ** i_layer)), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, - use_checkpoint=use_checkpoint, - pretrained_window_size=pretrained_window_sizes[i_layer]) + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + input_resolution=( + patches_resolution[0] // (2**i_layer), + patches_resolution[1] // (2**i_layer), + ), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + pretrained_window_size=pretrained_window_sizes[i_layer], + ) self.layers.append(layer) self.norm = norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool1d(1) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + # Checks if we are doing hierarchical classification or not. + if isinstance(num_classes, int): + self.head = ( + nn.Linear(self.num_features, num_classes) + if num_classes > 0 + else nn.Identity() + ) + self.hierarchical = False + else: + self.num_classes = tuple(num_classes) + self.head = HierarchicalHead(self.num_features, num_classes) + self.hierarchical = True self.apply(self._init_weights) for bly in self.layers: @@ -589,7 +793,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, def _init_weights(self, m): if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) + trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): @@ -598,11 +802,11 @@ def _init_weights(self, m): @torch.jit.ignore def no_weight_decay(self): - return {'absolute_pos_embed'} + return {"absolute_pos_embed"} @torch.jit.ignore def no_weight_decay_keywords(self): - return {"cpb_mlp", "logit_scale", 'relative_position_bias_table'} + return {"cpb_mlp", "logit_scale", "relative_position_bias_table"} def forward_features(self, x): x = self.patch_embed(x) @@ -626,8 +830,21 @@ def forward(self, x): def flops(self): flops = 0 flops += self.patch_embed.flops() - for i, layer in enumerate(self.layers): + for layer in self.layers: flops += layer.flops() - flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) - flops += self.num_features * self.num_classes + flops += ( + self.num_features + * self.patches_resolution[0] + * self.patches_resolution[1] + // (2**self.num_layers) + ) + if isinstance(self.num_classes, int): + flops += self.num_features * self.num_classes + elif isinstance(self.num_classes, tuple): + for num_class in self.num_classes: + flops += self.num_features * num_class + else: + raise RuntimeError( + f"Internal error: self.num_classes should be int or tuple, not {type(self.num_classes)}" + ) return flops diff --git a/optimizer.py b/optimizer.py index 44317019..ddbe6f75 100644 --- a/optimizer.py +++ b/optimizer.py @@ -6,11 +6,12 @@ # -------------------------------------------------------- from functools import partial + from torch import optim as optim try: from apex.optimizers import FusedAdam, FusedLAMB -except: +except ImportError: FusedAdam = None FusedLAMB = None print("To use FusedLAMB or FusedAdam, please install apex.") @@ -22,36 +23,72 @@ def build_optimizer(config, model, simmim=False, is_pretrain=False): """ skip = {} skip_keywords = {} - if hasattr(model, 'no_weight_decay'): + if hasattr(model, "no_weight_decay"): skip = model.no_weight_decay() - if hasattr(model, 'no_weight_decay_keywords'): + if hasattr(model, "no_weight_decay_keywords"): skip_keywords = model.no_weight_decay_keywords() if simmim: if is_pretrain: parameters = get_pretrain_param_groups(model, skip, skip_keywords) else: - depths = config.MODEL.SWIN.DEPTHS if config.MODEL.TYPE == 'swin' else config.MODEL.SWINV2.DEPTHS + depths = ( + config.MODEL.SWIN.DEPTHS + if config.MODEL.TYPE == "swin" + else config.MODEL.SWINV2.DEPTHS + ) num_layers = sum(depths) - get_layer_func = partial(get_swin_layer, num_layers=num_layers + 2, depths=depths) - scales = list(config.TRAIN.LAYER_DECAY ** i for i in reversed(range(num_layers + 2))) - parameters = get_finetune_param_groups(model, config.TRAIN.BASE_LR, config.TRAIN.WEIGHT_DECAY, get_layer_func, scales, skip, skip_keywords) + get_layer_func = partial( + get_swin_layer, num_layers=num_layers + 2, depths=depths + ) + scales = list( + config.TRAIN.LAYER_DECAY**i for i in reversed(range(num_layers + 2)) + ) + parameters = get_finetune_param_groups( + model, + config.TRAIN.BASE_LR, + config.TRAIN.WEIGHT_DECAY, + get_layer_func, + scales, + skip, + skip_keywords, + ) else: parameters = set_weight_decay(model, skip, skip_keywords) opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() optimizer = None - if opt_lower == 'sgd': - optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, - lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) - elif opt_lower == 'adamw': - optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, - lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) - elif opt_lower == 'fused_adam': - optimizer = FusedAdam(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, - lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) - elif opt_lower == 'fused_lamb': - optimizer = FusedLAMB(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, - lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) + if opt_lower == "sgd": + optimizer = optim.SGD( + parameters, + momentum=config.TRAIN.OPTIMIZER.MOMENTUM, + nesterov=True, + lr=config.TRAIN.BASE_LR, + weight_decay=config.TRAIN.WEIGHT_DECAY, + ) + elif opt_lower == "adamw": + optimizer = optim.AdamW( + parameters, + eps=config.TRAIN.OPTIMIZER.EPS, + betas=config.TRAIN.OPTIMIZER.BETAS, + lr=config.TRAIN.BASE_LR, + weight_decay=config.TRAIN.WEIGHT_DECAY, + ) + elif opt_lower == "fused_adam": + optimizer = FusedAdam( + parameters, + eps=config.TRAIN.OPTIMIZER.EPS, + betas=config.TRAIN.OPTIMIZER.BETAS, + lr=config.TRAIN.BASE_LR, + weight_decay=config.TRAIN.WEIGHT_DECAY, + ) + elif opt_lower == "fused_lamb": + optimizer = FusedLAMB( + parameters, + eps=config.TRAIN.OPTIMIZER.EPS, + betas=config.TRAIN.OPTIMIZER.BETAS, + lr=config.TRAIN.BASE_LR, + weight_decay=config.TRAIN.WEIGHT_DECAY, + ) return optimizer @@ -63,14 +100,17 @@ def set_weight_decay(model, skip_list=(), skip_keywords=()): for name, param in model.named_parameters(): if not param.requires_grad: continue # frozen weights - if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ - check_keywords_in_name(name, skip_keywords): + if ( + len(param.shape) == 1 + or name.endswith(".bias") + or (name in skip_list) + or check_keywords_in_name(name, skip_keywords) + ): no_decay.append(param) # print(f"{name} has no weight decay") else: has_decay.append(param) - return [{'params': has_decay}, - {'params': no_decay, 'weight_decay': 0.}] + return [{"params": has_decay}, {"params": no_decay, "weight_decay": 0.0}] def check_keywords_in_name(name, keywords=()): @@ -86,19 +126,22 @@ def get_pretrain_param_groups(model, skip_list=(), skip_keywords=()): no_decay = [] has_decay_name = [] no_decay_name = [] - + for name, param in model.named_parameters(): if not param.requires_grad: continue - if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ - check_keywords_in_name(name, skip_keywords): + if ( + len(param.shape) == 1 + or name.endswith(".bias") + or (name in skip_list) + or check_keywords_in_name(name, skip_keywords) + ): no_decay.append(param) no_decay_name.append(name) else: has_decay.append(param) has_decay_name.append(name) - return [{'params': has_decay}, - {'params': no_decay, 'weight_decay': 0.}] + return [{"params": has_decay}, {"params": no_decay, "weight_decay": 0.0}] def get_swin_layer(name, num_layers, depths): @@ -107,27 +150,33 @@ def get_swin_layer(name, num_layers, depths): elif name.startswith("patch_embed"): return 0 elif name.startswith("layers"): - layer_id = int(name.split('.')[1]) - block_id = name.split('.')[3] - if block_id == 'reduction' or block_id == 'norm': - return sum(depths[:layer_id + 1]) + layer_id = int(name.split(".")[1]) + block_id = name.split(".")[3] + if block_id == "reduction" or block_id == "norm": + return sum(depths[: layer_id + 1]) layer_id = sum(depths[:layer_id]) + int(block_id) return layer_id + 1 else: return num_layers - 1 -def get_finetune_param_groups(model, lr, weight_decay, get_layer_func, scales, skip_list=(), skip_keywords=()): +def get_finetune_param_groups( + model, lr, weight_decay, get_layer_func, scales, skip_list=(), skip_keywords=() +): parameter_group_names = {} parameter_group_vars = {} for name, param in model.named_parameters(): if not param.requires_grad: continue - if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ - check_keywords_in_name(name, skip_keywords): + if ( + len(param.shape) == 1 + or name.endswith(".bias") + or (name in skip_list) + or check_keywords_in_name(name, skip_keywords) + ): group_name = "no_decay" - this_weight_decay = 0. + this_weight_decay = 0.0 else: group_name = "decay" this_weight_decay = weight_decay @@ -141,7 +190,7 @@ def get_finetune_param_groups(model, lr, weight_decay, get_layer_func, scales, s if scales is not None: scale = scales[layer_id] else: - scale = 1. + scale = 1.0 parameter_group_names[group_name] = { "group_name": group_name, @@ -155,7 +204,7 @@ def get_finetune_param_groups(model, lr, weight_decay, get_layer_func, scales, s "weight_decay": this_weight_decay, "params": [], "lr": lr * scale, - "lr_scale": scale + "lr_scale": scale, } parameter_group_vars[group_name]["params"].append(param) diff --git a/scripts/generate_wandb_id.py b/scripts/generate_wandb_id.py new file mode 100644 index 00000000..b32fc512 --- /dev/null +++ b/scripts/generate_wandb_id.py @@ -0,0 +1,3 @@ +import wandb.util + +print(wandb.util.generate_id()) diff --git a/scripts/parse_logs.py b/scripts/parse_logs.py new file mode 100644 index 00000000..73f0c59d --- /dev/null +++ b/scripts/parse_logs.py @@ -0,0 +1,344 @@ +""" +This scripts parses the training logs to graph both training loss and validation accuracy over time. +""" + +import argparse +import dataclasses +import re + +import matplotlib.pyplot as plt +import preface + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("file", help="Log file to parse. Typically named log_rank0.txt") + parser.add_argument( + "--last", help="How many of the latest epochs to look at.", default=10, type=int + ) + parser.add_argument("--name", help="Run name") + + return parser.parse_args() + + +@dataclasses.dataclass +class ValidationLine: + epoch: int + batch: int + batch_max: int + loss: float + mean_loss: float + acc1: float + mean_acc1: float + acc5: float + mean_acc5: float + + pattern0 = r""" + ^\[.*?\]\ + INFO:\ Test:\ \ + \[\ ?(?P\d+)/(?P\d+)\] + \ \ + eta:\ \d+:\d\d:\d\d + \ \ + loss:\ (?P[\w.]+)\ \((?P[\w.]+)\) + \ \ + acc1:\ (?P[\w.]+)\ \((?P[\w.]+)\) + \ \ + acc5:\ (?P[\w.]+)\ \((?P[\w.]+)\) + $ + """ + + pattern1 = r""" + ^\[.*?\]\ + \(main.py\ \d+\):\ + INFO\ Test:\ + \[(?P\d+)/(?P\d+)\] + \t + Time\ \d+.\d+\ \(\d+.\d+\) + \t + Loss\ (?P[\w.]+)\ \((?P[\w.]+)\) + \t + Acc@1\ (?P[\w.]+)\ \((?P[\w.]+)\) + \t + Acc@5\ (?P[\w.]+)\ \((?P[\w.]+)\) + \t + Mem\ (?P.*) + $ + """ + + @classmethod + def from_raw_line(cls, line, last_train): + if "Test" not in line: + return None + + # Example line: + # [2022-06-08 07:34:50 swinv2_large_patch4_window7_224_inat21](main.py 258): INFO Test: [0/196] Time 1.895 (1.895) Loss 0.8066 (0.8066) Acc@1 84.766 (84.766) Acc@5 95.312 (95.312) Mem 36916MB + + match = re.match(cls.pattern0, line, re.VERBOSE) or re.match( + cls.pattern1, line, re.VERBOSE + ) + + if not match: + print("Couldn't match validation line:", repr(line)) + return None + + epoch = 0 + if last_train: + epoch = last_train.epoch + + return cls( + epoch, + int(match.group("batch")), + int(match.group("batch_max")), + float(match.group("loss")), + float(match.group("mean_loss")), + float(match.group("acc1")), + float(match.group("mean_acc1")), + float(match.group("acc5")), + float(match.group("mean_acc5")), + ) + + +@dataclasses.dataclass +class TrainLine: + epoch: int + epoch_max: int + batch: int + batch_max: int + lr: float + wd: float + loss: float + mean_loss: float + grad_norm: float + mean_grad_norm: float + loss_scale: float + mean_loss_scale: float + + pattern0 = r""" + ^\[.*?\]\ # [2022-06-08 08:35:04 ...] + INFO:\ Epoch:\ \[(?P\d+)\] + \ \ + \[\ *(?P\d+)/(?P\d+)\] # [700/5247] + \ \ # + eta:\ (\d\ day,\ )?\d+:\d\d:\d\d # eta 3:11:57 + \ \ # + lr:\ (?P\d\.\d+) # lr 0.000040 + \ \ # + loss:\ (?P[\w.]+)\ \((?P[\w.]+)\) # loss 4.3632 (3.7333) + \ \ # + grad_norm:\ (?P[\w.]+)\ \((?P[\w.]+)\) # grad_norm 9.5996 (inf) + \ \ + time:\ (?P