Skip to content

Commit bd0dd35

Browse files
committed
added training codes
0 parents  commit bd0dd35

37 files changed

+4449
-0
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
*__pycache__*
2+
experiments*
3+
model/pretrained_checkpoints*
4+
support_data.pth
5+
data_paths.yaml

README.md

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Visual Token Matching
2+
3+
This repository contains official code for [Universal Few-shot Learning of Dense Prediction Tasks with Visual Token Matching](https://openreview.net/forum?id=88nT0j5jAn) (ICLR 2023 oral).
4+
5+
Currently, we only include codes for our model architecture and episodic training.
6+
We will release remaining codes (fine-tuning and evaluation) soon.
7+
8+
## Setup
9+
1. Download Taskonomy Dataset (tiny split) from the official github page https://github.com/StanfordVL/taskonomy/tree/master/data.
10+
* You may download data of `depth_euclidean`, `depth_zbuffer`, `keypoints2d`, `keypoints3d`, `normal`, `principal_curvature`, `reshading`, `segment_semantic`, and `rgb`.
11+
* (Optional) Resize the images and labels into (256, 256) resolution.
12+
* To reduce the I/O bottleneck of dataloader, we stored data from all buildings in a single directory. The directory structure looks like:
13+
```
14+
<root>
15+
|--<task1>
16+
| |--<building1>_<file_name1>
17+
| | ...
18+
| |--<building2>_<file_name1>
19+
| |...
20+
|
21+
|--<task2>
22+
| |--<building1>_<file_name1>
23+
| | ...
24+
| |--<building2>_<file_name1>
25+
| |...
26+
|
27+
|...
28+
```
29+
30+
2. Create `data_paths.yaml` file and write the root directory path (`<root>` in the above structure) by `taskonomy: PATH_TO_YOUR_TASKONOMY`.
31+
32+
3. Install pre-requirements by `pip install -r requirements.txt`.
33+
34+
4. Create `model/pretrained_checkpoints` directory and download [BEiT pre-trained checkpoints](https://github.com/microsoft/unilm/tree/master/beit) to the directory.
35+
* We used `beit_base_patch16_224_pt22k` checkpoint for our experiment.
36+
37+
## Usage
38+
```
39+
python main.py --task_fold [0/1/2/3/4]
40+
```
41+
42+
43+
## Citation
44+
If you find this work useful, please consider citing:
45+
```
46+
@inproceedings{
47+
kim2023universal,
48+
title={Universal Few-shot Learning of Dense Prediction Tasks with Visual Token Matching},
49+
author={Donggyun Kim and Jinwoo Kim and Seongwoong Cho and Chong Luo and Seunghoon Hong},
50+
booktitle={International Conference on Learning Representations},
51+
year={2023},
52+
url={https://openreview.net/forum?id=88nT0j5jAn}
53+
}
54+
```

args.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import argparse
2+
import yaml
3+
4+
5+
def str2bool(v):
6+
if v == 'True' or v == 'true':
7+
return True
8+
elif v == 'False' or v == 'false':
9+
return False
10+
else:
11+
raise argparse.ArgumentTypeError('Boolean value expected.')
12+
13+
14+
# argument parser
15+
parser = argparse.ArgumentParser()
16+
17+
# environment arguments
18+
parser.add_argument('--seed', type=int, default=0)
19+
parser.add_argument('--precision', '-prc', type=str, default='bf16', choices=['fp32', 'fp16', 'bf16'])
20+
parser.add_argument('--strategy', '-str', type=str, default='ddp', choices=['none', 'ddp'])
21+
parser.add_argument('--debug_mode', '-debug', default=False, action='store_true')
22+
parser.add_argument('--continue_mode', '-cont', default=False, action='store_true')
23+
parser.add_argument('--skip_mode', '-skip', default=False, action='store_true')
24+
parser.add_argument('--no_eval', '-ne', default=False, action='store_true')
25+
parser.add_argument('--no_save', '-ns', default=False, action='store_true')
26+
parser.add_argument('--reset_mode', '-reset', default=False, action='store_true')
27+
parser.add_argument('--profile_mode', '-prof', default=False, action='store_true')
28+
parser.add_argument('--sanity_check', '-sc', default=False, action='store_true')
29+
30+
# data arguments
31+
parser.add_argument('--dataset', type=str, default='taskonomy', choices=['taskonomy'])
32+
parser.add_argument('--task', type=str, default='', choices=['', 'all'])
33+
parser.add_argument('--task_fold', '-fold', type=int, default=0, choices=[0, 1, 2, 3, 4])
34+
35+
parser.add_argument('--num_workers', '-nw', type=int, default=8)
36+
parser.add_argument('--global_batch_size', '-gbs', type=int, default=8)
37+
parser.add_argument('--max_channels', '-mc', type=int, default=5)
38+
parser.add_argument('--shot', type=int, default=4)
39+
parser.add_argument('--domains_per_batch', '-dpb', type=int, default=2)
40+
parser.add_argument('--eval_batch_size', '-ebs', type=int, default=8)
41+
parser.add_argument('--n_eval_batches', '-neb', type=int, default=10)
42+
43+
parser.add_argument('--img_size', type=int, default=224, choices=[224])
44+
parser.add_argument('--image_augmentation', '-ia', type=str2bool, default=True)
45+
parser.add_argument('--unary_augmentation', '-ua', type=str2bool, default=True)
46+
parser.add_argument('--binary_augmentation', '-ba', type=str2bool, default=True)
47+
parser.add_argument('--mixed_augmentation', '-ma', type=str2bool, default=True)
48+
49+
# model arguments
50+
parser.add_argument('--model', type=str, default='VTM', choices=['VTM'])
51+
parser.add_argument('--image_backbone', '-ib', type=str, default='beit_base_patch16_224_in22k')
52+
parser.add_argument('--label_backbone', '-lb', type=str, default='vit_base_patch16_224')
53+
parser.add_argument('--image_encoder_weights', '-iew', type=str, default='imagenet', choices=['none', 'imagenet'])
54+
parser.add_argument('--label_encoder_weights', '-lew', type=str, default='none', choices=['none', 'imagenet'])
55+
parser.add_argument('--n_attn_heads', '-nah', type=int, default=4)
56+
parser.add_argument('--n_attn_layers', '-nal', type=int, default=1)
57+
parser.add_argument('--attn_residual', '-ar', type=str2bool, default=True)
58+
parser.add_argument('--out_activation', '-oa', type=str, default='sigmoid', choices=['sigmoid', 'clip', 'none'])
59+
parser.add_argument('--drop_rate', '-dr', type=float, default=0.0)
60+
parser.add_argument('--drop_path_rate', '-dpr', type=float, default=0.1)
61+
parser.add_argument('--bitfit', '-bf', type=str2bool, default=True)
62+
parser.add_argument('--semseg_threshold', '-th', type=float, default=0.2)
63+
64+
# training arguments
65+
parser.add_argument('--n_steps', '-nst', type=int, default=300000)
66+
parser.add_argument('--optimizer', '-opt', type=str, default='adam', choices=['sgd', 'adam', 'adamw', 'fadam', 'dsadam'])
67+
parser.add_argument('--lr', type=float, default=1e-4)
68+
parser.add_argument('--lr_pretrained', '-lrp', type=float, default=1e-5)
69+
parser.add_argument('--lr_schedule', '-lrs', type=str, default='poly', choices=['constant', 'sqroot', 'cos', 'poly'])
70+
parser.add_argument('--lr_warmup', '-lrw', type=int, default=5000)
71+
parser.add_argument('--lr_warmup_scale', '-lrws', type=float, default=0.)
72+
parser.add_argument('--weight_decay', '-wd', type=float, default=0.)
73+
parser.add_argument('--lr_decay_degree', '-ldd', type=float, default=0.9)
74+
parser.add_argument('--temperature', '-temp', type=float, default=-1.)
75+
parser.add_argument('--reg_coef', '-rgc', type=float, default=1.)
76+
parser.add_argument('--mask_value', '-mv', type=float, default=-1.)
77+
78+
# logging arguments
79+
parser.add_argument('--log_dir', type=str, default='TRAIN')
80+
parser.add_argument('--save_dir', type=str, default='')
81+
parser.add_argument('--load_dir', type=str, default='')
82+
parser.add_argument('--exp_name', type=str, default='')
83+
parser.add_argument('--name_postfix', '-ptf', type=str, default='')
84+
parser.add_argument('--log_iter', '-li', type=int, default=100)
85+
parser.add_argument('--val_iter', '-vi', type=int, default=10000)
86+
parser.add_argument('--save_iter', '-si', type=int, default=10000)
87+
parser.add_argument('--load_step', '-ls', type=int, default=-1)
88+
89+
config = parser.parse_args()
90+
91+
92+
# retrieve data root
93+
with open('data_paths.yaml', 'r') as f:
94+
path_dict = yaml.safe_load(f)
95+
config.root_dir = path_dict[config.dataset]
96+
if config.save_dir == '':
97+
config.save_dir = config.log_dir
98+
if config.load_dir == '':
99+
config.load_dir = config.log_dir
100+
101+
# for debugging
102+
if config.debug_mode:
103+
config.n_steps = 10
104+
config.log_iter = 1
105+
config.val_iter = 5
106+
config.save_iter = 5
107+
config.n_eval_batches = 4
108+
config.log_dir += '_debugging'
109+
config.save_dir += '_debugging'
110+
config.load_dir += '_debugging'
111+
112+
113+
# model-specific hyper-parameters
114+
config.n_levels = 4
115+
116+
# adjust backbone names
117+
if config.image_backbone in ['beit_base', 'beit_large']:
118+
config.image_backbone += '_patch16_224_in22k'
119+
if config.image_backbone in ['vit_tiny', 'vit_small', 'vit_base', 'vit_large']:
120+
config.image_backbone += '_patch16_224'
121+
if config.label_backbone in ['vit_tiny', 'vit_small', 'vit_base', 'vit_large']:
122+
config.label_backbone += '_patch16_224'

dataset/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)