diff --git a/.gitignore b/.gitignore
index 24f67f2e..9317223a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -87,9 +87,6 @@ target/
profile_default/
ipython_config.py
-# pyenv
-.python-version
-
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
@@ -133,3 +130,13 @@ dmypy.json
# Pyre type checker
.pyre/
+
+# Apex
+apex/
+
+# Logging
+runs/
+wandb/
+
+# host-specific values
+scripts/env.fish
diff --git a/.python-version b/.python-version
new file mode 100644
index 00000000..1281604a
--- /dev/null
+++ b/.python-version
@@ -0,0 +1 @@
+3.10.7
diff --git a/README.md b/README.md
index 816bf731..e333743d 100644
--- a/README.md
+++ b/README.md
@@ -1,310 +1,114 @@
# Swin Transformer
-[](https://paperswithcode.com/sota/object-detection-on-coco?p=swin-transformer-v2-scaling-up-capacity-and)
-[](https://paperswithcode.com/sota/instance-segmentation-on-coco?p=swin-transformer-v2-scaling-up-capacity-and)
-[](https://paperswithcode.com/sota/semantic-segmentation-on-ade20k?p=swin-transformer-v2-scaling-up-capacity-and)
-[](https://paperswithcode.com/sota/action-classification-on-kinetics-400?p=swin-transformer-v2-scaling-up-capacity-and)
+[Link to original Swin Transformer project](https://github.com/microsoft/Swin-Transformer)
-This repo is the official implementation of ["Swin Transformer: Hierarchical Vision Transformer using Shifted Windows"](https://arxiv.org/pdf/2103.14030.pdf) as well as the follow-ups. It currently includes code and models for the following tasks:
+## Installation Instructions
-> **Image Classification**: Included in this repo. See [get_started.md](get_started.md) for a quick start.
+1. Set up python packages
-> **Object Detection and Instance Segmentation**: See [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection).
-
-> **Semantic Segmentation**: See [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation).
-
-> **Video Action Recognition**: See [Video Swin Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer).
-
-> **Semi-Supervised Object Detection**: See [Soft Teacher](https://github.com/microsoft/SoftTeacher).
-
-> **SSL: Contrasitive Learning**: See [Transformer-SSL](https://github.com/SwinTransformer/Transformer-SSL).
-
-> **SSL: Masked Image Modeling**: See [get_started.md#simmim-support](https://github.com/microsoft/Swin-Transformer/blob/main/get_started.md#simmim-support).
-
-> **Mixture-of-Experts**: See [get_started](get_started.md#mixture-of-experts-support) for more instructions.
-
-> **Feature-Distillation**: Will appear in [Feature-Distillation](https://github.com/SwinTransformer/Feature-Distillation).
-
-## Activity notification
-
-* 09/18/2022: Organizing ECCV Workshop [*Computer Vision in the Wild (CVinW)*](https://computer-vision-in-the-wild.github.io/eccv-2022/), where two challenges are hosted to evaluate the zero-shot, few-shot and full-shot performance of pre-trained vision models in downstream tasks:
- - [``*Image Classification in the Wild (ICinW)*''](https://eval.ai/web/challenges/challenge-page/1832/overview) Challenge evaluates on 20 image classification tasks.
- - [``*Object Detection in the Wild (ODinW)*''](https://eval.ai/web/challenges/challenge-page/1839/overview) Challenge evaluates on 35 object detection tasks.
-
-
-$\qquad$ [
[Workshop]](https://computer-vision-in-the-wild.github.io/eccv-2022/) $\qquad$ [
[IC Challenge] ](https://eval.ai/web/challenges/challenge-page/1832/overview)
-$\qquad$ [
[OD Challenge] ](https://eval.ai/web/challenges/challenge-page/1839/overview)
-
-## Updates
-
-***09/24/2022***
-
-1. Merged [SimMIM](https://github.com/microsoft/SimMIM), which is a **Masked Image Modeling** based pre-training approach applicable to Swin and SwinV2 (and also applicable for ViT and ResNet). Please refer to [get started with SimMIM](get_started.md#simmim-support) to play with SimMIM pre-training.
-
-2. Released a series of Swin and SwinV2 models pre-trained using the SimMIM approach (see [MODELHUB for SimMIM](MODELHUB.md#simmim-pretrained-swin-v2-models)), with model size ranging from SwinV2-Small-50M to SwinV2-giant-1B, data size ranging from ImageNet-1K-10% to ImageNet-22K, and iterations from 125k to 500k. You may leverage these models to study the properties of MIM methods. Please look into the [data scaling](https://arxiv.org/abs/2206.04664) paper for more details.
-
-***07/09/2022***
-
-`News`:
-
-1. SwinV2-G achieves `61.4 mIoU` on ADE20K semantic segmentation (+1.5 mIoU over the previous SwinV2-G model), using an additional [feature distillation (FD)](https://github.com/SwinTransformer/Feature-Distillation) approach, **setting a new recrod** on this benchmark. FD is an approach that can generally improve the fine-tuning performance of various pre-trained models, including DeiT, DINO, and CLIP. Particularly, it improves CLIP pre-trained ViT-L by +1.6% to reach `89.0%` on ImageNet-1K image classification, which is **the most accurate ViT-L model**.
-2. Merged a PR from **Nvidia** that links to faster Swin Transformer inference that have significant speed improvements on `T4 and A100 GPUs`.
-3. Merged a PR from **Nvidia** that enables an option to use `pure FP16 (Apex O2)` in training, while almost maintaining the accuracy.
-
-***06/03/2022***
-
-1. Added **Swin-MoE**, the Mixture-of-Experts variant of Swin Transformer implemented using [Tutel](https://github.com/microsoft/tutel) (an optimized Mixture-of-Experts implementation). **Swin-MoE** is introduced in the [TuTel](https://arxiv.org/abs/2206.03382) paper.
-
-***05/12/2022***
-
-1. Pretrained models of [Swin Transformer V2](https://arxiv.org/abs/2111.09883) on ImageNet-1K and ImageNet-22K are released.
-2. ImageNet-22K pretrained models for Swin-V1-Tiny and Swin-V2-Small are released.
-
-***03/02/2022***
-
-1. Swin Transformer V2 and SimMIM got accepted by CVPR 2022. [SimMIM](https://github.com/microsoft/SimMIM) is a self-supervised pre-training approach based on masked image modeling, a key technique that works out the 3-billion-parameter Swin V2 model using `40x less labelled data` than that of previous billion-scale models based on JFT-3B.
-
-***02/09/2022***
-
-1. Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo [](https://huggingface.co/spaces/akhaliq/Swin-Transformer)
-
-***10/12/2021***
-
-1. Swin Transformer received ICCV 2021 best paper award (Marr Prize).
-
-***08/09/2021***
-1. [Soft Teacher](https://arxiv.org/pdf/2106.09018v2.pdf) will appear at ICCV2021. The code will be released at [GitHub Repo](https://github.com/microsoft/SoftTeacher). `Soft Teacher` is an end-to-end semi-supervisd object detection method, achieving a new record on the COCO test-dev: `61.3 box AP` and `53.0 mask AP`.
-
-***07/03/2021***
-1. Add **Swin MLP**, which is an adaption of `Swin Transformer` by replacing all multi-head self-attention (MHSA) blocks by MLP layers (more precisely it is a group linear layer). The shifted window configuration can also significantly improve the performance of vanilla MLP architectures.
-
-***06/25/2021***
-1. [Video Swin Transformer](https://arxiv.org/abs/2106.13230) is released at [Video-Swin-Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer).
-`Video Swin Transformer` achieves state-of-the-art accuracy on a broad range of video recognition benchmarks, including action recognition (`84.9` top-1 accuracy on Kinetics-400 and `86.1` top-1 accuracy on Kinetics-600 with `~20x` less pre-training data and `~3x` smaller model size) and temporal modeling (`69.6` top-1 accuracy on Something-Something v2).
-
-***05/12/2021***
-1. Used as a backbone for `Self-Supervised Learning`: [Transformer-SSL](https://github.com/SwinTransformer/Transformer-SSL)
-
-Using Swin-Transformer as the backbone for self-supervised learning enables us to evaluate the transferring performance of the learnt representations on down-stream tasks, which is missing in previous works due to the use of ViT/DeiT, which has not been well tamed for down-stream tasks.
+```sh
+python -m venv venv
+# Activate your virtual environment somehow
+source venv/bin/activate.fish
+```
-***04/12/2021***
+CUDA 11.6
-Initial commits:
+```sh
+pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
+```
-1. Pretrained models on ImageNet-1K ([Swin-T-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth), [Swin-S-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth), [Swin-B-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)) and ImageNet-22K ([Swin-B-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth), [Swin-L-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)) are provided.
-2. The supported code and models for ImageNet-1K image classification, COCO object detection and ADE20K semantic segmentation are provided.
-3. The cuda kernel implementation for the [local relation layer](https://arxiv.org/pdf/1904.11491.pdf) is provided in branch [LR-Net](https://github.com/microsoft/Swin-Transformer/tree/LR-Net).
+CUDA 11.3
-## Introduction
+```sh
+pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
+```
-**Swin Transformer** (the name `Swin` stands for **S**hifted **win**dow) is initially described in [arxiv](https://arxiv.org/abs/2103.14030), which capably serves as a
-general-purpose backbone for computer vision. It is basically a hierarchical Transformer whose representation is
-computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention
-computation to non-overlapping local windows while also allowing for cross-window connection.
+Python packages
-Swin Transformer achieves strong performance on COCO object detection (`58.7 box AP` and `51.1 mask AP` on test-dev) and
-ADE20K semantic segmentation (`53.5 mIoU` on val), surpassing previous models by a large margin.
+```sh
+pip install matplotlib yacs timm einops black isort flake8 flake8-bugbear termcolor wandb preface opencv-python
+```
-
+2. Install Apex
-## Main Results on ImageNet with Pretrained Models
+```sh
+git clone https://github.com/NVIDIA/apex.git
+cd apex
+pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
+```
-**ImageNet-1K and ImageNet-22K Pretrained Swin-V1 Models**
+```sh
+cd kernels/window_process
+python setup.py install
+```
-| name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS| 22K model | 1K model |
-| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---: |:---: |
-| Swin-T | ImageNet-1K | 224x224 | 81.2 | 95.5 | 28M | 4.5G | 755 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/156nWJy4Q28rDlrX-rRbI3w)/[config](configs/swin/swin_tiny_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745562/log_swin_tiny_patch4_window7_224.txt) |
-| Swin-S | ImageNet-1K | 224x224 | 83.2 | 96.2 | 50M | 8.7G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1KFjpj3Efey3LmtE1QqPeQg)/[config](configs/swin/swin_small_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745563/log_swin_small_patch4_window7_224.txt) |
-| Swin-B | ImageNet-1K | 224x224 | 83.5 | 96.5 | 88M | 15.4G | 278 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/16bqCTEc70nC_isSsgBSaqQ)/[config](configs/swin/swin_base_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745564/log_swin_base_patch4_window7_224.txt) |
-| Swin-B | ImageNet-1K | 384x384 | 84.5 | 97.0 | 88M | 47.1G | 85 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth)/[baidu](https://pan.baidu.com/s/1xT1cu740-ejW7htUdVLnmw)/[config](configs/swin/swin_base_patch4_window12_384_finetune.yaml) |
-| Swin-T | ImageNet-22K | 224x224 | 80.9 | 96.0 | 28M | 4.5G | 755 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1vct0VYwwQQ8PYkBjwSSBZQ?pwd=swin)/[config](configs/swin/swin_tiny_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22kto1k_finetune.pth)/[baidu](https://pan.baidu.com/s/1K0OO-nGZDPkR8fm_r83e8Q?pwd=swin)/[config](configs/swin/swin_tiny_patch4_window7_224_22kto1k_finetune.yaml) |
-| Swin-S | ImageNet-22K | 224x224 | 83.2 | 97.0 | 50M | 8.7G | 437 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/11NC1xdT5BAGBgazdTme5Sg?pwd=swin)/[config](configs/swin/swin_small_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22kto1k_finetune.pth)/[baidu](https://pan.baidu.com/s/10RFVfjQJhwPfeHrmxQUaLw?pwd=swin)/[config](configs/swin/swin_small_patch4_window7_224_22kto1k_finetune.yaml) |
-| Swin-B | ImageNet-22K | 224x224 | 85.2 | 97.5 | 88M | 15.4G | 278 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1y1Ec3UlrKSI8IMtEs-oBXA)/[config](configs/swin/swin_base_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1n_wNkcbRxVXit8r_KrfAVg)/[config](configs/swin/swin_base_patch4_window7_224_22kto1k_finetune.yaml) |
-| Swin-B | ImageNet-22K | 384x384 | 86.4 | 98.0 | 88M | 47.1G | 85 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1vwJxnJcVqcLZAw9HaqiR6g) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1caKTSdoLJYoi4WBcnmWuWg)/[config](configs/swin/swin_base_patch4_window12_384_22kto1k_finetune.yaml) |
-| Swin-L | ImageNet-22K | 224x224 | 86.3 | 97.9 | 197M | 34.5G | 141 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1pws3rOTFuOebBYP3h6Kx8w)/[config](configs/swin/swin_large_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1NkQApMWUhxBGjk1ne6VqBQ)/[config](configs/swin/swin_large_patch4_window7_224_22kto1k_finetune.yaml) |
-| Swin-L | ImageNet-22K | 384x384 | 87.3 | 98.2 | 197M | 103.9G | 42 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1sl7o_bJA143OD7UqSLAMoA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1X0FLHQyPOC6Kmv2CmgxJvA)/[config](configs/swin/swin_large_patch4_window12_384_22kto1k_finetune.yaml) |
+3. Download Data
-**ImageNet-1K and ImageNet-22K Pretrained Swin-V2 Models**
+We use the iNat21 dataseta available on [GitHub](https://github.com/visipedia/inat_comp/tree/master/2021)
-| name | pretrain | resolution | window |acc@1 | acc@5 | #params | FLOPs | FPS |22K model | 1K model |
-|:---------------------:| :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---:|:---: |:---: |
-| SwinV2-T | ImageNet-1K | 256x256 | 8x8 | 81.8 | 95.9 | 28M | 5.9G | 572 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1RzLkAH_5OtfRCJe6Vlg6rg?pwd=swin)/[config](configs/swinv2/swinv2_tiny_patch4_window8_256.yaml) |
-| SwinV2-S | ImageNet-1K | 256x256 | 8x8 | 83.7 | 96.6 | 50M | 11.5G | 327 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/195PdA41szEduW3jEtRSa4Q?pwd=swin)/[config](configs/swinv2/swinv2_small_patch4_window8_256.yaml) |
-| SwinV2-B | ImageNet-1K | 256x256 | 8x8 | 84.2 | 96.9 | 88M | 20.3G | 217 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/18AfMSz3dPyzIvP1dKuERvQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window8_256.yaml) |
-| SwinV2-T | ImageNet-1K | 256x256 | 16x16 | 82.8 | 96.2 | 28M | 6.6G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1dyK3cK9Xipmv6RnTtrPocw?pwd=swin)/[config](configs/swinv2/swinv2_tiny_patch4_window16_256.yaml) |
-| SwinV2-S | ImageNet-1K | 256x256 | 16x16 | 84.1 | 96.8 | 50M | 12.6G | 257 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1ZIPiSfWNKTPp821Ka-Mifw?pwd=swin)/[config](configs/swinv2/swinv2_small_patch4_window16_256.yaml) |
-| SwinV2-B | ImageNet-1K | 256x256 | 16x16 | 84.6 | 97.0 | 88M | 21.8G | 174 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1dlDQGn8BXCmnh7wQSM5Nhw?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window16_256.yaml) |
-| SwinV2-B\* | ImageNet-22K | 256x256 | 16x16 | 86.2 | 97.9 | 88M | 21.8G | 174 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/1Xc2rsSsRQz_sy5mjgfxrMQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/1sgstld4MgGsZxhUAW7MlmQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.yaml) |
-| SwinV2-B\* | ImageNet-22K | 384x384 | 24x24 | 87.1 | 98.2 | 88M | 54.7G | 57 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/1Xc2rsSsRQz_sy5mjgfxrMQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/17u3sEQaUYlvfL195rrORzQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.yaml) |
-| SwinV2-L\* | ImageNet-22K | 256x256 | 16x16 | 86.9 | 98.0 | 197M | 47.5G | 95 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/11PhCV7qAGXtZ8dXNgyiGOw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/1pqp31N80qIWjFPbudzB6Bw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.yaml) |
-| SwinV2-L\* | ImageNet-22K | 384x384 | 24x24 | 87.6 | 98.3 | 197M | 115.4G | 33 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/11PhCV7qAGXtZ8dXNgyiGOw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/13URdNkygr3Xn0N3e6IwjgA?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.yaml) |
+```
+cd /mnt/10tb
+mkdir -p data/inat21
+cd data/inat21
+mkdir compressed raw
+cd compressed
+wget https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz
+wget https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz
+
+# pv is just a progress bar
+pv val.tar.gz | tar -xz
+mv val ../raw/ # if I knew how tar worked I could have it extract to raw/
+
+pv train.tar.gz | tar -xz
+mv train ../raw/
+```
-Note:
-- SwinV2-B\* (SwinV2-L\*) with input resolution of 256x256 and 384x384 both fine-tuned from the same pre-training model using a smaller input resolution of 192x192.
-- SwinV2-B\* (384x384) achieves 78.08 acc@1 on ImageNet-1K-V2 while SwinV2-L\* (384x384) achieves 78.31.
+4. Preprocess iNat 21
-**ImageNet-1K Pretrained Swin MLP Models**
+Use your root data folder and your size of choice.
-| name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS | 1K model |
-| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
-| [Mixer-B/16](https://arxiv.org/pdf/2105.01601.pdf) | ImageNet-1K | 224x224 | 76.4 | - | 59M | 12.7G | - | [official repo](https://github.com/google-research/vision_transformer) |
-| [ResMLP-S24](https://arxiv.org/abs/2105.03404) | ImageNet-1K | 224x224 | 79.4 | - | 30M | 6.0G | 715 | [timm](https://github.com/rwightman/pytorch-image-models) |
-| [ResMLP-B24](https://arxiv.org/abs/2105.03404) | ImageNet-1K | 224x224 | 81.0 | - | 116M | 23.0G | 231 | [timm](https://github.com/rwightman/pytorch-image-models) |
-| Swin-T/C24 | ImageNet-1K | 256x256 | 81.6 | 95.7 | 28M | 5.9G | 563 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_tiny_c24_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/17k-7l6Sxt7uZ7IV0f26GNQ)/[config](configs/swin/swin_tiny_c24_patch4_window8_256.yaml) |
-| SwinMLP-T/C24 | ImageNet-1K | 256x256 | 79.4 | 94.6 | 20M | 4.0G | 807 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c24_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1Sa4vP5R0M2RjfIe9HIga-Q)/[config](configs/swin/swin_mlp_tiny_c24_patch4_window8_256.yaml) |
-| SwinMLP-T/C12 | ImageNet-1K | 256x256 | 79.6 | 94.7 | 21M | 4.0G | 792 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c12_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1mM9J2_DEVZHUB5ASIpFl0w)/[config](configs/swin/swin_mlp_tiny_c12_patch4_window8_256.yaml) |
-| SwinMLP-T/C6 | ImageNet-1K | 256x256 | 79.7 | 94.9 | 23M | 4.0G | 766 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c6_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1hUTYVT2W1CsjICw-3W-Vjg)/[config](configs/swin/swin_mlp_tiny_c6_patch4_window8_256.yaml) |
-| SwinMLP-B | ImageNet-1K | 224x224 | 81.3 | 95.3 | 61M | 10.4G | 409 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1zww3dnbX3GxNiGfb-GwyUg)/[config](configs/swin/swin_mlp_base_patch4_window7_224.yaml) |
+```
+export DATA_DIR=/mnt/10tb/data/inat21/
+python -m data.inat preprocess $DATA_DIR val resize 192
+python -m data.inat preprocess $DATA_DIR train resize 192
+python -m data.inat preprocess $DATA_DIR val resize 256
+python -m data.inat preprocess $DATA_DIR train resize 256
+```
-Note: access code for `baidu` is `swin`. C24 means each head has 24 channels.
+5. Login to Wandb
-**ImageNet-22K Pretrained Swin-MoE Models**
+```
+wandb login
+```
-- Please refer to [get_started](get_started.md#mixture-of-experts-support) for instructions on running Swin-MoE.
-- Pretrained models for Swin-MoE can be found in [MODEL HUB](MODELHUB.md#imagenet-22k-pretrained-swin-moe-models)
+6. Set up an `env.fish` file:
-## Main Results on Downstream Tasks
+You need to provide `$VENV` and a `$RUN_OUTPUT` environment variables.
+I recommend using a file to save these variables.
-**COCO Object Detection (2017 val)**
+In fish:
-| Backbone | Method | pretrain | Lr Schd | box mAP | mask mAP | #params | FLOPs |
-| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
-| Swin-T | Mask R-CNN | ImageNet-1K | 3x | 46.0 | 41.6 | 48M | 267G |
-| Swin-S | Mask R-CNN | ImageNet-1K | 3x | 48.5 | 43.3 | 69M | 359G |
-| Swin-T | Cascade Mask R-CNN | ImageNet-1K | 3x | 50.4 | 43.7 | 86M | 745G |
-| Swin-S | Cascade Mask R-CNN | ImageNet-1K | 3x | 51.9 | 45.0 | 107M | 838G |
-| Swin-B | Cascade Mask R-CNN | ImageNet-1K | 3x | 51.9 | 45.0 | 145M | 982G |
-| Swin-T | RepPoints V2 | ImageNet-1K | 3x | 50.0 | - | 45M | 283G |
-| Swin-T | Mask RepPoints V2 | ImageNet-1K | 3x | 50.3 | 43.6 | 47M | 292G |
-| Swin-B | HTC++ | ImageNet-22K | 6x | 56.4 | 49.1 | 160M | 1043G |
-| Swin-L | HTC++ | ImageNet-22K | 3x | 57.1 | 49.5 | 284M | 1470G |
-| Swin-L | HTC++* | ImageNet-22K | 3x | 58.0 | 50.4 | 284M | - |
+```fish
+# scripts/env.fish
+set -gx VENV venv
+set -gx RUN_OUTPUT /mnt/10tb/models/hierarchical-vision
+```
-Note: * indicates multi-scale testing.
+Then run `source scripts/env.fish`
-**ADE20K Semantic Segmentation (val)**
+## AWS Helpers
-| Backbone | Method | pretrain | Crop Size | Lr Schd | mIoU | mIoU (ms+flip) | #params | FLOPs |
-| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
-| Swin-T | UPerNet | ImageNet-1K | 512x512 | 160K | 44.51 | 45.81 | 60M | 945G |
-| Swin-S | UperNet | ImageNet-1K | 512x512 | 160K | 47.64 | 49.47 | 81M | 1038G |
-| Swin-B | UperNet | ImageNet-1K | 512x512 | 160K | 48.13 | 49.72 | 121M | 1188G |
-| Swin-B | UPerNet | ImageNet-22K | 640x640 | 160K | 50.04 | 51.66 | 121M | 1841G |
-| Swin-L | UperNet | ImageNet-22K | 640x640 | 160K | 52.05 | 53.53 | 234M | 3230G |
-
-## Citing Swin Transformer
+Uninstall v1 of awscli:
```
-@inproceedings{liu2021Swin,
- title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
- author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining},
- booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
- year={2021}
-}
-```
-## Citing Local Relation Networks (the first full-attention visual backbone)
-```
-@inproceedings{hu2019local,
- title={Local Relation Networks for Image Recognition},
- author={Hu, Han and Zhang, Zheng and Xie, Zhenda and Lin, Stephen},
- booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
- pages={3464--3473},
- year={2019}
-}
-```
-## Citing Swin Transformer V2
+sudo /usr/local/bin/pip uninstall awscli
```
-@inproceedings{liu2021swinv2,
- title={Swin Transformer V2: Scaling Up Capacity and Resolution},
- author={Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
- booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)},
- year={2022}
-}
-```
-## Citing SimMIM (a self-supervised approach that enables SwinV2-G)
-```
-@inproceedings{xie2021simmim,
- title={SimMIM: A Simple Framework for Masked Image Modeling},
- author={Xie, Zhenda and Zhang, Zheng and Cao, Yue and Lin, Yutong and Bao, Jianmin and Yao, Zhuliang and Dai, Qi and Hu, Han},
- booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)},
- year={2022}
-}
-```
-## Citing SimMIM-data-scaling
-```
-@article{xie2022data,
- title={On Data Scaling in Masked Image Modeling},
- author={Xie, Zhenda and Zhang, Zheng and Cao, Yue and Lin, Yutong and Wei, Yixuan and Dai, Qi and Hu, Han},
- journal={arXiv preprint arXiv:2206.04664},
- year={2022}
-}
-```
-## Citing Swin-MoE
+
+Install v2:
```
-@misc{hwang2022tutel,
- title={Tutel: Adaptive Mixture-of-Experts at Scale},
- author={Changho Hwang and Wei Cui and Yifan Xiong and Ziyue Yang and Ze Liu and Han Hu and Zilong Wang and Rafael Salas and Jithin Jose and Prabhat Ram and Joe Chau and Peng Cheng and Fan Yang and Mao Yang and Yongqiang Xiong},
- year={2022},
- eprint={2206.03382},
- archivePrefix={arXiv}
-}
+cd ~/pkg
+curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
+unzip awscliv2.zip
+./aws/install --bin-dir ~/.local/bin --install-dir ~/.local/aws-cli
```
-
-## Getting Started
-
-- For **Image Classification**, please see [get_started.md](get_started.md) for detailed instructions.
-- For **Object Detection and Instance Segmentation**, please see [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection).
-- For **Semantic Segmentation**, please see [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation).
-- For **Self-Supervised Learning**, please see [Transformer-SSL](https://github.com/SwinTransformer/Transformer-SSL).
-- For **Video Recognition**, please see [Video Swin Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer).
-
-## Third-party Usage and Experiments
-
-***In this pargraph, we cross link third-party repositories which use Swin and report results. You can let us know by raising an issue***
-
-(`Note please report accuracy numbers and provide trained models in your new repository to facilitate others to get sense of correctness and model behavior`)
-
-[06/30/2022] Swin Transformers (V1) inference implemented in FasterTransformer: [FasterTransformer](https://github.com/NVIDIA/FasterTransformer/blob/main/docs/swin_guide.md)
-
-[05/12/2022] Swin Transformers (V1) implemented in TensorFlow with the pre-trained parameters ported into them. Find the implementation,
-TensorFlow weights, code example here in [this repository](https://github.com/sayakpaul/swin-transformers-tf/).
-
-[04/06/2022] Swin Transformer for Audio Classification: [Hierarchical Token Semantic Audio Transformer](https://github.com/RetroCirce/HTS-Audio-Transformer).
-
-[12/21/2021] Swin Transformer for StyleGAN: [StyleSwin](https://github.com/microsoft/StyleSwin)
-
-[12/13/2021] Swin Transformer for Face Recognition: [FaceX-Zoo](https://github.com/JDAI-CV/FaceX-Zoo)
-
-[08/29/2021] Swin Transformer for Image Restoration: [SwinIR](https://github.com/JingyunLiang/SwinIR)
-
-[08/12/2021] Swin Transformer for person reID: [https://github.com/layumi/Person_reID_baseline_pytorch](https://github.com/layumi/Person_reID_baseline_pytorch)
-
-[06/29/2021] Swin-Transformer in PaddleClas and inference based on whl package: [https://github.com/PaddlePaddle/PaddleClas](https://github.com/PaddlePaddle/PaddleClas)
-
-[04/14/2021] Swin for RetinaNet in Detectron: https://github.com/xiaohu2015/SwinT_detectron2.
-
-[04/16/2021] Included in a famous model zoo: https://github.com/rwightman/pytorch-image-models.
-
-[04/20/2021] Swin-Transformer classifier inference using TorchServe: https://github.com/kamalkraj/Swin-Transformer-Serve
-
-## Contributing
-
-This project welcomes contributions and suggestions. Most contributions require you to agree to a
-Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
-the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
-
-When you submit a pull request, a CLA bot will automatically determine whether you need to provide
-a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
-provided by the bot. You will only need to do this once across all repos using our CLA.
-
-This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
-For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
-contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
-
-## Trademarks
-
-This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
-trademarks or logos is subject to and must follow
-[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
-Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
-Any use of third-party trademarks or logos are subject to those third-party's policies.
diff --git a/config.py b/config.py
index 1671ec34..27e3d295 100644
--- a/config.py
+++ b/config.py
@@ -6,33 +6,34 @@
# --------------------------------------------------------'
import os
+
import yaml
from yacs.config import CfgNode as CN
+
+
_C = CN()
# Base config files
-_C.BASE = ['']
+_C.BASE = [""]
# -----------------------------------------------------------------------------
# Data settings
# -----------------------------------------------------------------------------
_C.DATA = CN()
-# Batch size for a single GPU, could be overwritten by command line argument
-_C.DATA.BATCH_SIZE = 128
# Path to dataset, could be overwritten by command line argument
-_C.DATA.DATA_PATH = ''
+_C.DATA.DATA_PATH = ""
# Dataset name
-_C.DATA.DATASET = 'imagenet'
+_C.DATA.DATASET = "imagenet"
# Input image size
_C.DATA.IMG_SIZE = 224
# Interpolation to resize image (random, bilinear, bicubic)
-_C.DATA.INTERPOLATION = 'bicubic'
+_C.DATA.INTERPOLATION = "bicubic"
# Use zipped dataset instead of folder dataset
# could be overwritten by command line argument
_C.DATA.ZIP_MODE = False
# Cache Data in Memory, could be overwritten by command line argument
-_C.DATA.CACHE_MODE = 'part'
+_C.DATA.CACHE_MODE = "part"
# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
_C.DATA.PIN_MEMORY = True
# Number of data loading threads
@@ -48,14 +49,14 @@
# -----------------------------------------------------------------------------
_C.MODEL = CN()
# Model type
-_C.MODEL.TYPE = 'swin'
+_C.MODEL.TYPE = "swin"
# Model name
-_C.MODEL.NAME = 'swin_tiny_patch4_window7_224'
+_C.MODEL.NAME = "swin_tiny_patch4_window7_224"
# Pretrained weight from checkpoint, could be imagenet22k pretrained weight
# could be overwritten by command line argument
-_C.MODEL.PRETRAINED = ''
+_C.MODEL.PRETRAINED = ""
# Checkpoint to resume, could be overwritten by command line argument
-_C.MODEL.RESUME = ''
+_C.MODEL.RESUME = ""
# Number of classes, overwritten in data preparation
_C.MODEL.NUM_CLASSES = 1000
# Dropout rate
@@ -73,7 +74,7 @@
_C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
_C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
_C.MODEL.SWIN.WINDOW_SIZE = 7
-_C.MODEL.SWIN.MLP_RATIO = 4.
+_C.MODEL.SWIN.MLP_RATIO = 4.0
_C.MODEL.SWIN.QKV_BIAS = True
_C.MODEL.SWIN.QK_SCALE = None
_C.MODEL.SWIN.APE = False
@@ -87,7 +88,7 @@
_C.MODEL.SWINV2.DEPTHS = [2, 2, 6, 2]
_C.MODEL.SWINV2.NUM_HEADS = [3, 6, 12, 24]
_C.MODEL.SWINV2.WINDOW_SIZE = 7
-_C.MODEL.SWINV2.MLP_RATIO = 4.
+_C.MODEL.SWINV2.MLP_RATIO = 4.0
_C.MODEL.SWINV2.QKV_BIAS = True
_C.MODEL.SWINV2.APE = False
_C.MODEL.SWINV2.PATCH_NORM = True
@@ -101,7 +102,7 @@
_C.MODEL.SWIN_MOE.DEPTHS = [2, 2, 6, 2]
_C.MODEL.SWIN_MOE.NUM_HEADS = [3, 6, 12, 24]
_C.MODEL.SWIN_MOE.WINDOW_SIZE = 7
-_C.MODEL.SWIN_MOE.MLP_RATIO = 4.
+_C.MODEL.SWIN_MOE.MLP_RATIO = 4.0
_C.MODEL.SWIN_MOE.QKV_BIAS = True
_C.MODEL.SWIN_MOE.QK_SCALE = None
_C.MODEL.SWIN_MOE.APE = False
@@ -131,7 +132,7 @@
_C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2]
_C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24]
_C.MODEL.SWIN_MLP.WINDOW_SIZE = 7
-_C.MODEL.SWIN_MLP.MLP_RATIO = 4.
+_C.MODEL.SWIN_MLP.MLP_RATIO = 4.0
_C.MODEL.SWIN_MLP.APE = False
_C.MODEL.SWIN_MLP.PATCH_NORM = True
@@ -145,6 +146,10 @@
# Training settings
# -----------------------------------------------------------------------------
_C.TRAIN = CN()
+# Batch size for a single GPU, could be overwritten by command line argument
+_C.TRAIN.DEVICE_BATCH_SIZE = 128
+# Global batch size = DEVICE_BATCH_SIZE * N_PROCS * ACCUMULATION_STEPS
+_C.TRAIN.GLOBAL_BATCH_SIZE = 1024
_C.TRAIN.START_EPOCH = 0
_C.TRAIN.EPOCHS = 300
_C.TRAIN.WARMUP_EPOCHS = 20
@@ -165,7 +170,7 @@
# LR scheduler
_C.TRAIN.LR_SCHEDULER = CN()
-_C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
+_C.TRAIN.LR_SCHEDULER.NAME = "cosine"
# Epoch interval to decay LR, used in StepLRScheduler
_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
# LR decay rate, used in StepLRScheduler
@@ -178,7 +183,7 @@
# Optimizer
_C.TRAIN.OPTIMIZER = CN()
-_C.TRAIN.OPTIMIZER.NAME = 'adamw'
+_C.TRAIN.OPTIMIZER.NAME = "adamw"
# Optimizer Epsilon
_C.TRAIN.OPTIMIZER.EPS = 1e-8
# Optimizer Betas
@@ -193,6 +198,16 @@
_C.TRAIN.MOE = CN()
# Only save model on master device
_C.TRAIN.MOE.SAVE_MASTER = False
+
+# Hierarchical coefficients for loss
+_C.TRAIN.HIERARCHICAL_COEFFS = (1,)
+
+# [Debugging] How many batches of the training data to overfit.
+_C.TRAIN.OVERFIT_BATCHES = 0
+
+# Percentage of data for low data regieme
+_C.TRAIN.DATA_PERCENTAGE = 1
+
# -----------------------------------------------------------------------------
# Augmentation settings
# -----------------------------------------------------------------------------
@@ -200,11 +215,11 @@
# Color jitter factor
_C.AUG.COLOR_JITTER = 0.4
# Use AutoAugment policy. "v0" or "original"
-_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
+_C.AUG.AUTO_AUGMENT = "rand-m9-mstd0.5-inc1"
# Random erase prob
_C.AUG.REPROB = 0.25
# Random erase mode
-_C.AUG.REMODE = 'pixel'
+_C.AUG.REMODE = "pixel"
# Random erase count
_C.AUG.RECOUNT = 1
# Mixup alpha, mixup enabled if > 0
@@ -218,7 +233,7 @@
# Probability of switching to cutmix when both mixup and cutmix enabled
_C.AUG.MIXUP_SWITCH_PROB = 0.5
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
-_C.AUG.MIXUP_MODE = 'batch'
+_C.AUG.MIXUP_MODE = "batch"
# -----------------------------------------------------------------------------
# Testing settings
@@ -230,20 +245,32 @@
_C.TEST.SEQUENTIAL = False
_C.TEST.SHUFFLE = False
+# -----------------------------------------------------------------------------
+# Experiment Settings
+# -----------------------------------------------------------------------------
+_C.EXPERIMENT = CN()
+# The experiment name. This is a human-readable name that is easy to read.
+_C.EXPERIMENT.NAME = "default-dragonfruit"
+# The wandb id for logging.
+# Generate this id with scripts/generate_wandb_id
+_C.EXPERIMENT.WANDB_ID = ""
+
# -----------------------------------------------------------------------------
# Misc
# -----------------------------------------------------------------------------
+
+# Whether we are doing hierarchical classification
+_C.HIERARCHICAL = False
+
# [SimMIM] Whether to enable pytorch amp, overwritten by command line argument
_C.ENABLE_AMP = False
# Enable Pytorch automatic mixed precision (amp).
_C.AMP_ENABLE = True
# [Deprecated] Mixed precision opt level of apex, if O0, no apex amp is used ('O0', 'O1', 'O2')
-_C.AMP_OPT_LEVEL = ''
+_C.AMP_OPT_LEVEL = ""
# Path to output folder, overwritten by command line argument
-_C.OUTPUT = ''
-# Tag of experiment, overwritten by command line argument
-_C.TAG = 'default'
+_C.OUTPUT = ""
# Frequency to save checkpoint
_C.SAVE_FREQ = 1
# Frequency to logging info
@@ -263,15 +290,15 @@
def _update_config_from_file(config, cfg_file):
config.defrost()
- with open(cfg_file, 'r') as f:
+ with open(cfg_file, "r") as f:
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
- for cfg in yaml_cfg.setdefault('BASE', ['']):
+ for cfg in yaml_cfg.setdefault("BASE", [""]):
if cfg:
_update_config_from_file(
config, os.path.join(os.path.dirname(cfg_file), cfg)
)
- print('=> merge config from {}'.format(cfg_file))
+ print("=> merge config from {}".format(cfg_file))
config.merge_from_file(cfg_file)
config.freeze()
@@ -284,60 +311,91 @@ def update_config(config, args):
config.merge_from_list(args.opts)
def _check_args(name):
- if hasattr(args, name) and eval(f'args.{name}'):
+ if hasattr(args, name) and eval(f"args.{name}"):
return True
return False
# merge from specific arguments
- if _check_args('batch_size'):
- config.DATA.BATCH_SIZE = args.batch_size
- if _check_args('data_path'):
- config.DATA.DATA_PATH = args.data_path
- if _check_args('zip'):
+ if _check_args("batch_size"):
+ config.TRAIN.DEVICE_BATCH_SIZE = args.batch_size
+ if _check_args("data_path"):
+ config.DATA.DATA_PATH = os.path.abspath(args.data_path)
+ if _check_args("zip"):
config.DATA.ZIP_MODE = True
- if _check_args('cache_mode'):
+ if _check_args("cache_mode"):
config.DATA.CACHE_MODE = args.cache_mode
- if _check_args('pretrained'):
+ if _check_args("pretrained"):
config.MODEL.PRETRAINED = args.pretrained
- if _check_args('resume'):
+ if _check_args("resume"):
config.MODEL.RESUME = args.resume
- if _check_args('accumulation_steps'):
- config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
- if _check_args('use_checkpoint'):
+ if _check_args("use_checkpoint"):
config.TRAIN.USE_CHECKPOINT = True
- if _check_args('amp_opt_level'):
+ if _check_args("amp_opt_level"):
print("[warning] Apex amp has been deprecated, please use pytorch amp instead!")
- if args.amp_opt_level == 'O0':
+ if args.amp_opt_level == "O0":
config.AMP_ENABLE = False
- if _check_args('disable_amp'):
+ if _check_args("disable_amp"):
config.AMP_ENABLE = False
- if _check_args('output'):
+ if _check_args("output"):
config.OUTPUT = args.output
- if _check_args('tag'):
- config.TAG = args.tag
- if _check_args('eval'):
+ if _check_args("eval"):
config.EVAL_MODE = True
- if _check_args('throughput'):
+ if _check_args("throughput"):
config.THROUGHPUT_MODE = True
# [SimMIM]
- if _check_args('enable_amp'):
+ if _check_args("enable_amp"):
config.ENABLE_AMP = args.enable_amp
# for acceleration
- if _check_args('fused_window_process'):
+ if _check_args("fused_window_process"):
config.FUSED_WINDOW_PROCESS = True
- if _check_args('fused_layernorm'):
+ if _check_args("fused_layernorm"):
config.FUSED_LAYERNORM = True
- ## Overwrite optimizer if not None, currently we use it for [fused_adam, fused_lamb]
- if _check_args('optim'):
+ # Overwrite optimizer if not None, currently we use it for [fused_adam, fused_lamb]
+ if _check_args("optim"):
config.TRAIN.OPTIMIZER.NAME = args.optim
- # set local rank for distributed training
- config.LOCAL_RANK = args.local_rank
+ if _check_args("low_data"):
+ config.TRAIN.DATA_PERCENTAGE = args.low_data
+
+ # Use os.environ["LOCAL_RANK"] rather than --local_rank
+ if "LOCAL_RANK" in os.environ:
+ # set local rank for distributed training
+ config.LOCAL_RANK = int(os.environ["LOCAL_RANK"])
+
+ # Use this to calculate accumulation steps
+ if "LOCAL_WORLD_SIZE" in os.environ:
+
+ def divide_cleanly(a, b):
+ assert a % b == 0, f"{a} / {b} has remainder {a % b}"
+ return a // b
+
+ n_procs = int(os.environ["LOCAL_WORLD_SIZE"])
+ desired_device_batch_size = divide_cleanly(
+ config.TRAIN.GLOBAL_BATCH_SIZE, n_procs
+ )
+ actual_device_batch_size = config.TRAIN.DEVICE_BATCH_SIZE
+
+ if actual_device_batch_size > desired_device_batch_size:
+ print(
+ f"Decreasing device batch size from {actual_device_batch_size} to {desired_device_batch_size} so your global bath size is {config.TRAIN.GLOBAL_BATCH_SIZE}, not {desired_device_batch_size * n_procs}!"
+ )
+ config.TRAIN.ACCUMULATION_STEPS = 1
+ config.TRAIN.DEVICE_BATCH_SIZE = desired_device_batch_size
+ elif desired_device_batch_size == actual_device_batch_size:
+ config.TRAIN.ACCUMULATION_STEPS = 1
+ else:
+ assert desired_device_batch_size > actual_device_batch_size
+ config.TRAIN.ACCUMULATION_STEPS = divide_cleanly(
+ desired_device_batch_size, actual_device_batch_size
+ )
+ print(
+ f"Using {config.TRAIN.ACCUMULATION_STEPS} accumulation steps so your global batch size is {config.TRAIN.GLOBAL_BATCH_SIZE}, not {actual_device_batch_size * n_procs}!"
+ )
# output folder
- config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
+ config.OUTPUT = os.path.join(config.OUTPUT, config.EXPERIMENT.NAME)
config.freeze()
@@ -347,6 +405,231 @@ def get_config(args):
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
config = _C.clone()
+
update_config(config, args)
+
return config
+
+
+
+#################################### Added sript for yaml file generation #############################################
+
+# Distribution = Literal["normal", "uniform", "loguniform"]
+import tomli
+import utils
+import dataclasses
+import copy
+from typing_extensions import Literal
+
+from typing import (
+ Any,
+ Dict,
+ Iterator,
+ List,
+ Optional,
+ Type,
+ TypeVar,
+ Union,
+)
+Distribution = Literal["normal", "uniform", "loguniform"]
+T = TypeVar("T", bound="Config")
+
+class Config:
+ @classmethod
+ def from_dict(cls: Type[T], dct: Dict[str, Any]) -> T:
+ for field in dataclasses.fields(cls):
+ if (
+ isinstance(field.type, type)
+ and issubclass(field.type, Config)
+ and field.name in dct
+ and not isinstance(dct[field.name], field.type)
+ ):
+ if not isinstance(dct[field.name], dict):
+ logger.warn(
+ "Subdict is not a dict! [cls: %s, field name: %s, field type: %s, actual type: %s]",
+ cls,
+ field.name,
+ field.type,
+ type(dct[field.name]),
+ )
+ dct[field.name] = field.type.from_dict(dct[field.name])
+
+ return cls(**dct)
+
+ @classmethod
+ def get_toml_name(cls) -> str:
+ # Because I'm a bad programmer and I do hacky things.
+ return cls.__name__[: cls.__name__.lower().find("config")].lower()
+
+ @classmethod
+ def from_existing(cls: Type[T], other: Type[T], **overrides) -> T:
+ kwargs = {**dataclasses.asdict(other), **overrides}
+
+ return cls(**kwargs)
+
+ @property
+ def pretty(self) -> str:
+ return json.dumps(dataclasses.asdict(self), indent=4)
+
+ def __str__(self) -> str:
+ return json.dumps(dataclasses.asdict(self))
+
+ def validate_field(self, fname: str, ftype) -> None:
+ choices = get_args(ftype)
+ if getattr(self, fname) not in choices:
+ raise ValueError(f"self.{fname} must be one of {', '.join(choices)}")
+
+@dataclasses.dataclass(frozen=True)
+class RandomVecConfig(Config):
+ distribution: Optional[Distribution] = None
+
+ # Distribution keyword args.
+ dist_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
+
+ def __post_init__(self) -> None:
+ if self.distribution is not None:
+ self.validate_field("distribution", Distribution)
+Layer = Union[
+ int,
+ Literal[
+ "sigmoid",
+ "tanh",
+ "output",
+ "cos",
+ "sine",
+ "layernorm",
+ "groupnorm",
+ "1/x",
+ "nonlinear-wht",
+ "dropout",
+ ],
+]
+
+@dataclasses.dataclass(frozen=True)
+class ProjectionConfig(Config):
+ layers: List[Layer] = dataclasses.field(default_factory=lambda: ["output"])
+ layer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
+
+ @classmethod
+ def from_dict(cls, dct) -> "ProjectionConfig":
+ """
+ I reimplement this method because the toml dict will have a string for layers that needs to be evaluated to a real Python list.
+ """
+ for key in dct:
+ if key == "layers" and isinstance(dct[key], str):
+ dct[key] = eval(dct[key])
+
+ return cls(**dct)
+
+PromptType = Literal["uuid", "token", "vocab", "chunk-n", "natural-n"]
+
+@dataclasses.dataclass(frozen=True)
+class DataConfig(Config):
+ file: str
+ overwrite_cache: bool = False
+ """
+ Can be one of 'uuid', 'token', or 'vocab'.
+ * uuid: encodes a uuid as the prompt (typically between 20-30 tokens for GPT2).
+ * token: adds a new token to the vocabulary for each chunk (<|start0|>, <|start1|>, etc.)
+ * vocab: finds an existing token in the vocabulary that's not in any of the examples and uses it as th e prompt.
+ * chunk-n: "Chunk 1: ", "Chunk 2: ", ...
+ """
+ prompt_type: PromptType = "uuid"
+
+ chunk_length: Union[Literal["longest"], int] = "longest"
+
+ def __post_init__(self) -> None:
+ if not os.path.exists(self.file):
+ raise ValueError(f"{self.file} does not exist!")
+
+ self.validate_field("prompt_type", PromptType)
+
+ if self.chunk_length != "longest":
+ assert isinstance(self.chunk_length, int)
+
+ def get_text(self) -> str:
+ assert self.file is not None
+
+ with open(self.file, "r") as file:
+ return file.read()
+
+@dataclasses.dataclass(frozen=True)
+class ModelConfig(Config):
+ language_model_name_or_path: str
+ intrinsic_dimension: Optional[int] = None
+
+ # Structure-aware intrinsic dimension (SAID)
+ # Has no effect when intrinsic_dimension is None.
+ intrinsic_dimension_said: bool = False
+
+ # temperature of 1.0 has no effect, lower tend toward greedy sampling
+ temperature: float = 1.0
+
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering.
+ top_k: int = 0
+
+ # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
+ top_p: float = 0.9
+
+ # primarily useful for CTRL model; in that case, use 1.2
+ repetition_penalty: float = 1.0
+
+ # optional stop token (ignore text generated after this token)
+ stop_token: Optional[str] = None
+
+ # context window size
+ context_window: int = 1024
+
+ # dropout probability for fully connected layers in embeddings, encoder, and pooler, embeddings, and attention.
+ dropout: float = 0.0
+
+ # dropout probability for the intrinsic dimension layer(s)
+ int_dim_dropout: float = 0.0
+
+ # Whether to use pre-trained weights.
+ pretrained: bool = True
+
+ random_vector: RandomVecConfig = dataclasses.field(default_factory=RandomVecConfig)
+
+ projection: ProjectionConfig = dataclasses.field(default_factory=ProjectionConfig)
+
+ normalized: bool = True
+ scaled: bool = False
+ scaling_factor: float = 1
+
+ def __post_init__(self) -> None:
+ assert isinstance(self.random_vector, RandomVecConfig), str(
+ type(self.random_vector)
+ )
+ assert isinstance(self.projection, ProjectionConfig), str(type(self.projection))
+
+SeedSource = Literal["trial", "config", "random"]
+
+@dataclasses.dataclass(frozen=True)
+class ExperimentConfig(Config):
+ model: ModelConfig
+ # tokenizer: TokenizerConfig
+
+ ####Below two lines commented by me ##############
+ data: DataConfig
+ # training: TrainingConfig
+
+ trials: int = 3
+ save_weights: bool = True
+ seed_source: SeedSource = "trial"
+ seed: int = 0
+
+ def __post_init__(self) -> None:
+ self.validate_field("seed_source", SeedSource)
+
+
+# def load_configs(config_file: str) -> Iterator[ExperimentConfig]:
+# """
+# A config file could contain many experiments. For any field in a config file, if it is a list, then it turns into multiple experiments. If there are multiple lists, then each combination of elements from each list forms a new experiment.
+# """
+# with open(config_file, "r") as file:
+# config_dict = tomli.loads(file.read())
+
+# for flat in utils.flattened(config_dict):
+# yield ExperimentConfig.from_dict(copy.deepcopy(flat))
\ No newline at end of file
diff --git a/configs/hierarchical-vision-project/bold-banana-192.yaml b/configs/hierarchical-vision-project/bold-banana-192.yaml
new file mode 100644
index 00000000..208d989e
--- /dev/null
+++ b/configs/hierarchical-vision-project/bold-banana-192.yaml
@@ -0,0 +1,29 @@
+DATA:
+ DATASET: inat21
+ IMG_SIZE: 192
+ NUM_WORKERS: 32
+MODEL:
+ TYPE: swinv2
+ NAME: swinv2_base_window12
+ DROP_PATH_RATE: 0.2
+ SWINV2:
+ EMBED_DIM: 128
+ DEPTHS: [ 2, 2, 18, 2 ]
+ NUM_HEADS: [ 4, 8, 16, 32 ]
+ WINDOW_SIZE: 12
+TRAIN:
+ # Want a global batch size of 2048 because SwinV2 was trained on 16 V100s with batch size 128 (I think)
+ # But we are going to use a global batch size of 1024 because it's faster (throughput).
+ GLOBAL_BATCH_SIZE: 1024
+
+ # We are using limited epochs based on pre-training configs for imagenet22k
+ # Then we will pre-train on 256x256 for 30 epochs
+ EPOCHS: 90
+ WARMUP_EPOCHS: 5
+ WEIGHT_DECAY: 0.1
+
+SAVE_FREQ: 4
+
+EXPERIMENT:
+ NAME: bold-banana-192
+ WANDB_ID: 12907g41
diff --git a/configs/hierarchical-vision-project/fuzzy-fig-192.yaml b/configs/hierarchical-vision-project/fuzzy-fig-192.yaml
new file mode 100644
index 00000000..d68232a5
--- /dev/null
+++ b/configs/hierarchical-vision-project/fuzzy-fig-192.yaml
@@ -0,0 +1,32 @@
+DATA:
+ DATASET: inat21
+ IMG_SIZE: 192
+ NUM_WORKERS: 32
+MODEL:
+ TYPE: swinv2
+ NAME: swinv2_base_window12
+ DROP_PATH_RATE: 0.2
+ SWINV2:
+ EMBED_DIM: 128
+ DEPTHS: [ 2, 2, 18, 2 ]
+ NUM_HEADS: [ 4, 8, 16, 32 ]
+ WINDOW_SIZE: 12
+TRAIN:
+ # Want a global batch size of 2048 because SwinV2 was trained on 16 V100s with batch size 128 (I think)
+ # But we are going to use a global batch size of 1024 because it's faster (throughput).
+ GLOBAL_BATCH_SIZE: 1024
+
+ # We are using limited epochs based on pre-training configs for imagenet22k
+ # Then we will pre-train on 256x256 for 30 epochs
+ EPOCHS: 90
+ WARMUP_EPOCHS: 5
+ WEIGHT_DECAY: 0.1
+
+ # Use 1/4 of original learning rates
+ BASE_LR: 1.25e-4
+ WARMUP_LR: 1.25e-7
+ MIN_LR: 1.25e-6
+
+EXPERIMENT:
+ NAME: fuzzy-fig-192
+ WANDB_ID: 2c0bq1h2
diff --git a/configs/hierarchical-vision-project/groovy-grape-192.yaml b/configs/hierarchical-vision-project/groovy-grape-192.yaml
new file mode 100644
index 00000000..02673338
--- /dev/null
+++ b/configs/hierarchical-vision-project/groovy-grape-192.yaml
@@ -0,0 +1,37 @@
+
+DATA:
+ DATASET: inat21
+ IMG_SIZE: 192
+ NUM_WORKERS: 32
+MODEL:
+ TYPE: swinv2
+ NAME: swinv2_base_window12
+ DROP_PATH_RATE: 0.2
+ SWINV2:
+ EMBED_DIM: 128
+ DEPTHS: [ 2, 2, 18, 2 ]
+ NUM_HEADS: [ 4, 8, 16, 32 ]
+ WINDOW_SIZE: 12
+TRAIN:
+ # Want a global batch size of 2048 because SwinV2 was trained on 16 V100s with batch size 128 (I think)
+ # But we are going to use a global batch size of 1024 because it's faster (throughput).
+ GLOBAL_BATCH_SIZE: 1024
+
+ # We are using limited epochs based on pre-training configs for imagenet22k
+ # Then we will pre-train on 256x256 for 30 epochs
+ EPOCHS: 90
+ WARMUP_EPOCHS: 5
+ WEIGHT_DECAY: 0.1
+
+ # Use 1/4 of original learning rates
+ BASE_LR: 1.25e-4
+ WARMUP_LR: 1.25e-7
+ MIN_LR: 1.25e-6
+
+ HIERARCHICAL_COEFFS: [ 8, 5.65, 4, 2.82, 2, 1.41, 1 ]
+
+EXPERIMENT:
+ NAME: groovy-grape-192
+ WANDB_ID: 3jcq2v9b
+
+HIERARCHICAL: true
diff --git a/configs/hierarchical-vision-project/groovy-grape-256.yaml b/configs/hierarchical-vision-project/groovy-grape-256.yaml
new file mode 100644
index 00000000..d000dccd
--- /dev/null
+++ b/configs/hierarchical-vision-project/groovy-grape-256.yaml
@@ -0,0 +1,34 @@
+DATA:
+ DATASET: inat21
+ IMG_SIZE: 256
+ NUM_WORKERS: 32
+MODEL:
+ TYPE: swinv2
+ NAME: swinv2_base_window16
+ DROP_PATH_RATE: 0.2
+ SWINV2:
+ EMBED_DIM: 128
+ DEPTHS: [ 2, 2, 18, 2 ]
+ NUM_HEADS: [ 4, 8, 16, 32 ]
+ WINDOW_SIZE: 16
+ PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
+TRAIN:
+ GLOBAL_BATCH_SIZE: 1024
+ EPOCHS: 30
+ WARMUP_EPOCHS: 5
+
+ # Should weight decay be this low?
+ # I don't think so, but I am sticking with the default for now.
+ WEIGHT_DECAY: 1.0e-8
+
+ BASE_LR: 2.0e-05
+ WARMUP_LR: 2.0e-08
+ MIN_LR: 2.0e-07
+
+ HIERARCHICAL_COEFFS: [ 8, 5.65, 4, 2.82, 2, 1.41, 1 ]
+
+EXPERIMENT:
+ NAME: groovy-grape-256
+ WANDB_ID: 11o47wpm
+
+HIERARCHICAL: true
diff --git a/configs/hierarchical-vision-project/outrageous-orange-192.yaml b/configs/hierarchical-vision-project/outrageous-orange-192.yaml
new file mode 100644
index 00000000..15d835e7
--- /dev/null
+++ b/configs/hierarchical-vision-project/outrageous-orange-192.yaml
@@ -0,0 +1,31 @@
+DATA:
+ DATASET: inat21
+ IMG_SIZE: 192
+ NUM_WORKERS: 32
+MODEL:
+ TYPE: swinv2
+ NAME: swinv2_base_window12
+ DROP_PATH_RATE: 0.2
+ SWINV2:
+ EMBED_DIM: 128
+ DEPTHS: [ 2, 2, 18, 2 ]
+ NUM_HEADS: [ 4, 8, 16, 32 ]
+ WINDOW_SIZE: 12
+TRAIN:
+ # Want a global batch size of 2048 because SwinV2 was trained on 16 V100s with batch size 128 (I think)
+ # But we are going to use a global batch size of 1024 because it's faster (throughput).
+ GLOBAL_BATCH_SIZE: 1024
+
+ # We are using limited epochs based on pre-training configs for imagenet22k
+ # Then we will pre-train on 256x256 for 30 epochs
+ EPOCHS: 90
+ WARMUP_EPOCHS: 5
+ WEIGHT_DECAY: 0.1
+ HIERARCHICAL_COEFFS: [ 8, 5.65, 4, 2.82, 2, 1.41, 1 ]
+
+EXPERIMENT:
+ NAME: outrageous-orange-192
+ WANDB_ID: y6zzxboz
+
+HIERARCHICAL: true
+SAVE_FREQ: 4
diff --git a/configs/swinv2/swinv2_base_patch4_window12_192_inat21_hierarchical_lr1.25.yaml b/configs/swinv2/swinv2_base_patch4_window12_192_inat21_hierarchical_lr1.25.yaml
new file mode 100644
index 00000000..cc23379a
--- /dev/null
+++ b/configs/swinv2/swinv2_base_patch4_window12_192_inat21_hierarchical_lr1.25.yaml
@@ -0,0 +1,33 @@
+DATA:
+ DATASET: inat21
+ IMG_SIZE: 192
+ DATA_PATH: /mnt/10tb/data/inat21/resize-192
+ NUM_WORKERS: 32
+ BATCH_SIZE: 64
+MODEL:
+ TYPE: swinv2
+ NAME: swinv2_base_patch4_window12_192_inat21_hierarchical_lr1.25
+ DROP_PATH_RATE: 0.2
+ SWINV2:
+ EMBED_DIM: 128
+ DEPTHS: [ 2, 2, 18, 2 ]
+ NUM_HEADS: [ 4, 8, 16, 32 ]
+ WINDOW_SIZE: 12
+TRAIN:
+ # Want a global batch size of 2048 because SwinV2 was trained on 16 V100s with batch size 128 (I think)
+ # But we are going to use a global batch size of 1024 because it's faster (throughput).
+ ACCUMULATION_STEPS: 2
+
+ # We are using limited epochs based on pre-training configs for imagenet22k
+ # Then we will pre-train on 256x256 for 30 epochs
+ EPOCHS: 90
+ WARMUP_EPOCHS: 5
+ WEIGHT_DECAY: 0.1
+
+ # Use 1/4 of original learning rates
+ BASE_LR: 1.25e-4
+ WARMUP_LR: 1.25e-7
+ MIN_LR: 1.25e-6
+ HIERARCHICAL_COEFFS: [ 8, 5.65, 4, 2.82, 2, 1.41, 1 ]
+
+HIERARCHICAL: true
diff --git a/configs/swinv2/swinv2_base_patch4_window12_192_inat21_hierarchical_lr2.5.yaml b/configs/swinv2/swinv2_base_patch4_window12_192_inat21_hierarchical_lr2.5.yaml
new file mode 100644
index 00000000..5bc4336f
--- /dev/null
+++ b/configs/swinv2/swinv2_base_patch4_window12_192_inat21_hierarchical_lr2.5.yaml
@@ -0,0 +1,33 @@
+DATA:
+ DATASET: inat21
+ IMG_SIZE: 192
+ DATA_PATH: /mnt/10tb/data/inat21/resize-192
+ NUM_WORKERS: 32
+ BATCH_SIZE: 64
+MODEL:
+ TYPE: swinv2
+ NAME: sweet-strawberry-192
+ DROP_PATH_RATE: 0.2
+ SWINV2:
+ EMBED_DIM: 128
+ DEPTHS: [ 2, 2, 18, 2 ]
+ NUM_HEADS: [ 4, 8, 16, 32 ]
+ WINDOW_SIZE: 12
+TRAIN:
+ # Want a global batch size of 2048 because SwinV2 was trained on 16 V100s with batch size 128 (I think)
+ # But we are going to use a global batch size of 1024 because it's faster (throughput).
+ ACCUMULATION_STEPS: 2
+
+ # We are using limited epochs based on pre-training configs for imagenet22k
+ # Then we will pre-train on 256x256 for 30 epochs
+ EPOCHS: 90
+ WARMUP_EPOCHS: 5
+ WEIGHT_DECAY: 0.1
+
+ # Use 1/2 of original learning rates
+ BASE_LR: 2.5e-4
+ WARMUP_LR: 2.5e-7
+ MIN_LR: 2.5e-6
+ HIERARCHICAL_COEFFS: [ 8, 5.65, 4, 2.82, 2, 1.41, 1 ]
+
+HIERARCHICAL: true
diff --git a/configs/swinv2/swinv2_base_patch4_window12_192_inat21_hierarchical_lr5.yaml b/configs/swinv2/swinv2_base_patch4_window12_192_inat21_hierarchical_lr5.yaml
new file mode 100644
index 00000000..2a2d422f
--- /dev/null
+++ b/configs/swinv2/swinv2_base_patch4_window12_192_inat21_hierarchical_lr5.yaml
@@ -0,0 +1,29 @@
+DATA:
+ DATASET: inat21
+ IMG_SIZE: 192
+ DATA_PATH: /research/nfs_su_809/cv_datasets/inat21/train_val_192
+ NUM_WORKERS: 32
+ BATCH_SIZE: 128
+MODEL:
+ TYPE: swinv2
+ NAME: outrageous-orange-192
+ DROP_PATH_RATE: 0.2
+ SWINV2:
+ EMBED_DIM: 128
+ DEPTHS: [ 2, 2, 18, 2 ]
+ NUM_HEADS: [ 4, 8, 16, 32 ]
+ WINDOW_SIZE: 12
+TRAIN:
+ # Want a global batch size of 2048 because SwinV2 was trained on 16 V100s with batch size 128 (I think)
+ # But we are going to use a global batch size of 1024 because it's faster (throughput).
+ ACCUMULATION_STEPS: 2
+
+ # We are using limited epochs based on pre-training configs for imagenet22k
+ # Then we will pre-train on 256x256 for 30 epochs
+ EPOCHS: 90
+ WARMUP_EPOCHS: 5
+ WEIGHT_DECAY: 0.1
+ HIERARCHICAL_COEFFS: [ 8, 5.65, 4, 2.82, 2, 1.41, 1 ]
+
+SAVE_FREQ: 4
+HIERARCHICAL: true
diff --git a/configs/swinv2/swinv2_base_patch4_window12to16_192to256_inat21_hierarchical_lr2.5_ft.yaml b/configs/swinv2/swinv2_base_patch4_window12to16_192to256_inat21_hierarchical_lr2.5_ft.yaml
new file mode 100644
index 00000000..c3f6b5df
--- /dev/null
+++ b/configs/swinv2/swinv2_base_patch4_window12to16_192to256_inat21_hierarchical_lr2.5_ft.yaml
@@ -0,0 +1,34 @@
+DATA:
+ DATASET: inat21
+ NUM_WORKERS: 32
+ BATCH_SIZE: 16
+ IMG_SIZE: 256
+ DATA_PATH: /mnt/10tb/data/inat21/resize-256
+MODEL:
+ TYPE: swinv2
+ PRETRAINED: /mnt/10tb/models/swinv2_base_patch4_window12_192_inat21_hierarchical_lr2.5_v0_epoch_89.pth
+ NAME: sweet-strawberry-256
+ DROP_PATH_RATE: 0.2
+ SWINV2:
+ EMBED_DIM: 128
+ DEPTHS: [ 2, 2, 18, 2 ]
+ NUM_HEADS: [ 4, 8, 16, 32 ]
+ WINDOW_SIZE: 16
+ PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
+TRAIN:
+ # Global batch size of 1024
+ ACCUMULATION_STEPS: 8
+ EPOCHS: 30
+ WARMUP_EPOCHS: 5
+
+ # Should weight decay be this low?
+ # I don't think so, but I am sticking with the default for now.
+ WEIGHT_DECAY: 1.0e-8
+
+ BASE_LR: 2.0e-05
+ WARMUP_LR: 2.0e-08
+ MIN_LR: 2.0e-07
+
+ HIERARCHICAL_COEFFS: [ 8, 5.65, 4, 2.82, 2, 1.41, 1 ]
+
+HIERARCHICAL: true
diff --git a/data/__init__.py b/data/__init__.py
index 5baad7ed..a8cfa99a 100644
--- a/data/__init__.py
+++ b/data/__init__.py
@@ -1,6 +1,6 @@
from .build import build_loader as _build_loader
-from .data_simmim_pt import build_loader_simmim
from .data_simmim_ft import build_loader_finetune
+from .data_simmim_pt import build_loader_simmim
def build_loader(config, simmim=False, is_pretrain=False):
diff --git a/data/build.py b/data/build.py
index 5799f253..601ab602 100644
--- a/data/build.py
+++ b/data/build.py
@@ -6,52 +6,71 @@
# --------------------------------------------------------
import os
-import torch
+import random
+
import numpy as np
+import torch
import torch.distributed as dist
+from timm.data import Mixup, create_transform
+from torch.utils.data import Subset
from torchvision import datasets, transforms
-from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
-from timm.data import Mixup
-from timm.data import create_transform
from .cached_image_folder import CachedImageFolder
+from .constants import data_mean_std
+from .hierarchical import HierarchicalImageFolder, HierarchicalMixup
from .imagenet22k_dataset import IN22KDATASET
from .samplers import SubsetRandomSampler
try:
from torchvision.transforms import InterpolationMode
-
def _pil_interp(method):
- if method == 'bicubic':
+ if method == "bicubic":
return InterpolationMode.BICUBIC
- elif method == 'lanczos':
+ elif method == "lanczos":
return InterpolationMode.LANCZOS
- elif method == 'hamming':
+ elif method == "hamming":
return InterpolationMode.HAMMING
else:
# default bilinear, do we want to allow nearest?
return InterpolationMode.BILINEAR
-
import timm.data.transforms as timm_transforms
timm_transforms._pil_interp = _pil_interp
-except:
+except ImportError:
from timm.data.transforms import _pil_interp
def build_loader(config):
config.defrost()
- dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
+ dataset_train, config.MODEL.NUM_CLASSES = build_dataset(
+ is_train=True, config=config
+ )
config.freeze()
- print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")
+ print(
+ f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset"
+ )
dataset_val, _ = build_dataset(is_train=False, config=config)
- print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")
+ print(
+ f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset"
+ )
+
+ # Check if we are overfitting some subset of the training data for debugging
+ if config.TRAIN.OVERFIT_BATCHES > 0:
+ n_examples = config.TRAIN.OVERFIT_BATCHES * config.TRAIN.DEVICE_BATCH_SIZE
+ indices = random.sample(range(len(dataset_train)), n_examples)
+ dataset_train = Subset(dataset_train, indices)
+
+ # Check if training is for low data regieme; select subset of data (newly added script)
+ if config.TRAIN.DATA_PERCENTAGE < 1:
+ n_examples = config.TRAIN.DATA_PERCENTAGE * len(dataset_train) #config.TRAIN.DEVICE_BATCH_SIZE
+ indices = random.sample(range(len(dataset_train)), int(n_examples))
+ dataset_train = Subset(dataset_train, indices)
num_tasks = dist.get_world_size()
global_rank = dist.get_rank()
- if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
+ if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == "part":
indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
sampler_train = SubsetRandomSampler(indices)
else:
@@ -67,57 +86,102 @@ def build_loader(config):
)
data_loader_train = torch.utils.data.DataLoader(
- dataset_train, sampler=sampler_train,
- batch_size=config.DATA.BATCH_SIZE,
+ dataset_train,
+ sampler=sampler_train,
+ batch_size=config.TRAIN.DEVICE_BATCH_SIZE,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=True,
)
data_loader_val = torch.utils.data.DataLoader(
- dataset_val, sampler=sampler_val,
- batch_size=config.DATA.BATCH_SIZE,
+ dataset_val,
+ sampler=sampler_val,
+ batch_size=config.TRAIN.DEVICE_BATCH_SIZE,
shuffle=False,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
- drop_last=False
+ drop_last=False,
)
# setup mixup / cutmix
mixup_fn = None
- mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
+ mixup_active = (
+ config.AUG.MIXUP > 0
+ or config.AUG.CUTMIX > 0.0
+ or config.AUG.CUTMIX_MINMAX is not None
+ )
if mixup_active:
- mixup_fn = Mixup(
- mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
- prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
- label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
+ mixup_args = dict(
+ mixup_alpha=config.AUG.MIXUP,
+ cutmix_alpha=config.AUG.CUTMIX,
+ cutmix_minmax=config.AUG.CUTMIX_MINMAX,
+ prob=config.AUG.MIXUP_PROB,
+ switch_prob=config.AUG.MIXUP_SWITCH_PROB,
+ mode=config.AUG.MIXUP_MODE,
+ label_smoothing=config.MODEL.LABEL_SMOOTHING,
+ num_classes=config.MODEL.NUM_CLASSES,
+ )
+ if config.HIERARCHICAL:
+ mixup_fn = HierarchicalMixup(**mixup_args)
+ else:
+ mixup_fn = Mixup(**mixup_args)
return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
def build_dataset(is_train, config):
transform = build_transform(is_train, config)
- if config.DATA.DATASET == 'imagenet':
- prefix = 'train' if is_train else 'val'
+ if config.DATA.DATASET == "imagenet":
+ prefix = "train" if is_train else "val"
if config.DATA.ZIP_MODE:
ann_file = prefix + "_map.txt"
prefix = prefix + ".zip@/"
- dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
- cache_mode=config.DATA.CACHE_MODE if is_train else 'part')
+ dataset = CachedImageFolder(
+ config.DATA.DATA_PATH,
+ ann_file,
+ prefix,
+ transform,
+ cache_mode=config.DATA.CACHE_MODE if is_train else "part",
+ )
else:
root = os.path.join(config.DATA.DATA_PATH, prefix)
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = 1000
- elif config.DATA.DATASET == 'imagenet22K':
- prefix = 'ILSVRC2011fall_whole'
+ elif config.DATA.DATASET == "imagenet22K":
+ prefix = "ILSVRC2011fall_whole"
if is_train:
ann_file = prefix + "_map_train.txt"
else:
ann_file = prefix + "_map_val.txt"
dataset = IN22KDATASET(config.DATA.DATA_PATH, ann_file, transform)
nb_classes = 21841
+
+ elif config.DATA.DATASET == 'nabird':
+ prefix = 'train' if is_train else 'val'
+ root = os.path.join(config.DATA.DATA_PATH, prefix)
+ dataset = datasets.ImageFolder(root, transform=transform)
+ nb_classes = 555
+
+ elif config.DATA.DATASET == 'ip102':
+ prefix = 'train' if is_train else 'val'
+ root = os.path.join(config.DATA.DATA_PATH, prefix)
+ dataset = datasets.ImageFolder(root, transform=transform)
+ nb_classes = 102
+
+ elif config.DATA.DATASET == "inat21":
+ if config.DATA.ZIP_MODE:
+ raise NotImplementedError("We do not support zipped inat21")
+ prefix = "train" if is_train else "val"
+ root = os.path.join(config.DATA.DATA_PATH, prefix)
+ if config.HIERARCHICAL:
+ dataset = HierarchicalImageFolder(root, transform=transform)
+ nb_classes = dataset.num_classes
+ else:
+ dataset = datasets.ImageFolder(root, transform=transform)
+ nb_classes = 10_000
else:
- raise NotImplementedError("We only support ImageNet Now.")
+ raise NotImplementedError("We only support ImageNet now.")
return dataset, nb_classes
@@ -129,8 +193,12 @@ def build_transform(is_train, config):
transform = create_transform(
input_size=config.DATA.IMG_SIZE,
is_training=True,
- color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
- auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
+ color_jitter=config.AUG.COLOR_JITTER
+ if config.AUG.COLOR_JITTER > 0
+ else None,
+ auto_augment=config.AUG.AUTO_AUGMENT
+ if config.AUG.AUTO_AUGMENT != "none"
+ else None,
re_prob=config.AUG.REPROB,
re_mode=config.AUG.REMODE,
re_count=config.AUG.RECOUNT,
@@ -139,7 +207,9 @@ def build_transform(is_train, config):
if not resize_im:
# replace RandomResizedCropAndInterpolation with
# RandomCrop
- transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
+ transform.transforms[0] = transforms.RandomCrop(
+ config.DATA.IMG_SIZE, padding=4
+ )
return transform
t = []
@@ -147,16 +217,29 @@ def build_transform(is_train, config):
if config.TEST.CROP:
size = int((256 / 224) * config.DATA.IMG_SIZE)
t.append(
- transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
+ transforms.Resize(
+ size, interpolation=_pil_interp(config.DATA.INTERPOLATION)
+ ),
# to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
else:
t.append(
- transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
- interpolation=_pil_interp(config.DATA.INTERPOLATION))
+ transforms.Resize(
+ (config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
+ interpolation=_pil_interp(config.DATA.INTERPOLATION),
+ )
)
+ if config.DATA.DATA_PATH in data_mean_std:
+ mean, std = data_mean_std[config.DATA.DATA_PATH]
+ elif config.DATA.DATASET in data_mean_std:
+ mean, std = data_mean_std[config.DATA.DATASET]
+ else:
+ raise RuntimeError(
+ f"Can't find mean/std for {config.DATA.DATASET} at {config.DATA.DATA_PATH}. Please add it to data/constants.py (try using python -m data.inat normalize for iNat)."
+ )
+
t.append(transforms.ToTensor())
- t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
+ t.append(transforms.Normalize(mean, std))
return transforms.Compose(t)
diff --git a/data/cached_image_folder.py b/data/cached_image_folder.py
index 7e1883b1..732d3236 100644
--- a/data/cached_image_folder.py
+++ b/data/cached_image_folder.py
@@ -8,11 +8,12 @@
import io
import os
import time
+
import torch.distributed as dist
import torch.utils.data as data
from PIL import Image
-from .zipreader import is_zip_path, ZipReader
+from .zipreader import ZipReader, is_zip_path
def has_file_allowed_extension(filename, extensions):
@@ -56,7 +57,7 @@ def make_dataset_with_ann(ann_file, img_prefix, extensions):
with open(ann_file, "r") as f:
contents = f.readlines()
for line_str in contents:
- path_contents = [c for c in line_str.split('\t')]
+ path_contents = [c for c in line_str.split("\t")]
im_file_name = path_contents[0]
class_index = int(path_contents[1])
@@ -89,21 +90,37 @@ class DatasetFolder(data.Dataset):
samples (list): List of (sample path, class_index) tuples
"""
- def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None,
- cache_mode="no"):
+ def __init__(
+ self,
+ root,
+ loader,
+ extensions,
+ ann_file="",
+ img_prefix="",
+ transform=None,
+ target_transform=None,
+ cache_mode="no",
+ ):
# image folder mode
- if ann_file == '':
+ if ann_file == "":
_, class_to_idx = find_classes(root)
samples = make_dataset(root, class_to_idx, extensions)
# zip mode
else:
- samples = make_dataset_with_ann(os.path.join(root, ann_file),
- os.path.join(root, img_prefix),
- extensions)
+ samples = make_dataset_with_ann(
+ os.path.join(root, ann_file), os.path.join(root, img_prefix), extensions
+ )
if len(samples) == 0:
- raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" +
- "Supported extensions are: " + ",".join(extensions)))
+ raise (
+ RuntimeError(
+ "Found 0 files in subfolders of: "
+ + root
+ + "\n"
+ + "Supported extensions are: "
+ + ",".join(extensions)
+ )
+ )
self.root = root
self.loader = loader
@@ -131,7 +148,9 @@ def init_cache(self):
for index in range(n_sample):
if index % (n_sample // 10) == 0:
t = time.time() - start_time
- print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block')
+ print(
+ f"global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block"
+ )
start_time = time.time()
path, target = self.samples[index]
if self.cache_mode == "full":
@@ -162,17 +181,21 @@ def __len__(self):
return len(self.samples)
def __repr__(self):
- fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
- fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
- fmt_str += ' Root Location: {}\n'.format(self.root)
- tmp = ' Transforms (if any): '
- fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
- tmp = ' Target Transforms (if any): '
- fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
+ fmt_str = "Dataset " + self.__class__.__name__ + "\n"
+ fmt_str += " Number of datapoints: {}\n".format(self.__len__())
+ fmt_str += " Root Location: {}\n".format(self.root)
+ tmp = " Transforms (if any): "
+ fmt_str += "{0}{1}\n".format(
+ tmp, self.transform.__repr__().replace("\n", "\n" + " " * len(tmp))
+ )
+ tmp = " Target Transforms (if any): "
+ fmt_str += "{0}{1}".format(
+ tmp, self.target_transform.__repr__().replace("\n", "\n" + " " * len(tmp))
+ )
return fmt_str
-IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
+IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif"]
def pil_loader(path):
@@ -183,14 +206,15 @@ def pil_loader(path):
data = ZipReader.read(path)
img = Image.open(io.BytesIO(data))
else:
- with open(path, 'rb') as f:
+ with open(path, "rb") as f:
img = Image.open(f)
- return img.convert('RGB')
- return img.convert('RGB')
+ return img.convert("RGB")
+ return img.convert("RGB")
def accimage_loader(path):
import accimage
+
try:
return accimage.Image(path)
except IOError:
@@ -200,7 +224,8 @@ def accimage_loader(path):
def default_img_loader(path):
from torchvision import get_image_backend
- if get_image_backend() == 'accimage':
+
+ if get_image_backend() == "accimage":
return accimage_loader(path)
else:
return pil_loader(path)
@@ -225,12 +250,26 @@ class CachedImageFolder(DatasetFolder):
imgs (list): List of (image path, class_index) tuples
"""
- def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None,
- loader=default_img_loader, cache_mode="no"):
- super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
- ann_file=ann_file, img_prefix=img_prefix,
- transform=transform, target_transform=target_transform,
- cache_mode=cache_mode)
+ def __init__(
+ self,
+ root,
+ ann_file="",
+ img_prefix="",
+ transform=None,
+ target_transform=None,
+ loader=default_img_loader,
+ cache_mode="no",
+ ):
+ super(CachedImageFolder, self).__init__(
+ root,
+ loader,
+ IMG_EXTENSIONS,
+ ann_file=ann_file,
+ img_prefix=img_prefix,
+ transform=transform,
+ target_transform=target_transform,
+ cache_mode=cache_mode,
+ )
self.imgs = self.samples
def __getitem__(self, index):
diff --git a/data/constants.py b/data/constants.py
new file mode 100644
index 00000000..3cfd39b1
--- /dev/null
+++ b/data/constants.py
@@ -0,0 +1,42 @@
+import torch
+from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+
+data_mean_std = {
+ "/mnt/10tb/data/inat21/resize-192": (
+ torch.tensor([0.4632684290409088, 0.48004600405693054, 0.37628623843193054]),
+ torch.tensor([0.23754851520061493, 0.22912880778312683, 0.24746596813201904]),
+ ),
+ "/local/scratch/cv_datasets/inat21/resize-192": (
+ torch.tensor([0.23754851520061493, 0.22912880778312683, 0.24746596813201904]),
+ torch.tensor([0.4632684290409088, 0.48004600405693054, 0.37628623843193054]),
+ ),
+ "/mnt/10tb/data/inat21/resize-224": (
+ torch.tensor([0.23762744665145874, 0.2292044311761856, 0.24757201969623566]),
+ torch.tensor([0.4632636606693268, 0.48004215955734253, 0.37622377276420593]),
+ ),
+ "/mnt/10tb/data/inat21/resize-256": (
+ torch.tensor([0.23768986761569977, 0.22925858199596405, 0.2476460039615631]),
+ torch.tensor([0.4632672071456909, 0.480050653219223, 0.37618669867515564]),
+ ),
+ "/local/scratch/cv_datasets/inat21/resize-256": (
+ torch.tensor([0.23768986761569977, 0.22925858199596405, 0.2476460039615631]),
+ torch.tensor([0.4632672071456909, 0.480050653219223, 0.37618669867515564]),
+ ),
+ "/home/ubuntu/AWS_Server/swin-transformer/datasets/nabirds/image_new/full_train_val": (
+ torch.tensor([0.49044106, 0.5076765, 0.46390218]),
+ torch.tensor([0.16689847, 0.1688618, 0.18529404]),
+ ),
+ "/home/ubuntu/AWS_Server/swin-transformer/datasets/nabirds/image_new/10P_train_val": (
+ torch.tensor([0.48588862, 0.50227299, 0.45998148]),
+ torch.tensor([0.16756063, 0.16897439, 0.18490989]),
+ ),
+ "/home/ubuntu/AWS_Server/swin-transformer/datasets/nabirds/image_new/full_train_test": (
+ torch.tensor([0.49103116, 0.5080927, 0.46408487]),
+ torch.tensor([0.16669449, 0.16859235, 0.18495317]),
+ ),
+ "/home/ubuntu/AWS_Server/swin-transformer/datasets/ip102": (
+ torch.tensor([0.51354748, 0.54016679, 0.38778601]),
+ torch.tensor([0.19195388, 0.19070604, 0.19121135]),
+ ),
+ "imagenet": (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
+}
diff --git a/data/data_simmim_ft.py b/data/data_simmim_ft.py
index 1d44ae7d..ffd4137c 100644
--- a/data/data_simmim_ft.py
+++ b/data/data_simmim_ft.py
@@ -6,18 +6,20 @@
# --------------------------------------------------------
import os
+
import torch.distributed as dist
-from torch.utils.data import DataLoader, DistributedSampler
-from torchvision import datasets, transforms
+from timm.data import Mixup, create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
-from timm.data import Mixup
-from timm.data import create_transform
from timm.data.transforms import _pil_interp
+from torch.utils.data import DataLoader, DistributedSampler
+from torchvision import datasets, transforms
def build_loader_finetune(config):
config.defrost()
- dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
+ dataset_train, config.MODEL.NUM_CLASSES = build_dataset(
+ is_train=True, config=config
+ )
config.freeze()
dataset_val, _ = build_dataset(is_train=False, config=config)
@@ -31,7 +33,8 @@ def build_loader_finetune(config):
)
data_loader_train = DataLoader(
- dataset_train, sampler=sampler_train,
+ dataset_train,
+ sampler=sampler_train,
batch_size=config.DATA.BATCH_SIZE,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
@@ -39,7 +42,8 @@ def build_loader_finetune(config):
)
data_loader_val = DataLoader(
- dataset_val, sampler=sampler_val,
+ dataset_val,
+ sampler=sampler_val,
batch_size=config.DATA.BATCH_SIZE,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
@@ -48,21 +52,31 @@ def build_loader_finetune(config):
# setup mixup / cutmix
mixup_fn = None
- mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
+ mixup_active = (
+ config.AUG.MIXUP > 0
+ or config.AUG.CUTMIX > 0.0
+ or config.AUG.CUTMIX_MINMAX is not None
+ )
if mixup_active:
mixup_fn = Mixup(
- mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
- prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
- label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
+ mixup_alpha=config.AUG.MIXUP,
+ cutmix_alpha=config.AUG.CUTMIX,
+ cutmix_minmax=config.AUG.CUTMIX_MINMAX,
+ prob=config.AUG.MIXUP_PROB,
+ switch_prob=config.AUG.MIXUP_SWITCH_PROB,
+ mode=config.AUG.MIXUP_MODE,
+ label_smoothing=config.MODEL.LABEL_SMOOTHING,
+ num_classes=config.MODEL.NUM_CLASSES,
+ )
return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
def build_dataset(is_train, config):
transform = build_transform(is_train, config)
-
- if config.DATA.DATASET == 'imagenet':
- prefix = 'train' if is_train else 'val'
+
+ if config.DATA.DATASET == "imagenet":
+ prefix = "train" if is_train else "val"
root = os.path.join(config.DATA.DATA_PATH, prefix)
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = 1000
@@ -79,8 +93,12 @@ def build_transform(is_train, config):
transform = create_transform(
input_size=config.DATA.IMG_SIZE,
is_training=True,
- color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
- auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
+ color_jitter=config.AUG.COLOR_JITTER
+ if config.AUG.COLOR_JITTER > 0
+ else None,
+ auto_augment=config.AUG.AUTO_AUGMENT
+ if config.AUG.AUTO_AUGMENT != "none"
+ else None,
re_prob=config.AUG.REPROB,
re_mode=config.AUG.REMODE,
re_count=config.AUG.RECOUNT,
@@ -89,7 +107,9 @@ def build_transform(is_train, config):
if not resize_im:
# replace RandomResizedCropAndInterpolation with
# RandomCrop
- transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
+ transform.transforms[0] = transforms.RandomCrop(
+ config.DATA.IMG_SIZE, padding=4
+ )
return transform
t = []
@@ -97,14 +117,18 @@ def build_transform(is_train, config):
if config.TEST.CROP:
size = int((256 / 224) * config.DATA.IMG_SIZE)
t.append(
- transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
+ transforms.Resize(
+ size, interpolation=_pil_interp(config.DATA.INTERPOLATION)
+ ),
# to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
else:
t.append(
- transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
- interpolation=_pil_interp(config.DATA.INTERPOLATION))
+ transforms.Resize(
+ (config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
+ interpolation=_pil_interp(config.DATA.INTERPOLATION),
+ )
)
t.append(transforms.ToTensor())
diff --git a/data/data_simmim_pt.py b/data/data_simmim_pt.py
index 0f2503a7..9826ce90 100644
--- a/data/data_simmim_pt.py
+++ b/data/data_simmim_pt.py
@@ -7,70 +7,81 @@
import math
import random
-import numpy as np
+import numpy as np
import torch
import torch.distributed as dist
import torchvision.transforms as T
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data._utils.collate import default_collate
from torchvision.datasets import ImageFolder
-from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
class MaskGenerator:
- def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6):
+ def __init__(
+ self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6
+ ):
self.input_size = input_size
self.mask_patch_size = mask_patch_size
self.model_patch_size = model_patch_size
self.mask_ratio = mask_ratio
-
+
assert self.input_size % self.mask_patch_size == 0
assert self.mask_patch_size % self.model_patch_size == 0
-
+
self.rand_size = self.input_size // self.mask_patch_size
self.scale = self.mask_patch_size // self.model_patch_size
-
- self.token_count = self.rand_size ** 2
+
+ self.token_count = self.rand_size**2
self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
-
+
def __call__(self):
- mask_idx = np.random.permutation(self.token_count)[:self.mask_count]
+ mask_idx = np.random.permutation(self.token_count)[: self.mask_count]
mask = np.zeros(self.token_count, dtype=int)
mask[mask_idx] = 1
-
+
mask = mask.reshape((self.rand_size, self.rand_size))
mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)
-
+
return mask
class SimMIMTransform:
def __init__(self, config):
- self.transform_img = T.Compose([
- T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
- T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)),
- T.RandomHorizontalFlip(),
- T.ToTensor(),
- T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)),
- ])
-
- if config.MODEL.TYPE in ['swin', 'swinv2']:
- model_patch_size=config.MODEL.SWIN.PATCH_SIZE
+ self.transform_img = T.Compose(
+ [
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
+ T.RandomResizedCrop(
+ config.DATA.IMG_SIZE,
+ scale=(0.67, 1.0),
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
+ ),
+ T.RandomHorizontalFlip(),
+ T.ToTensor(),
+ T.Normalize(
+ mean=torch.tensor(IMAGENET_DEFAULT_MEAN),
+ std=torch.tensor(IMAGENET_DEFAULT_STD),
+ ),
+ ]
+ )
+
+ if config.MODEL.TYPE in ["swin", "swinv2"]:
+ model_patch_size = config.MODEL.SWIN.PATCH_SIZE
else:
raise NotImplementedError
-
+
self.mask_generator = MaskGenerator(
input_size=config.DATA.IMG_SIZE,
mask_patch_size=config.DATA.MASK_PATCH_SIZE,
model_patch_size=model_patch_size,
mask_ratio=config.DATA.MASK_RATIO,
)
-
+
def __call__(self, img):
img = self.transform_img(img)
mask = self.mask_generator()
-
+
return img, mask
@@ -84,7 +95,9 @@ def collate_fn(batch):
if batch[0][0][item_idx] is None:
ret.append(None)
else:
- ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)]))
+ ret.append(
+ default_collate([batch[i][0][item_idx] for i in range(batch_num)])
+ )
ret.append(default_collate([batch[i][1] for i in range(batch_num)]))
return ret
@@ -92,8 +105,18 @@ def collate_fn(batch):
def build_loader_simmim(config):
transform = SimMIMTransform(config)
dataset = ImageFolder(config.DATA.DATA_PATH, transform)
-
- sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
- dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn)
-
- return dataloader
\ No newline at end of file
+
+ sampler = DistributedSampler(
+ dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True
+ )
+ dataloader = DataLoader(
+ dataset,
+ config.DATA.BATCH_SIZE,
+ sampler=sampler,
+ num_workers=config.DATA.NUM_WORKERS,
+ pin_memory=True,
+ drop_last=True,
+ collate_fn=collate_fn,
+ )
+
+ return dataloader
diff --git a/data/hierarchical.py b/data/hierarchical.py
new file mode 100644
index 00000000..5bcd5057
--- /dev/null
+++ b/data/hierarchical.py
@@ -0,0 +1,94 @@
+import os
+
+import torch
+from timm.data import Mixup, mixup
+from torchvision.datasets import ImageFolder
+
+
+class HierarchicalImageFolder(ImageFolder):
+ """
+ Parses an image folder where the hierarchy is represented as follows:
+
+ 00000_top_middle_..._bottom
+ 00001_top_middle_..._other
+ ...
+ """
+
+ num_classes = None
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def find_classes(self, directory):
+ classes = sorted(
+ entry.name for entry in os.scandir(directory) if entry.is_dir()
+ )
+
+ tier_lookup = {}
+ class_to_idxs = {}
+
+ for cls in classes:
+ tiers = make_hierarchical(cls)
+
+ for tier, value in enumerate(tiers):
+ if tier not in tier_lookup:
+ tier_lookup[tier] = {}
+
+ if value not in tier_lookup[tier]:
+ tier_lookup[tier][value] = len(tier_lookup[tier])
+
+ class_to_idxs[cls] = torch.tensor(
+ [tier_lookup[tier][value] for tier, value in enumerate(tiers)]
+ )
+
+ # Set self.num_classes
+ self.num_classes = tuple(len(tier) for tier in tier_lookup.values())
+
+ return classes, class_to_idxs
+
+
+def make_hierarchical(name):
+ """
+ Sometimes the tree is not really a tree; that is, sometimes there are
+ repeated orders, for example.
+
+ Arguments:
+ name (str): the complete taxonomic name, separated by '_'
+ """
+ # index is a number
+ # top is kingdom
+ index, top, *tiers = name.split("_")
+
+ cleaned = [top]
+
+ complete = top
+ for tier in tiers:
+ complete += f"-{tier}"
+ cleaned.append(complete)
+
+ return cleaned
+
+
+class HierarchicalMixup(Mixup):
+ def __call__(self, inputs, targets):
+ assert len(inputs) % 2 == 0, "Batch size should be even when using this"
+ if self.mode == "elem":
+ lam = self._mix_elem(inputs)
+ elif self.mode == "pair":
+ lam = self._mix_pair(inputs)
+ else:
+ lam = self._mix_batch(inputs)
+
+ batch_size, *_ = inputs.shape
+ assert targets.shape == (
+ batch_size,
+ len(self.num_classes),
+ ), f"{targets.shape} != {batch_size, len(self.num_classes)}"
+
+ targets = [
+ mixup.mixup_target(
+ target, num_classes, lam, self.label_smoothing, inputs.device
+ )
+ for target, num_classes in zip(targets.T, self.num_classes)
+ ]
+ return inputs, targets
diff --git a/data/imagenet22k_dataset.py b/data/imagenet22k_dataset.py
index 5758060b..dc3210fe 100644
--- a/data/imagenet22k_dataset.py
+++ b/data/imagenet22k_dataset.py
@@ -1,16 +1,16 @@
-import os
import json
-import torch.utils.data as data
+import os
+import warnings
+
import numpy as np
+import torch.utils.data as data
from PIL import Image
-import warnings
-
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
class IN22KDATASET(data.Dataset):
- def __init__(self, root, ann_file='', transform=None, target_transform=None):
+ def __init__(self, root, ann_file="", transform=None, target_transform=None):
super(IN22KDATASET, self).__init__()
self.data_path = root
@@ -40,7 +40,7 @@ def __getitem__(self, index):
idb = self.database[index]
# images
- images = self._load_image(self.data_path + '/' + idb[0]).convert('RGB')
+ images = self._load_image(self.data_path + "/" + idb[0]).convert("RGB")
if self.transform is not None:
images = self.transform(images)
diff --git a/data/inat/__init__.py b/data/inat/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/data/inat/__main__.py b/data/inat/__main__.py
new file mode 100644
index 00000000..c1c0883c
--- /dev/null
+++ b/data/inat/__main__.py
@@ -0,0 +1,70 @@
+import argparse
+
+from .. import hierarchical
+from . import datasets
+
+
+def preprocess_cli(args):
+ datasets.preprocess_dataset(args.root, args.stage, args.strategy, args.size)
+
+
+def normalize_cli(args):
+ mean, std = datasets.load_statistics(args.directory)
+ print("Add this to a constants.py file:")
+ print(
+ f"""
+"{args.directory}": (
+ torch.tensor({mean.tolist()}),
+ torch.tensor({std.tolist()}),
+),"""
+ )
+
+
+def num_classes_cli(args):
+ dataset = hierarchical.HierarchicalImageFolder(
+ args.directory,
+ )
+
+ print("Add this to your config .yml file to the model section:")
+ print(f"num_classes: {list(dataset.num_classes)}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ subparsers = parser.add_subparsers(help="Available commands.")
+
+ # Preprocess
+ preprocess_parser = subparsers.add_parser("preprocess", help="Preprocess data.")
+ preprocess_parser.add_argument(
+ "root",
+ help="Root data folder. Should contain folders compressed/ and raw/train, raw/val, etc.",
+ )
+ preprocess_parser.add_argument(
+ "stage", choices=datasets.Inat21.stages, help="Data stage to preprocess"
+ )
+ preprocess_parser.add_argument(
+ "strategy", choices=datasets.Inat21.strategies, help="Preprocessing strategy"
+ )
+ preprocess_parser.add_argument("size", type=int, help="Image size in pixels")
+ preprocess_parser.set_defaults(func=preprocess_cli)
+
+ # Normalize
+ normalize_parser = subparsers.add_parser(
+ "normalize", help="Measure mean and std of dataset."
+ )
+ normalize_parser.add_argument("directory", help="Data folder")
+ normalize_parser.set_defaults(func=normalize_cli)
+
+ # Number of classes
+ num_classes_parser = subparsers.add_parser(
+ "num-classes", help="Measure number of classes in dataset."
+ )
+ num_classes_parser.add_argument("directory", help="Data folder")
+ num_classes_parser.set_defaults(func=num_classes_cli)
+
+ args = parser.parse_args()
+
+ if hasattr(args, "func"):
+ args.func(args)
+ else:
+ parser.print_usage()
diff --git a/data/inat/datasets.py b/data/inat/datasets.py
new file mode 100644
index 00000000..b0676896
--- /dev/null
+++ b/data/inat/datasets.py
@@ -0,0 +1,194 @@
+import concurrent.futures
+import os
+import pathlib
+import warnings
+
+import cv2
+import einops
+import timm.data
+import torch
+import torchvision
+from tqdm.auto import tqdm
+
+
+def load_statistics(directory):
+ """
+ Need to calculate mean and std for the individual channels so we can normalize the images.
+ """
+ dataset = timm.data.ImageDataset(
+ root=directory, transform=torchvision.transforms.ToTensor()
+ )
+ channels, height, width = dataset[0][0].shape
+
+ total = torch.zeros((channels,))
+ total_squared = torch.zeros((channels,))
+
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=256, num_workers=32)
+
+ for batch, _ in tqdm(dataloader):
+ total += einops.reduce(batch, "batch channel height width -> channel", "sum")
+ total_squared += einops.reduce(
+ torch.mul(batch, batch), "batch channel height width -> channel", "sum"
+ )
+
+ divisor = len(dataset) * width * height
+
+ mean = total / divisor
+ var = total_squared / divisor - torch.mul(mean, mean)
+ std = torch.sqrt(var)
+
+ return mean, std
+
+
+class Inat21:
+ val_tar_gz_hash = "f6f6e0e242e3d4c9569ba56400938afc"
+ train_tar_gz_hash = "e0526d53c7f7b2e3167b2b43bb2690ed"
+ train_mini_tar_gz_hash = "db6ed8330e634445efc8fec83ae81442"
+ strategies = ("resize", "pad")
+ stages = ("train", "val", "train_mini")
+
+ def __init__(self, root: str, stage: str, strategy: str, size: int):
+ self.root = root
+ self.stage = stage
+ self._check_stage(stage)
+
+ self.strategy = strategy
+ self._check_strategy(strategy)
+
+ self.size = size
+ self._check_size(size)
+
+ @property
+ def directory(self) -> str:
+ return os.path.join(self.root, f"{self.strategy}-{self.size}", self.stage)
+
+ @property
+ def ffcv(self) -> str:
+ return os.path.join(
+ self.root, f"{self.stage}-{self.strategy}-{self.size}.beton"
+ )
+
+ @property
+ def tar_file(self):
+ return os.path.join(self.root, "compressed", f"{self.stage}.tar.gz")
+
+ @property
+ def raw_dir(self):
+ return os.path.join(self.root, "raw", self.stage)
+
+ def _check_stage(self, stage):
+ if stage not in self.stages:
+ raise ValueError(f"Stage '{stage}' must be one of {self.stages}")
+
+ def _check_strategy(self, strategy):
+ if strategy not in self.strategies:
+ raise ValueError(f"Strategy '{strategy}' must be one of {self.strategies}")
+
+ def _check_size(self, size):
+ if not isinstance(size, int):
+ raise ValueError(f"Size {size} must be int; not {type(int)}")
+
+ def check(self):
+ # If /.finished doesn't exist, we need to preprocess.
+ if not os.path.isfile(finished_file_path(self.directory)):
+ warnings.warn(
+ f"Data not processed in {self.directory}! "
+ "You should run:\n\n"
+ f"\tpython -m src.data preprocess {self.root} {self.stage} {self.strategy} {self.size}"
+ "\n\nAnd then run this script again!"
+ )
+ raise RuntimeError(f"Data {self.directory} not pre-processed!")
+
+
+def finished_file_path(directory):
+ return os.path.join(directory, ".finished")
+
+
+def preprocess_class(cls_dir, output_dir, strategy, size):
+ cls = os.path.basename(cls_dir)
+ output_dir = os.path.join(output_dir, cls)
+ os.makedirs(output_dir, exist_ok=True)
+
+ with os.scandir(cls_dir) as entries:
+ for entry in entries:
+ if not entry.is_file():
+ continue
+
+ im = cv2.imread(entry.path)
+ im = preprocess_image(im, strategy, size)
+ output_path = os.path.join(output_dir, entry.name)
+ if not cv2.imwrite(output_path, im):
+ raise RuntimeError(output_path)
+
+
+def preprocess_image(im, strategy, size):
+ if strategy == "resize":
+ return cv2.resize(im, (size, size), interpolation=cv2.INTER_LINEAR)
+ elif strategy == "pad":
+ # https://stackoverflow.com/questions/43391205/add-padding-to-images-to-get-them-into-the-same-shape
+ raise NotImplementedError()
+ else:
+ raise NotImplementedError()
+
+
+def parent_of(path: str):
+ return pathlib.Path(path).parents[0]
+
+
+def preprocess_dataset(root: str, stage: str, strategy: str, size: int) -> None:
+ inat = Inat21(root, stage, strategy, size)
+
+ err_msg = (
+ f"Can't prepare data for stage {stage}, strategy {strategy} and size {size}."
+ )
+
+ # 1. If the directory does not exist, ask the user to fix that for us.
+ if not os.path.isdir(inat.raw_dir):
+ # Check that the tar exists
+ if not os.path.isfile(inat.tar_file):
+ warn_msg = f"Please download the appropriate .tar.gz file to {root} for stage {stage}."
+ if "raw" in root:
+ warn_msg += f"\n\nYour root path should contain a 'raw' directory; did you mean to use {parent_of(root)}?\n"
+ elif "compressed" in root:
+ warn_msg += f"\n\nYour root path should contain a 'compressed' directory; did you mean to use {parent_of(root)}?\n"
+
+ warnings.warn(warn_msg)
+
+ raise RuntimeError(err_msg)
+ else:
+ warnings.warn(
+ f"Please untar {inat.tar_file} in {root}. Probably need to run 'cd {root}; tar -xvf {inat.tar_file}"
+ )
+ raise RuntimeError(err_msg)
+
+ # 2. Now that we know the raw directory exists, we need to convert it
+ # to a processed directory
+ out_path = os.path.join(root, inat.directory)
+
+ # 3. Make sure the directory exists
+ if not os.path.isdir(out_path):
+ os.makedirs(out_path)
+
+ # 4. For all raw files, process and save to directory. We do this with
+ # a process pool because it is both I/O (read/write) and CPU (image processing)
+ # bound.
+ with os.scandir(inat.raw_dir) as entries:
+ directories = [entry.path for entry in entries if entry.is_dir()]
+ # print(directories)
+ print(f"Found {len(directories)} directories to preprocess.")
+
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ futures = [
+ executor.submit(preprocess_class, directory, out_path, strategy, size)
+ for directory in tqdm(directories)
+ ]
+
+ print(f"Submitted {len(futures)} jobs to executor.")
+
+ for future in tqdm(
+ concurrent.futures.as_completed(futures), total=len(futures)
+ ):
+ future.result()
+
+ # 5. Save a sentinel file called .finished
+ open(finished_file_path(out_path), "w").close()
diff --git a/data/zipreader.py b/data/zipreader.py
index 060bc46a..1babf75e 100644
--- a/data/zipreader.py
+++ b/data/zipreader.py
@@ -5,23 +5,24 @@
# Written by Ze Liu
# --------------------------------------------------------
+import io
import os
import zipfile
-import io
+
import numpy as np
-from PIL import Image
-from PIL import ImageFile
+from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
def is_zip_path(img_or_path):
"""judge if this is a zip path"""
- return '.zip@' in img_or_path
+ return ".zip@" in img_or_path
class ZipReader(object):
"""A class to read zipped files"""
+
zip_bank = dict()
def __init__(self):
@@ -31,18 +32,20 @@ def __init__(self):
def get_zipfile(path):
zip_bank = ZipReader.zip_bank
if path not in zip_bank:
- zfile = zipfile.ZipFile(path, 'r')
+ zfile = zipfile.ZipFile(path, "r")
zip_bank[path] = zfile
return zip_bank[path]
@staticmethod
def split_zip_style_path(path):
- pos_at = path.index('@')
- assert pos_at != -1, "character '@' is not found from the given path '%s'" % path
-
- zip_path = path[0: pos_at]
- folder_path = path[pos_at + 1:]
- folder_path = str.strip(folder_path, '/')
+ pos_at = path.index("@")
+ assert pos_at != -1, (
+ "character '@' is not found from the given path '%s'" % path
+ )
+
+ zip_path = path[0:pos_at]
+ folder_path = path[pos_at + 1 :]
+ folder_path = str.strip(folder_path, "/")
return zip_path, folder_path
@staticmethod
@@ -52,33 +55,37 @@ def list_folder(path):
zfile = ZipReader.get_zipfile(zip_path)
folder_list = []
for file_foler_name in zfile.namelist():
- file_foler_name = str.strip(file_foler_name, '/')
- if file_foler_name.startswith(folder_path) and \
- len(os.path.splitext(file_foler_name)[-1]) == 0 and \
- file_foler_name != folder_path:
+ file_foler_name = str.strip(file_foler_name, "/")
+ if (
+ file_foler_name.startswith(folder_path)
+ and len(os.path.splitext(file_foler_name)[-1]) == 0
+ and file_foler_name != folder_path
+ ):
if len(folder_path) == 0:
folder_list.append(file_foler_name)
else:
- folder_list.append(file_foler_name[len(folder_path) + 1:])
+ folder_list.append(file_foler_name[len(folder_path) + 1 :])
return folder_list
@staticmethod
def list_files(path, extension=None):
if extension is None:
- extension = ['.*']
+ extension = [".*"]
zip_path, folder_path = ZipReader.split_zip_style_path(path)
zfile = ZipReader.get_zipfile(zip_path)
file_lists = []
for file_foler_name in zfile.namelist():
- file_foler_name = str.strip(file_foler_name, '/')
- if file_foler_name.startswith(folder_path) and \
- str.lower(os.path.splitext(file_foler_name)[-1]) in extension:
+ file_foler_name = str.strip(file_foler_name, "/")
+ if (
+ file_foler_name.startswith(folder_path)
+ and str.lower(os.path.splitext(file_foler_name)[-1]) in extension
+ ):
if len(folder_path) == 0:
file_lists.append(file_foler_name)
else:
- file_lists.append(file_foler_name[len(folder_path) + 1:])
+ file_lists.append(file_foler_name[len(folder_path) + 1 :])
return file_lists
diff --git a/docs/cli-design.md b/docs/cli-design.md
new file mode 100644
index 00000000..4520f631
--- /dev/null
+++ b/docs/cli-design.md
@@ -0,0 +1,15 @@
+# CLI Design
+
+Options that are different per run:
+
+1. Config file
+2. Device batch size (depends on image size and local GPUs)
+3. Whether to debug or not (1 process or default number of processes)
+4. Master port (with a default of 12345)
+5. Data path
+6. Number of processes (needs to match number of GPUs)
+
+Options that are different per machine:
+
+1. Output directory
+2. Virtual environment location
diff --git a/docs/experiments/bold-banana.md b/docs/experiments/bold-banana.md
new file mode 100644
index 00000000..a0c8c8ce
--- /dev/null
+++ b/docs/experiments/bold-banana.md
@@ -0,0 +1,20 @@
+# Bold Banana
+
+This experiment is the default swin-v2-base applied to iNat 21 192x192.
+We train for 90 epochs at 192x192, then tune for 30 epochs at 256x256.
+
+```yaml
+configs:
+- configs/hierarchical-vision-project/bold-banana-192.yaml
+codename: bold-banana
+```
+
+## Log
+
+I initialized training on strawberry0 on 4x A6000 servers.
+
+I decided to use the A6000 servers for 256x256 tuning, so I am moving the latest checkpoint to S3, then cloning it back to an 8x V100 server to finish training.
+I am storing the 40th checkpoint on S3 as funky-banana-192-epoch40.pth.
+It is now running as funky-banana-192 on 8x V100.
+
+It had a bad mean/std, so I am re-christening it as bold-banana on 4x A6000
diff --git a/docs/experiments/fuzzy-fig.md b/docs/experiments/fuzzy-fig.md
new file mode 100644
index 00000000..476d2810
--- /dev/null
+++ b/docs/experiments/fuzzy-fig.md
@@ -0,0 +1,14 @@
+# Fuzzy Fig
+
+This experiment is the default swin-v2-base applied to iNat 21 192x192 with 1/4 the default learning rate (1.25e-4).
+We train for 90 epochs at 192x192, then tune for 30 epochs at 256x256.
+
+```yaml
+configs:
+- configs/hierarchical-vision-project/fuzzy-fig-192.yaml
+codename: fuzzy-fig
+```
+
+## Log
+
+Training is running on 8x V100 as fuzzy-fig-192.
diff --git a/docs/experiments/groovy-grape.md b/docs/experiments/groovy-grape.md
new file mode 100644
index 00000000..c794d8a7
--- /dev/null
+++ b/docs/experiments/groovy-grape.md
@@ -0,0 +1,19 @@
+# Groovy Grape
+
+This experiment trains swin-v2-base from scratch on iNat21 using the multitask objective using 1/4 of the default learning rate, which works out to 1.25e-4.
+It also does 90 epochs at 192.
+
+```yaml
+configs:
+- configs/hierarchical-vision-project/groovy-grape-192.yaml
+codename: groovy-grape
+```
+
+## Log
+
+This model trained the first 90 on 8x V100, and did 8/30 epochs at 256 on 8x V100.
+I am storing the 8th checkpoint on S3 as groovy-grape-256-epoch8.pth.
+Now it is stored at /local/scratch/stevens.994/hierarchical-vision/groovy-grape-256/v0
+It was originally haunted-broomstick on wandb, but is now groovy-grape-256.
+
+I messed it up by switching std/mean, so I am restarting this run on 8xV100.
diff --git a/docs/experiments/index.md b/docs/experiments/index.md
new file mode 100644
index 00000000..e79da1f6
--- /dev/null
+++ b/docs/experiments/index.md
@@ -0,0 +1,13 @@
+# Experiment Index
+
+This file has a lightweight index of all the experiments I'm running so you can go to the file if necessary.
+
+[bold-banana](bold-banana.md): Default swin-v2-base on iNat21
+
+[fuzzy-fig](fuzzy-fig.md): Default swin-v2-base on iNat21 with 1/4 LR
+
+[outrageous-orange](outrageous-orange.md): Hierarchical swin-v2-base on iNat21
+
+[sweet-strawberry](sweet-strawberry.md): Hierarchical swin-v2-base on iNat21 with 1/2 LR
+
+[groovy-grape](groovy-grape.md): Hierarchical swin-v2-base on iNat21 with 1/4 LR
diff --git a/docs/experiments/outrageous-orange.md b/docs/experiments/outrageous-orange.md
new file mode 100644
index 00000000..89fcc20c
--- /dev/null
+++ b/docs/experiments/outrageous-orange.md
@@ -0,0 +1,24 @@
+# Outrageous Orange
+
+This experiment is the swin-v2-base with hierarchical multitask objective applied to iNat 21 192x192.
+We train for 90 epochs at 192x192.
+
+```yaml
+configs:
+- configs/hierarchical-vision-project/outrageous-orange-192.yaml
+codename: outrageous-orange
+```
+
+## Log
+
+I initialized training on strawberry0 on 4x A6000 servers.
+
+I decided to use the A6000 servers for 256x256 tuning, so I am moving the latest checkpoint to S3, then cloning it back to an 8x V100 server to finish training.
+I am storing the 36th checkpoint on S3 as `outrageous-orange-192-epoch36.pth`.
+It is now running as `outrageous-orange-192` on 8x V100.
+
+I am going to stop outrageous-orange-192 (and never run outrageous-orange-256) because it is underperforming the other hierarchical runs and I don't want to waste compute on an obviously bad run. I'll still upload the checkpoints to S3 to save them.
+The latest checkpoint is at `s3://imageomics-models/outrageous-orange-192-epoch64.pth`
+
+This is a bad run (std/mean switching issue).
+I am going to restart it on 4x A60000 because I expect it to underperform the other run.
diff --git a/docs/experiments/sweet-strawberry.md b/docs/experiments/sweet-strawberry.md
new file mode 100644
index 00000000..7fda07e4
--- /dev/null
+++ b/docs/experiments/sweet-strawberry.md
@@ -0,0 +1,20 @@
+# Sweet Strawberry
+
+This experiment trains swin-v2-base from scratch on iNat21 using the multitask objective using 1/2 of the default learning rate, which works out to 2.5e-4.
+It also does 90 epochs at 192.
+
+```yaml
+configs:
+- configs/hierarchical-vision-project/sweet-strawberry-192.yaml
+codename: sweet-strawberry
+```
+
+## Log
+
+This model trained the first 90 on 8x V100, and did 15/30 epochs at 256 on 8x V100.
+I am storing the 15th checkpoint on S3 as sweet-strawberry-256-epoch15.pth.
+Now it is stored at /local/scratch/stevens.994/hierarchical-vision/sweet-strawberry-256/v0
+It was originally unearthly-moon on wandb, but is now sweet-strawberry-256.
+It is running on 4x A6000.
+
+I want to check that I didn't mix up sweet-strawberry-192 and groovy-grape-192, so I am going to check their validation top 1 accuracy on iNat21 192x192 using their final checkpoints.
diff --git a/generate_configs/generate.py b/generate_configs/generate.py
new file mode 100644
index 00000000..95c047b7
--- /dev/null
+++ b/generate_configs/generate.py
@@ -0,0 +1,128 @@
+import argparse
+import os
+import pathlib
+import yaml
+import tomli
+import tomli_w
+from tqdm.auto import tqdm
+# import sys
+# sys.path.append("..")
+
+# from .. import config, logging, templating, util
+
+import logger, config, utils
+logger = logger.init("experiments.generate")
+
+
+from . import templating #, utils
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="General .toml files from template .toml files. I kept all my templates in experiments/templates and my generated experiment configs in experiments/generated, which I then removed from version control.",
+ )
+ parser.add_argument(
+ "--strategy",
+ type=str,
+ help="Strategy to use to combine multiple lists in a template.",
+ default="grid",
+ choices=["grid", "paired", "random"],
+ )
+ parser.add_argument(
+ "--count",
+ type=int,
+ help="Number of configs to generate when using --strategy random. Required.",
+ default=-1,
+ )
+ parser.add_argument(
+ "--no-expand",
+ type=str,
+ nargs="+",
+ default=[],
+ help=".-separated fields to not expand",
+ )
+ parser.add_argument(
+ "--prefix",
+ type=str,
+ default="generated-",
+ help="Prefix to add to generated templates",
+ )
+ parser.add_argument(
+ "templates",
+ nargs="+",
+ type=str,
+ help="Template .toml files or directories containing template .toml files.",
+ )
+ parser.add_argument(
+ "output",
+ type=str,
+ help="Output directory to write the generated .toml files to.",
+ )
+ return parser.parse_args()
+
+
+def generate(args: argparse.Namespace) -> None:
+
+ strategy = templating.Strategy.new(args.strategy)
+
+ count = args.count
+
+ if strategy is templating.Strategy.random:
+ assert count > 0, "Need to include --count!"
+
+ for template_toml in utils.files_with_extension(args.templates, ".yaml"):
+
+ with open(template_toml, "rb") as template_file:
+ try:
+ template_dict = yaml.safe_load(template_file)
+ except tomli.TOMLDecodeError as err: #ToDo have to replace this for yaml file
+ logger.warning(
+ "Error parsing template file. [file: %s, err: %s]",
+ template_toml,
+ err,
+ )
+ continue
+
+
+ template_name = pathlib.Path(template_toml).stem
+
+
+ logger.info("Opened template file. [file: %s]", template_toml)
+ # print ("\n\n")
+ # print (template_dict)
+
+ experiment_dicts = templating.generate(
+ template_dict, strategy, count=count, no_expand=set(args.no_expand)
+ )
+
+
+ logger.info(
+ "Loaded experiment dictionaries. [count: %s]", len(experiment_dicts)
+ )
+
+ for i, experiment_dict in enumerate(tqdm(experiment_dicts)):
+ outputfile=experiment_dict["EXPERIMENT"]["NAME"]
+ experiment_dict["EXPERIMENT"]["NAME"]=os.path.join(outputfile, str(i))
+ filename = f"{args.prefix}{template_name}-{i}.yaml"
+ filepath = os.path.join(args.output, filename)
+ with open(filepath, "w") as file:
+ yaml.dump(experiment_dict, file)
+
+
+
+ #exit()
+ # Verifies that the configs are correctly loaded.
+ #list(config.load_configs(filepath))
+
+
+
+def main() -> None:
+ args = parse_args()
+
+ # print (args.no_expand)
+
+ generate(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/generate_configs/templating.py b/generate_configs/templating.py
new file mode 100644
index 00000000..7d4836e8
--- /dev/null
+++ b/generate_configs/templating.py
@@ -0,0 +1,369 @@
+import copy
+import dataclasses
+import enum
+import random
+import re
+import typing
+from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union #Protocol
+from typing_extensions import Protocol
+import orjson
+import preface
+import scipy.stats
+
+import logger, config
+
+# from . import config, logging
+
+Primitive = Union[str, int, float, bool]
+StringDict = Dict[str, object]
+
+logger = logger.init(__name__)
+
+
+TEMPLATE_PATTERN = re.compile(r"\((\d+)\.\.\.(\d+)\)")
+
+CONTINUOUS_DISTRIBUTION_SUFFIX = "__random-sample-distribution"
+
+
+class Strategy(preface.SumType):
+ grid = enum.auto()
+ paired = enum.auto()
+ random = enum.auto()
+
+
+@dataclasses.dataclass(frozen=True)
+class DiscreteDistribution:
+ values: Sequence[Primitive]
+
+
+class HasRandomVariates(Protocol):
+ def rvs(self, **kwargs: Dict[str, Any]) -> float:
+ ...
+
+
+# def parse_dist(
+# raw_dist: object, raw_params: object
+# ) -> Tuple[HasRandomVariates, Dict[str, Any]]:
+# try:
+# # We type ignore because we catch any type errors
+# parameter_list = list(raw_params) # type: ignore
+# except TypeError:
+# raise ValueError(f"{raw_params} should be a sequence!")
+
+# dist_choices = typing.get_args(config.Distribution)
+# if raw_dist not in dist_choices:
+# raise ValueError(f"{raw_dist} must be one of {', '.join(dist_choices)}")
+# raw_dist = typing.cast(config.Distribution, raw_dist)
+
+# params = {}
+
+# if raw_dist == "uniform":
+# dist = scipy.stats.uniform
+# assert len(parameter_list) == 2
+
+# low, high = parameter_list
+# assert low < high
+# params = {"loc": low, "scale": high - low}
+# elif raw_dist == "normal":
+# dist = scipy.stats.norm
+# assert len(parameter_list) == 2
+
+# mean, std = parameter_list
+# params = {"loc": mean, "scale": std}
+# elif raw_dist == "loguniform":
+# dist = scipy.stats.loguniform
+# assert len(parameter_list) == 2
+
+# low, high = parameter_list
+# assert low < high
+# params = {"a": low, "b": high}
+# else:
+# preface.never(raw_dist)
+
+# return dist, params
+
+
+@dataclasses.dataclass(frozen=True)
+class ContinuousDistribution:
+ fn: HasRandomVariates
+ params: Dict[str, Any]
+
+ @classmethod
+ def parse(cls, value: object) -> Optional["ContinuousDistribution"]:
+ assert isinstance(value, dict)
+
+ assert "dist" in value
+ assert "params" in value
+
+ dist, params = parse_dist(value["dist"], value["params"])
+
+ return cls(dist, params)
+
+
+@dataclasses.dataclass(frozen=True)
+class Hole:
+ path: str
+ distribution: Union[DiscreteDistribution, ContinuousDistribution]
+
+ def __post_init__(self) -> None:
+ assert isinstance(self.distribution, DiscreteDistribution) or isinstance(
+ self.distribution, ContinuousDistribution
+ ), f"self.distribution ({self.distribution}) is {type(self.distribution)}!"
+
+
+def makehole(key: str, value: object) -> Optional[Hole]:
+ if isinstance(value, list):
+ values: List[Union[str, int]] = []
+
+ for item in value:
+ if isinstance(item, str):
+ values += expand_numbers(item)
+ else:
+ values.append(item)
+
+ return Hole(key, DiscreteDistribution(values))
+
+ if isinstance(value, str):
+ numbers = expand_numbers(value)
+
+ if len(numbers) > 1:
+ return Hole(key, DiscreteDistribution(numbers))
+
+ if key.endswith(CONTINUOUS_DISTRIBUTION_SUFFIX):
+ key = key.removesuffix(CONTINUOUS_DISTRIBUTION_SUFFIX)
+
+ dist = ContinuousDistribution.parse(value)
+ assert dist
+
+ return Hole(key, dist)
+
+ return None
+
+
+def find_holes(template: StringDict) -> List[Hole]:
+ """
+ Arguments:
+ template (StringDict): Template with potential holes
+ no_expand (Set[str]): Fields to not treat as holes, even if we would otherwise.
+ """
+ holes = []
+
+ # Make it a list so we can modify template during iteration.
+ for key, value in list(template.items()):
+ # Have to check makehole first because value might be a dict, but
+ # if key ends with CONTINUOUS_DISTRIBUTION_SUFFIX, then we want to
+ # parse that dict as a continuous distribution
+ hole = makehole(key, value)
+ if hole:
+ holes.append(hole)
+ template.pop(key)
+ elif isinstance(value, dict):
+ holes.extend(
+ Hole(f"{key}.{hole.path}", hole.distribution)
+ for hole in find_holes(value)
+ )
+
+ return holes
+
+
+def sort_by_json(dicts: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ return list(
+ sorted(dicts, key=lambda d: orjson.dumps(d, option=orjson.OPT_SORT_KEYS))
+ )
+
+
+# region FILLING
+
+
+def grid_fill(filled: StringDict, holes: List[Hole]) -> List[StringDict]:
+ if not holes:
+ return [filled]
+
+ experiments = []
+ first, rest = holes[0], holes[1:]
+
+ if not isinstance(first.distribution, DiscreteDistribution):
+ raise RuntimeError(
+ f"Must sample from DiscreteDistribution with strategy grid, not {type(first.distribution)}!"
+ )
+
+ for value in first.distribution.values:
+
+ experiment = copy.deepcopy(filled)
+
+ preface.dict.set(experiment, first.path, value)
+ experiments.extend(grid_fill(experiment, rest))
+
+
+ return sort_by_json(experiments)
+
+
+def paired_fill(holes: List[Hole]) -> List[StringDict]:
+ experiments = []
+
+ assert all(isinstance(hole.distribution, DiscreteDistribution) for hole in holes)
+
+ # We can type ignore because we assert that all distributions are discrete
+ shortest_hole = min(holes, key=lambda h: len(h.distribution.values)) # type: ignore
+
+ assert isinstance(shortest_hole.distribution, DiscreteDistribution)
+
+ for i in range(len(shortest_hole.distribution.values)):
+ experiment: StringDict = {}
+ for hole in holes:
+ assert isinstance(hole.distribution, DiscreteDistribution)
+ preface.dict.set(experiment, hole.path, hole.distribution.values[i])
+
+ experiments.append(experiment)
+
+ return sort_by_json(experiments)
+
+
+def random_fill(holes: List[Hole], count: int) -> List[StringDict]:
+ experiments = []
+
+ for _ in range(count):
+ experiment: StringDict = {}
+ for hole in holes:
+ if isinstance(hole.distribution, DiscreteDistribution):
+ preface.dict.set(
+ experiment, hole.path, random.choice(hole.distribution.values)
+ )
+ elif isinstance(hole.distribution, ContinuousDistribution):
+ preface.dict.set(
+ experiment,
+ hole.path,
+ float(hole.distribution.fn.rvs(**hole.distribution.params)),
+ )
+ else:
+ preface.never(hole.distribution)
+
+ experiments.append(experiment)
+
+ return experiments
+
+
+# endregion
+
+
+def generate(
+ template: StringDict,
+ strategy: Strategy,
+ count: int = 0,
+ *,
+ no_expand: Optional[Set[str]] = None,
+) -> List[StringDict]:
+ """
+ Turns a template (a dictionary with lists as values) into a list of experiments (dictionaries with no lists).
+
+ If strategy is Strategy.Grid, returns an experiment for each possible combination of each value in each list. Strategy.Paired returns an experiment for sequential pair of values in each list.
+
+ An example makes this clearer. If the template had 3 lists with lengths 5, 4, 10, respectively:
+
+ Grid would return 5 x 4 x 10 = 200 experiments.
+
+ Paired would return min(5, 4, 10) = 4 experiments
+
+ Random would return experiments
+
+ Experiments are returned sorted by the JSON value.
+ """
+ ignored = {}
+ if no_expand is not None:
+ # print (no_expand)
+ # print (template)
+
+ # for field in no_expand:
+ # print (field)
+ # exit()
+
+ ignored = {field: preface.dict.pop(template, field) for field in no_expand}
+ # print (template)
+ # exit()
+
+ template = copy.deepcopy(template)
+ # print ("\n\n\n")
+ # print (template)
+ # print ("\n\n\n")
+ holes = find_holes(template)
+ # print (holes)
+
+ if not holes:
+ # We can return this directly because there are no holes.
+ return [template]
+
+
+ logger.info("Found all holes. [count: %d]", len(holes))
+
+ experiments: List[StringDict] = []
+
+ if strategy is Strategy.grid:
+ filled = grid_fill({}, holes)
+ elif strategy is Strategy.paired:
+ filled = paired_fill(holes)
+ elif strategy is Strategy.random:
+ filled = random_fill(holes, count)
+ else:
+ preface.never(strategy)
+
+ logger.info("Filled all holes. [count: %d]", len(filled))
+
+ without_holes: StringDict = {}
+ for key, value in preface.dict.flattened(template).items():
+ if makehole(key, value):
+ continue
+
+ without_holes[key] = value
+
+ for key, value in ignored.items():
+ without_holes[key] = value
+
+ experiments = [preface.dict.merge(exp, without_holes) for exp in filled]
+
+ logger.info("Merged all experiment configs. [count: %d]", len(experiments))
+
+ return sort_by_json(experiments)
+
+
+def expand_numbers(obj: str) -> Union[List[str], List[int]]:
+ """
+ Given a string that potentially has the digits expander:
+
+ "(0...34)"
+
+ Returns a list of strings with each value in the range (inclusive, exclusive)
+
+ ["0", "1", ..., "33"]
+ """
+ splits = re.split(TEMPLATE_PATTERN, obj, maxsplit=1)
+
+ # One split means the pattern isn't present.
+ if len(splits) == 1:
+ return [obj]
+
+ if len(splits) != 4:
+ raise ValueError("Can't parse strings with more than one ( ... )")
+
+ front, start_s, end_s, back = splits
+
+ front_list = expand_numbers(front)
+ back_list = expand_numbers(back)
+
+ start = int(start_s)
+ end = int(end_s)
+
+ if start < end:
+ spread = range(start, end)
+ else:
+ spread = range(start, end, -1)
+
+ expanded = []
+ for f in front_list:
+ for i in spread:
+ for b in back_list:
+ expanded.append(f"{f}{i}{b}")
+
+ try:
+ return [int(i) for i in expanded]
+ except ValueError:
+ return expanded
diff --git a/hierarchical.py b/hierarchical.py
new file mode 100644
index 00000000..bb549307
--- /dev/null
+++ b/hierarchical.py
@@ -0,0 +1,79 @@
+import einops
+import torch
+
+
+def accuracy(output, target, topk=(1,), hierarchy_level=-1):
+ """
+ Computes the accuracy over the k top predictions for the specified values of k
+
+ Copied from rwightman/pytorch-image-models/timm/utils/metrics.py and modified
+ to work with hierarchical outputs as well.
+
+ When the output is hierarchical, only returns the accuracy for `hierarchy_level`
+ (default -1, which is the fine-grained level).
+ """
+ output_levels = 1
+ if isinstance(output, list):
+ output_levels = len(output)
+ output = output[-1]
+ print (output_levels)
+ # print (output)
+
+ batch_size = output.size(0)
+
+ # Target might have multiple levels because of the hierarchy
+ if target.squeeze().ndim == 2:
+ assert target.squeeze().shape == (batch_size, output_levels)
+ target = target[:, -1]
+
+ maxk = min(max(topk), output.size(1))
+ _, pred = output.topk(maxk, dim=1, largest=True, sorted=True)
+ pred = pred.t()
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
+ return [
+ correct[: min(k, maxk)].reshape(-1).float().sum(0) * 100.0 / batch_size
+ for k in topk
+ ]
+
+
+class FineGrainedCrossEntropyLoss(torch.nn.CrossEntropyLoss):
+ """
+ A cross-entropy used with hierarchical inputs and targets and only
+ looks at the finest-grained tier (the last level).
+ """
+
+ def forward(self, inputs, targets):
+ fine_grained_inputs = inputs[-1]
+ fine_grained_targets = targets[:, -1]
+ return super().forward(fine_grained_inputs, fine_grained_targets)
+
+
+class HierarchicalCrossEntropyLoss(torch.nn.CrossEntropyLoss):
+ def __init__(self, *args, coeffs=(1.0,), **kwargs):
+ super().__init__(*args, **kwargs)
+
+ if isinstance(coeffs, torch.Tensor):
+ coeffs = coeffs.clone().detach().type(torch.float)
+ else:
+ coeffs = torch.tensor(coeffs, dtype=torch.float)
+
+ self.register_buffer("coeffs", coeffs)
+
+ def forward(self, inputs, targets):
+ if not isinstance(targets, list):
+ targets = einops.rearrange(targets, "batch tiers -> tiers batch")
+
+ assert (
+ len(inputs) == len(targets) == len(self.coeffs)
+ ), f"{len(inputs)} != {len(targets)} != {len(self.coeffs)}"
+
+ losses = torch.stack(
+ [
+ # Need to specify arguments to super() because of some a bug
+ # with super() in list comprehensions/generators (unclear)
+ super(HierarchicalCrossEntropyLoss, self).forward(input, target)
+ for input, target in zip(inputs, targets)
+ ]
+ )
+
+ return torch.dot(self.coeffs, losses)
diff --git a/kernels/window_process/setup.py b/kernels/window_process/setup.py
index c78526d0..c6cb5ea6 100644
--- a/kernels/window_process/setup.py
+++ b/kernels/window_process/setup.py
@@ -1,12 +1,16 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
-
-setup(name='swin_window_process',
+setup(
+ name="swin_window_process",
ext_modules=[
- CUDAExtension('swin_window_process', [
- 'swin_window_process.cpp',
- 'swin_window_process_kernel.cu',
- ])
+ CUDAExtension(
+ "swin_window_process",
+ [
+ "swin_window_process.cpp",
+ "swin_window_process_kernel.cu",
+ ],
+ )
],
- cmdclass={'build_ext': BuildExtension})
\ No newline at end of file
+ cmdclass={"build_ext": BuildExtension},
+)
diff --git a/kernels/window_process/unit_test.py b/kernels/window_process/unit_test.py
index 65dee566..743d7691 100644
--- a/kernels/window_process/unit_test.py
+++ b/kernels/window_process/unit_test.py
@@ -4,22 +4,25 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
-import torch
-import swin_window_process
import random
import time
import unittest
+import swin_window_process
+import torch
+
class WindowProcess(torch.autograd.Function):
@staticmethod
def forward(ctx, input, B, H, W, C, shift_size, window_size):
- output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size)
+ output = swin_window_process.roll_and_window_partition_forward(
+ input, B, H, W, C, shift_size, window_size
+ )
ctx.B = B
ctx.H = H
- ctx.W = W
- ctx.C = C
+ ctx.W = W
+ ctx.C = C
ctx.shift_size = shift_size
ctx.window_size = window_size
return output
@@ -28,24 +31,28 @@ def forward(ctx, input, B, H, W, C, shift_size, window_size):
def backward(ctx, grad_in):
B = ctx.B
H = ctx.H
- W = ctx.W
- C = ctx.C
+ W = ctx.W
+ C = ctx.C
shift_size = ctx.shift_size
window_size = ctx.window_size
- grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size)
+ grad_out = swin_window_process.roll_and_window_partition_backward(
+ grad_in, B, H, W, C, shift_size, window_size
+ )
return grad_out, None, None, None, None, None, None, None
class WindowProcessReverse(torch.autograd.Function):
@staticmethod
def forward(ctx, input, B, H, W, C, shift_size, window_size):
- output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size)
+ output = swin_window_process.window_merge_and_roll_forward(
+ input, B, H, W, C, shift_size, window_size
+ )
ctx.B = B
ctx.H = H
- ctx.W = W
- ctx.C = C
+ ctx.W = W
+ ctx.C = C
ctx.shift_size = shift_size
ctx.window_size = window_size
@@ -55,12 +62,14 @@ def forward(ctx, input, B, H, W, C, shift_size, window_size):
def backward(ctx, grad_in):
B = ctx.B
H = ctx.H
- W = ctx.W
- C = ctx.C
+ W = ctx.W
+ C = ctx.C
shift_size = ctx.shift_size
window_size = ctx.window_size
- grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size)
+ grad_out = swin_window_process.window_merge_and_roll_backward(
+ grad_in, B, H, W, C, shift_size, window_size
+ )
return grad_out, None, None, None, None, None, None, None
@@ -74,9 +83,12 @@ def window_partition(x, window_size):
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ windows = (
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ )
return windows
+
def window_reverse(windows, window_size, H, W):
"""
Args:
@@ -88,7 +100,9 @@ def window_reverse(windows, window_size, H, W):
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = windows.view(
+ B, H // window_size, W // window_size, window_size, window_size, -1
+ )
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
@@ -119,6 +133,7 @@ def copy_one_tensor(input, requires_grad=True):
input1 = input.clone().detach().requires_grad_(requires_grad).cuda()
return input1
+
class Test_WindowProcess(unittest.TestCase):
def setUp(self):
self.B = 192
@@ -129,10 +144,12 @@ def setUp(self):
self.window_size = 7
self.nH = self.H // self.window_size
self.nW = self.W // self.window_size
-
+
def test_roll_and_window_partition_forward(self, dtype=torch.float32):
- input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
-
+ input = torch.randn(
+ (self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True
+ ).cuda()
+
input1 = copy_one_tensor(input, True)
input2 = copy_one_tensor(input, True)
@@ -140,15 +157,28 @@ def test_roll_and_window_partition_forward(self, dtype=torch.float32):
# ori
expected = pyt_forward(input1, self.shift_size, self.window_size)
# fused kernel
- fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size)
-
+ fused_output = WindowProcess.apply(
+ input2,
+ self.B,
+ self.H,
+ self.W,
+ self.C,
+ -self.shift_size,
+ self.window_size,
+ )
+
self.assertTrue(torch.equal(expected, fused_output))
- #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
-
+ # self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
+
def test_roll_and_window_partition_backward(self, dtype=torch.float32):
- input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
- d_loss_tensor = torch.randn((self.B*self.nW*self.nH, self.window_size, self.window_size, self.C), dtype=dtype).cuda()
-
+ input = torch.randn(
+ (self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True
+ ).cuda()
+ d_loss_tensor = torch.randn(
+ (self.B * self.nW * self.nH, self.window_size, self.window_size, self.C),
+ dtype=dtype,
+ ).cuda()
+
input1 = copy_one_tensor(input, True)
input2 = copy_one_tensor(input, True)
@@ -156,64 +186,105 @@ def test_roll_and_window_partition_backward(self, dtype=torch.float32):
expected = pyt_forward(input1, self.shift_size, self.window_size)
expected.backward(d_loss_tensor)
# fused kernel
- fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size)
+ fused_output = WindowProcess.apply(
+ input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size
+ )
fused_output.backward(d_loss_tensor)
-
+
self.assertTrue(torch.equal(expected, fused_output))
- #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
+ # self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
def test_window_merge_and_roll_forward(self, dtype=torch.float32):
- input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda()
-
+ input = torch.randn(
+ (self.B * self.nH * self.nW, self.window_size, self.window_size, self.C),
+ dtype=dtype,
+ requires_grad=True,
+ ).cuda()
+
input1 = copy_one_tensor(input, True)
input2 = copy_one_tensor(input, True)
with torch.no_grad():
# ori
- expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W)
+ expected = reverse_pyt_forward(
+ input1, self.shift_size, self.window_size, self.H, self.W
+ )
# fused kernel
- fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size)
-
+ fused_output = WindowProcessReverse.apply(
+ input2,
+ self.B,
+ self.H,
+ self.W,
+ self.C,
+ self.shift_size,
+ self.window_size,
+ )
+
self.assertTrue(torch.equal(expected, fused_output))
- #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
-
+ # self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
def test_window_merge_and_roll_backward(self, dtype=torch.float32):
- input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda()
- d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
-
+ input = torch.randn(
+ (self.B * self.nH * self.nW, self.window_size, self.window_size, self.C),
+ dtype=dtype,
+ requires_grad=True,
+ ).cuda()
+ d_loss_tensor = torch.randn(
+ (self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True
+ ).cuda()
+
input1 = copy_one_tensor(input, True)
input2 = copy_one_tensor(input, True)
# ori
- expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W)
+ expected = reverse_pyt_forward(
+ input1, self.shift_size, self.window_size, self.H, self.W
+ )
expected.backward(d_loss_tensor)
# fused kernel
- fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size)
+ fused_output = WindowProcessReverse.apply(
+ input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size
+ )
fused_output.backward(d_loss_tensor)
-
+
self.assertTrue(torch.equal(expected, fused_output))
- #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
+ # self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
def test_forward_backward_speed(self, dtype=torch.float32, times=1000):
- input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda()
- d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
-
+ input = torch.randn(
+ (self.B * self.nH * self.nW, self.window_size, self.window_size, self.C),
+ dtype=dtype,
+ requires_grad=True,
+ ).cuda()
+ d_loss_tensor = torch.randn(
+ (self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True
+ ).cuda()
+
input1 = copy_one_tensor(input, True)
input2 = copy_one_tensor(input, True)
# SwinTransformer official
def run_pyt(t=1000):
for _ in range(t):
- expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W)
+ expected = reverse_pyt_forward(
+ input1, self.shift_size, self.window_size, self.H, self.W
+ )
expected.backward(d_loss_tensor)
# my op
def run_fusedop(t=1000):
for _ in range(t):
- fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size)
+ fused_output = WindowProcessReverse.apply(
+ input2,
+ self.B,
+ self.H,
+ self.W,
+ self.C,
+ self.shift_size,
+ self.window_size,
+ )
fused_output.backward(d_loss_tensor)
-
+
torch.cuda.synchronize()
t1 = time.time()
run_pyt(t=times)
@@ -224,10 +295,10 @@ def run_fusedop(t=1000):
t3 = time.time()
self.assertTrue((t3 - t2) < (t2 - t1))
- print('Run {} times'.format(times))
- print('Original time cost: {}'.format(t2 - t1))
- print('Fused op time cost: {}'.format(t3 - t2))
-
+ print("Run {} times".format(times))
+ print("Original time cost: {}".format(t2 - t1))
+ print("Fused op time cost: {}".format(t3 - t2))
+
def test_roll_and_window_partition_forward_fp16(self, dtype=torch.float16):
self.test_roll_and_window_partition_forward(dtype=dtype)
@@ -236,7 +307,7 @@ def test_roll_and_window_partition_backward_fp16(self, dtype=torch.float16):
def test_window_merge_and_roll_forward_fp16(self, dtype=torch.float16):
self.test_window_merge_and_roll_forward(dtype=dtype)
-
+
def test_window_merge_and_roll_backward_fp16(self, dtype=torch.float16):
self.test_window_merge_and_roll_backward(dtype=dtype)
@@ -244,7 +315,7 @@ def test_forward_backward_speed_fp16(self, dtype=torch.float16, times=1000):
self.test_forward_backward_speed(dtype=dtype, times=times)
-if __name__ == '__main__':
- print('Pass only two tensors are exactly the same (using torch.equal).\n')
+if __name__ == "__main__":
+ print("Pass only two tensors are exactly the same (using torch.equal).\n")
torch.manual_seed(0)
unittest.main(verbosity=2)
diff --git a/kernels/window_process/window_process.py b/kernels/window_process/window_process.py
index ee43e9e9..d482d4f7 100644
--- a/kernels/window_process/window_process.py
+++ b/kernels/window_process/window_process.py
@@ -4,19 +4,21 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
-import torch
import swin_window_process
+import torch
class WindowProcess(torch.autograd.Function):
@staticmethod
def forward(ctx, input, B, H, W, C, shift_size, window_size):
- output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size)
+ output = swin_window_process.roll_and_window_partition_forward(
+ input, B, H, W, C, shift_size, window_size
+ )
ctx.B = B
ctx.H = H
- ctx.W = W
- ctx.C = C
+ ctx.W = W
+ ctx.C = C
ctx.shift_size = shift_size
ctx.window_size = window_size
return output
@@ -25,24 +27,28 @@ def forward(ctx, input, B, H, W, C, shift_size, window_size):
def backward(ctx, grad_in):
B = ctx.B
H = ctx.H
- W = ctx.W
- C = ctx.C
+ W = ctx.W
+ C = ctx.C
shift_size = ctx.shift_size
window_size = ctx.window_size
- grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size)
+ grad_out = swin_window_process.roll_and_window_partition_backward(
+ grad_in, B, H, W, C, shift_size, window_size
+ )
return grad_out, None, None, None, None, None, None, None
class WindowProcessReverse(torch.autograd.Function):
@staticmethod
def forward(ctx, input, B, H, W, C, shift_size, window_size):
- output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size)
+ output = swin_window_process.window_merge_and_roll_forward(
+ input, B, H, W, C, shift_size, window_size
+ )
ctx.B = B
ctx.H = H
- ctx.W = W
- ctx.C = C
+ ctx.W = W
+ ctx.C = C
ctx.shift_size = shift_size
ctx.window_size = window_size
@@ -52,12 +58,14 @@ def forward(ctx, input, B, H, W, C, shift_size, window_size):
def backward(ctx, grad_in):
B = ctx.B
H = ctx.H
- W = ctx.W
- C = ctx.C
+ W = ctx.W
+ C = ctx.C
shift_size = ctx.shift_size
window_size = ctx.window_size
- #grad_out = ctx.saved_tensors[0]
- #grad_out = torch.zeros((B, H, W, C), dtype=dtype).cuda()
- grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size)
+ # grad_out = ctx.saved_tensors[0]
+ # grad_out = torch.zeros((B, H, W, C), dtype=dtype).cuda()
+ grad_out = swin_window_process.window_merge_and_roll_backward(
+ grad_in, B, H, W, C, shift_size, window_size
+ )
return grad_out, None, None, None, None, None, None, None
diff --git a/logger.py b/logger.py
index a066e55b..14a7d6d1 100644
--- a/logger.py
+++ b/logger.py
@@ -5,37 +5,149 @@
# Written by Ze Liu
# --------------------------------------------------------
+import functools
+import logging
import os
import sys
-import logging
-import functools
+
+import torch
from termcolor import colored
+from torch.utils.tensorboard import SummaryWriter
+
+import wandb
@functools.lru_cache()
-def create_logger(output_dir, dist_rank=0, name=''):
+def create_logger(output_dir, dist_rank=0, name=""):
# create logger
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.propagate = False
# create formatter
- fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
- color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
- colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
+ fmt = "[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s"
+ color_fmt = (
+ colored("[%(asctime)s %(name)s]", "green")
+ + colored("(%(filename)s %(lineno)d)", "yellow")
+ + ": %(levelname)s %(message)s"
+ )
# create console handlers for master process
if dist_rank == 0:
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter(
- logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
+ logging.Formatter(fmt=color_fmt, datefmt="%Y-%m-%d %H:%M:%S")
+ )
logger.addHandler(console_handler)
# create file handlers
- file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')
+ file_handler = logging.FileHandler(
+ os.path.join(output_dir, f"log_rank{dist_rank}.txt"), mode="a"
+ )
file_handler.setLevel(logging.DEBUG)
- file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
+ file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt="%Y-%m-%d %H:%M:%S"))
logger.addHandler(file_handler)
return logger
+
+
+class TensorboardWriter:
+ writer = None
+
+ def __init__(self, output_dir, dist_rank):
+ self.output_dir = output_dir
+
+ if dist_rank == 0:
+ self.writer = SummaryWriter(log_dir=self.output_dir)
+
+ def add_hparams(self, hparams, metrics):
+ pass
+
+ def log(self, items, step):
+ if self.writer is None:
+ return
+
+ # Copied from huggingface/accelerate/src/accelerate/tracking.py
+ for k, v in items.items():
+ if isinstance(v, (int, float)):
+ self.writer.add_scalar(k, v, global_step=step)
+ elif isinstance(v, torch.Tensor):
+ assert v.numel() == 1
+ self.writer.add_scalar(k, v.item(), global_step=step)
+ elif isinstance(v, str):
+ self.writer.add_text(k, v, global_step=step)
+ elif isinstance(v, dict):
+ self.writer.add_scalars(k, v, global_step=step)
+ else:
+ print(f"Can't log {v} because it is {type(v)}!")
+
+
+class WandbWriter:
+ def __init__(self, rank):
+ self.rank = rank
+
+ def init(self, config):
+ if self.rank != 0:
+ return
+
+ kwargs = dict(
+ config=config,
+ project="hierarchical-vision",
+ name=config.EXPERIMENT.NAME,
+ )
+
+ if not config.EXPERIMENT.WANDB_ID:
+ print("Cannot resume wandb run because no id was provided!")
+ else:
+ kwargs["id"] = config.EXPERIMENT.WANDB_ID
+ kwargs["resume"] = "allow"
+
+ wandb.init(**kwargs, mode="disabled")
+
+ # Validation metrics
+ wandb.define_metric("val/loss", step_metric="epoch", summary="max")
+ wandb.define_metric("val/acc1", step_metric="epoch", summary="max")
+ wandb.define_metric("val/acc5", step_metric="epoch", summary="max")
+
+ # Training metrics
+ wandb.define_metric("train/batch_time", step_metric="step", summary="last")
+ wandb.define_metric("train/grad_norm", step_metric="step", summary="last")
+ wandb.define_metric("train/batch_loss", step_metric="step", summary="last")
+ wandb.define_metric("train/loss_scale", step_metric="step", summary="last")
+ wandb.define_metric("train/learning_rate", step_metric="step", summary="last")
+
+ wandb.define_metric("train/epoch_time", step_metric="epoch", summary="last")
+ wandb.define_metric("train/loss", step_metric="epoch", summary="last")
+
+ # Other metrics
+ wandb.define_metric("memory_mb", summary="max")
+
+ def log(self, dct):
+ if self.rank != 0:
+ return
+
+ wandb.log(dct)
+
+ @property
+ def name(self):
+ if self.rank != 0:
+ raise RuntimeError(f"Should not get .name with rank {self.rank}.")
+
+ return wandb.run.name
+
+
+############################# Added Script to generate yaml files #################################
+
+def init(name: str, verbose: bool = False, date=True) -> logging.Logger:
+ if date:
+ log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
+ else:
+ log_format = "[%(levelname)s] [%(name)s] %(message)s"
+
+ if not verbose:
+ logging.basicConfig(level=logging.INFO, format=log_format)
+ else:
+ logging.basicConfig(level=logging.DEBUG, format=log_format)
+
+ return logging.getLogger(name)
diff --git a/lr_scheduler.py b/lr_scheduler.py
index a2122e5d..fa0ec4f1 100644
--- a/lr_scheduler.py
+++ b/lr_scheduler.py
@@ -9,8 +9,8 @@
import torch
from timm.scheduler.cosine_lr import CosineLRScheduler
-from timm.scheduler.step_lr import StepLRScheduler
from timm.scheduler.scheduler import Scheduler
+from timm.scheduler.step_lr import StepLRScheduler
def build_scheduler(config, optimizer, n_iter_per_epoch):
@@ -20,11 +20,13 @@ def build_scheduler(config, optimizer, n_iter_per_epoch):
multi_steps = [i * n_iter_per_epoch for i in config.TRAIN.LR_SCHEDULER.MULTISTEPS]
lr_scheduler = None
- if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
+ if config.TRAIN.LR_SCHEDULER.NAME == "cosine":
lr_scheduler = CosineLRScheduler(
optimizer,
- t_initial=(num_steps - warmup_steps) if config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX else num_steps,
- t_mul=1.,
+ t_initial=(num_steps - warmup_steps)
+ if config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX
+ else num_steps,
+ cycle_mul=1.0,
lr_min=config.TRAIN.MIN_LR,
warmup_lr_init=config.TRAIN.WARMUP_LR,
warmup_t=warmup_steps,
@@ -32,7 +34,7 @@ def build_scheduler(config, optimizer, n_iter_per_epoch):
t_in_epochs=False,
warmup_prefix=config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX,
)
- elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
+ elif config.TRAIN.LR_SCHEDULER.NAME == "linear":
lr_scheduler = LinearLRScheduler(
optimizer,
t_initial=num_steps,
@@ -41,16 +43,16 @@ def build_scheduler(config, optimizer, n_iter_per_epoch):
warmup_t=warmup_steps,
t_in_epochs=False,
)
- elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
+ elif config.TRAIN.LR_SCHEDULER.NAME == "step":
lr_scheduler = StepLRScheduler(
optimizer,
decay_t=decay_steps,
- decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
+ cycle_decay=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
warmup_lr_init=config.TRAIN.WARMUP_LR,
warmup_t=warmup_steps,
t_in_epochs=False,
)
- elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep':
+ elif config.TRAIN.LR_SCHEDULER.NAME == "multistep":
lr_scheduler = MultiStepLRScheduler(
optimizer,
milestones=multi_steps,
@@ -59,28 +61,58 @@ def build_scheduler(config, optimizer, n_iter_per_epoch):
warmup_t=warmup_steps,
t_in_epochs=False,
)
+ elif config.TRAIN.LR_SCHEDULER.NAME == "constant":
+ lr_scheduler = ConstantLRScheduler(
+ optimizer,
+ warmup_lr_init=config.TRAIN.WARMUP_LR,
+ warmup_t=warmup_steps,
+ t_in_epochs=False,
+ )
return lr_scheduler
-class LinearLRScheduler(Scheduler):
- def __init__(self,
- optimizer: torch.optim.Optimizer,
- t_initial: int,
- lr_min_rate: float,
- warmup_t=0,
- warmup_lr_init=0.,
- t_in_epochs=True,
- noise_range_t=None,
- noise_pct=0.67,
- noise_std=1.0,
- noise_seed=42,
- initialize=True,
- ) -> None:
+class TimmScheduler(Scheduler):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def get_epoch_values(self, epoch: int):
+ if self.t_in_epochs:
+ return self._get_lr(epoch)
+ else:
+ return None
+
+ def get_update_values(self, num_updates: int):
+ if not self.t_in_epochs:
+ return self._get_lr(num_updates)
+ else:
+ return None
+
+
+class LinearLRScheduler(TimmScheduler):
+ def __init__(
+ self,
+ optimizer: torch.optim.Optimizer,
+ t_initial: int,
+ lr_min_rate: float,
+ warmup_t=0,
+ warmup_lr_init=0.0,
+ t_in_epochs=True,
+ noise_range_t=None,
+ noise_pct=0.67,
+ noise_std=1.0,
+ noise_seed=42,
+ initialize=True,
+ ) -> None:
super().__init__(
- optimizer, param_group_field="lr",
- noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
- initialize=initialize)
+ optimizer,
+ param_group_field="lr",
+ noise_range_t=noise_range_t,
+ noise_pct=noise_pct,
+ noise_std=noise_std,
+ noise_seed=noise_seed,
+ initialize=initialize,
+ )
self.t_initial = t_initial
self.lr_min_rate = lr_min_rate
@@ -88,7 +120,9 @@ def __init__(self,
self.warmup_lr_init = warmup_lr_init
self.t_in_epochs = t_in_epochs
if self.warmup_t:
- self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
+ self.warmup_steps = [
+ (v - warmup_lr_init) / self.warmup_t for v in self.base_values
+ ]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
@@ -99,54 +133,75 @@ def _get_lr(self, t):
else:
t = t - self.warmup_t
total_t = self.t_initial - self.warmup_t
- lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
+ lrs = [
+ v - ((v - v * self.lr_min_rate) * (t / total_t))
+ for v in self.base_values
+ ]
return lrs
- def get_epoch_values(self, epoch: int):
- if self.t_in_epochs:
- return self._get_lr(epoch)
- else:
- return None
- def get_update_values(self, num_updates: int):
- if not self.t_in_epochs:
- return self._get_lr(num_updates)
- else:
- return None
-
-
-class MultiStepLRScheduler(Scheduler):
- def __init__(self, optimizer: torch.optim.Optimizer, milestones, gamma=0.1, warmup_t=0, warmup_lr_init=0, t_in_epochs=True) -> None:
+class MultiStepLRScheduler(TimmScheduler):
+ def __init__(
+ self,
+ optimizer: torch.optim.Optimizer,
+ milestones,
+ gamma=0.1,
+ warmup_t=0,
+ warmup_lr_init=0,
+ t_in_epochs=True,
+ ) -> None:
super().__init__(optimizer, param_group_field="lr")
-
+
self.milestones = milestones
self.gamma = gamma
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.t_in_epochs = t_in_epochs
if self.warmup_t:
- self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
+ self.warmup_steps = [
+ (v - warmup_lr_init) / self.warmup_t for v in self.base_values
+ ]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
-
+
assert self.warmup_t <= min(self.milestones)
-
+
def _get_lr(self, t):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
- lrs = [v * (self.gamma ** bisect.bisect_right(self.milestones, t)) for v in self.base_values]
+ lrs = [
+ v * (self.gamma ** bisect.bisect_right(self.milestones, t))
+ for v in self.base_values
+ ]
return lrs
- def get_epoch_values(self, epoch: int):
- if self.t_in_epochs:
- return self._get_lr(epoch)
+
+class ConstantLRScheduler(TimmScheduler):
+ def __init__(
+ self,
+ optimizer: torch.optim.Optimizer,
+ warmup_t=0,
+ warmup_lr_init=0,
+ t_in_epochs=True,
+ ) -> None:
+ super().__init__(optimizer, param_group_field="lr")
+
+ self.warmup_t = warmup_t
+ self.warmup_lr_init = warmup_lr_init
+ self.t_in_epochs = t_in_epochs
+ if self.warmup_t:
+ self.warmup_steps = [
+ (v - warmup_lr_init) / self.warmup_t for v in self.base_values
+ ]
+ super().update_groups(self.warmup_lr_init)
else:
- return None
+ self.warmup_steps = [1 for _ in self.base_values]
- def get_update_values(self, num_updates: int):
- if not self.t_in_epochs:
- return self._get_lr(num_updates)
+ def _get_lr(self, t):
+ if t < self.warmup_t:
+ lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
- return None
+ lrs = [v for v in self.base_values]
+ return lrs
diff --git a/main.py b/main.py
index 84230ea7..416d7119 100644
--- a/main.py
+++ b/main.py
@@ -5,84 +5,171 @@
# Written by Ze Liu
# --------------------------------------------------------
-import os
-import time
-import json
-import random
import argparse
import datetime
-import numpy as np
+import json
+import os
+import random
+import time
+import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
-
-from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
-from timm.utils import accuracy, AverageMeter
+from timm.loss import LabelSmoothingCrossEntropy
+from timm.utils import AverageMeter
from config import get_config
-from models import build_model
from data import build_loader
+from hierarchical import (
+ FineGrainedCrossEntropyLoss,
+ HierarchicalCrossEntropyLoss,
+ accuracy,
+)
+from logger import WandbWriter, create_logger
from lr_scheduler import build_scheduler
+from models import build_model
from optimizer import build_optimizer
-from logger import create_logger
-from utils import load_checkpoint, load_pretrained, save_checkpoint, NativeScalerWithGradNormCount, auto_resume_helper, \
- reduce_tensor
+from utils import (
+ NativeScalerWithGradNormCount,
+ auto_resume_helper,
+ batch_size_of,
+ load_checkpoint,
+ load_pretrained,
+ reduce_tensor,
+ save_checkpoint,
+ find_experiments,
+)
+
+
+
+class EarlyStopper:
+ def __init__(self, patience=1, min_delta=0):
+ self.patience = patience
+ self.min_delta = min_delta
+ self.counter = 0
+ self.min_validation_loss = np.inf
+
+ def early_stop(self, validation_loss):
+ if validation_loss < self.min_validation_loss:
+ self.min_validation_loss = validation_loss
+ self.counter = 0
+ elif validation_loss > (self.min_validation_loss + self.min_delta):
+ self.counter += 1
+ if self.counter >= self.patience:
+ return True
+ return False
+
def parse_option():
- parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
- parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
+ parser = argparse.ArgumentParser(
+ "Swin Transformer training and evaluation script", add_help=False
+ )
+ parser.add_argument(
+ "--cfg",
+ nargs="+",
+ #type=str,
+ required=True,
+ #metavar="FILE",
+ help="Paths to directories containing config.yaml files OR just a config.yaml file. Directories will be searched for any nested config.yaml files.",
+ )
parser.add_argument(
"--opts",
help="Modify config options by adding 'KEY VALUE' pairs. ",
default=None,
- nargs='+',
+ nargs="+",
)
# easy config modification
- parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
- parser.add_argument('--data-path', type=str, help='path to dataset')
- parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
- parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
- help='no: no cache, '
- 'full: cache all data, '
- 'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
- parser.add_argument('--pretrained',
- help='pretrained weight from checkpoint, could be imagenet22k pretrained weight')
- parser.add_argument('--resume', help='resume from checkpoint')
- parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
- parser.add_argument('--use-checkpoint', action='store_true',
- help="whether to use gradient checkpointing to save memory")
- parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp')
- parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'],
- help='mixed precision opt level, if O0, no amp is used (deprecated!)')
- parser.add_argument('--output', default='output', type=str, metavar='PATH',
- help='root of output folder, the full path is