We recommend following the instructions below to setup the environment:
# create dusa environment
conda create -n dusa -y python=3.9.18
conda activate dusa
# install torch
pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121
# install open-mmlab
pip install openmim==0.3.9
mim install "mmcv==2.1.0"
mim install "mmengine==0.10.2"
# install other requirements
conda env update -n dusa -f env.ymlImportant
The final tree of files should be:
DUSA/classification/pretrained_models/
├── B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz
├── convnext-large_3rdparty_64xb64_in1k_20220124-f8a0ded0.pth
├── DiT-XL-2-256x256.pt
└── resnet50_gn_a1h2-8fe6c4d0.pth
We use pre-trained weights from the following sources:
ResNet50-GNViT-B/16ConvNeXt-LargeDiT-XL/2stabilityai/sd-vae-ft-emais the VAE used in DiT-XL/2
We assume the weights are placed in the ./pretrained_models directory. To prepare in cli, refer to the following:
# create folder if not exists
mkdir -p pretrained_models && cd pretrained_models
# download weights for
# ResNet50-GN
wget https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_gn_a1h2-8fe6c4d0.pth
# Vit-B/16
wget https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz
# ConvNeXt-Large
wget https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_64xb64_in1k_20220124-f8a0ded0.pth
# DiT-XL/2
wget https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.ptThe VAE used in DiT-XL/2 should be automatically downloaded while running experiments.
Alternatively, we could download it manually:
huggingface-cli download --resume-download stabilityai/sd-vae-ft-ema config.json diffusion_pytorch_model.safetensorsTip
In case of network issues, give https://hf-mirror.com/ a try:
HF_ENDPOINT=https://hf-mirror.com huggingface-cli download --resume-download stabilityai/sd-vae-ft-ema config.json diffusion_pytorch_model.safetensorsImportant
The final tree of files should be:
DUSA/classification/data
├── ImageNet -> /path/to/imagenet
└── imagenet-c -> /path/to/imagenet-c
The ImageNet-C dataset can be downloaded from here. Refer to the commands below:
wget https://zenodo.org/records/2235448/files/blur.tar?download=1 -c -O blur.tar
wget https://zenodo.org/records/2235448/files/digital.tar?download=1 -c -O digital.tar
wget https://zenodo.org/records/2235448/files/extra.tar?download=1 -c -O extra.tar
wget https://zenodo.org/records/2235448/files/noise.tar?download=1 -c -O noise.tar
wget https://zenodo.org/records/2235448/files/weather.tar?download=1 -c -O weather.tarExtract the dataset to your selected /path/to/imagenet-c and symlink the dataset:
mkdir -p data
ln -s /path/to/imagenet-c data/imagenet-cNote
Only ImageNet-C is required for reproducing DUSA.
To reproduce the EATA baseline, the validation set of ImageNet is also required.
ImageNet validation set can be officially accessed here.
Or download with cli:
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar --no-check-certificateExtract the val split to your selected /path/to/imagenet and symlink the dataset:
mkdir -p data
ln -s /path/to/imagenet data/ImageNetImportant
Make sure the validation set files are in data/ImageNet/val/.
The scripts are available in the ./sh directory. Change CUDA_VISIBLE_DEVICES if needed.
-
DUSA
bash sh/convnext-l/dusa_convnext_in-c.sh
-
DUSA-U
bash sh/convnext-l/dusa-u_convnext_in-c.sh
-
Ablation on Noise
bash sh/convnext-l/dusa_ablation_convnext.sh
-
Baselines
bash sh/convnext-l/baselines_convnext_in-c.sh
-
DUSA
bash sh/convnext-l/dusa_continual_convnext_in-c.sh
-
Baselines
bash sh/convnext-l/baselines_continual_convnext_in-c.sh
-
DUSA
bash sh/vit-b/dusa_vit_in-c.sh
-
DUSA-U
bash sh/vit-b/dusa-u_vit_in-c.sh
-
Baselines
bash sh/vit-b/baselines_vit_in-c.sh
-
DUSA
bash sh/res50-gn/dusa_res50_in-c.sh
-
DUSA-U
bash sh/res50-gn/dusa-u_res50_in-c.sh
-
Ablation on Noise
bash sh/res50-gn/dusa_ablation_res50.sh
-
Baselines
bash sh/res50-gn/baselines_res50_in-c.sh
This implementation is based on MMPreTrain and inspired by Diffusion-TTA. The baseline code are borrowed from their official implementations in Tent, CoTTA, EATA, SAR, and RoTTA. We thank their authors for making the source code publicly available.