Skip to content

Commit 186db89

Browse files
committed
added finetune and test codes
1 parent bd0dd35 commit 186db89

29 files changed

+968
-271
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ experiments*
33
model/pretrained_checkpoints*
44
support_data.pth
55
data_paths.yaml
6+
*ipynb*

args.py

Lines changed: 85 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import argparse
22
import yaml
3+
from easydict import EasyDict
4+
5+
from dataset.taskonomy_constants import TASKS_GROUP_NAMES, TASKS_GROUP_TEST
36

47

58
def str2bool(v):
@@ -14,109 +17,110 @@ def str2bool(v):
1417
# argument parser
1518
parser = argparse.ArgumentParser()
1619

17-
# environment arguments
18-
parser.add_argument('--seed', type=int, default=0)
19-
parser.add_argument('--precision', '-prc', type=str, default='bf16', choices=['fp32', 'fp16', 'bf16'])
20-
parser.add_argument('--strategy', '-str', type=str, default='ddp', choices=['none', 'ddp'])
20+
# necessary arguments
2121
parser.add_argument('--debug_mode', '-debug', default=False, action='store_true')
2222
parser.add_argument('--continue_mode', '-cont', default=False, action='store_true')
2323
parser.add_argument('--skip_mode', '-skip', default=False, action='store_true')
2424
parser.add_argument('--no_eval', '-ne', default=False, action='store_true')
2525
parser.add_argument('--no_save', '-ns', default=False, action='store_true')
2626
parser.add_argument('--reset_mode', '-reset', default=False, action='store_true')
27-
parser.add_argument('--profile_mode', '-prof', default=False, action='store_true')
28-
parser.add_argument('--sanity_check', '-sc', default=False, action='store_true')
29-
30-
# data arguments
31-
parser.add_argument('--dataset', type=str, default='taskonomy', choices=['taskonomy'])
32-
parser.add_argument('--task', type=str, default='', choices=['', 'all'])
33-
parser.add_argument('--task_fold', '-fold', type=int, default=0, choices=[0, 1, 2, 3, 4])
34-
35-
parser.add_argument('--num_workers', '-nw', type=int, default=8)
36-
parser.add_argument('--global_batch_size', '-gbs', type=int, default=8)
37-
parser.add_argument('--max_channels', '-mc', type=int, default=5)
38-
parser.add_argument('--shot', type=int, default=4)
39-
parser.add_argument('--domains_per_batch', '-dpb', type=int, default=2)
40-
parser.add_argument('--eval_batch_size', '-ebs', type=int, default=8)
41-
parser.add_argument('--n_eval_batches', '-neb', type=int, default=10)
42-
43-
parser.add_argument('--img_size', type=int, default=224, choices=[224])
44-
parser.add_argument('--image_augmentation', '-ia', type=str2bool, default=True)
45-
parser.add_argument('--unary_augmentation', '-ua', type=str2bool, default=True)
46-
parser.add_argument('--binary_augmentation', '-ba', type=str2bool, default=True)
47-
parser.add_argument('--mixed_augmentation', '-ma', type=str2bool, default=True)
48-
49-
# model arguments
50-
parser.add_argument('--model', type=str, default='VTM', choices=['VTM'])
51-
parser.add_argument('--image_backbone', '-ib', type=str, default='beit_base_patch16_224_in22k')
52-
parser.add_argument('--label_backbone', '-lb', type=str, default='vit_base_patch16_224')
53-
parser.add_argument('--image_encoder_weights', '-iew', type=str, default='imagenet', choices=['none', 'imagenet'])
54-
parser.add_argument('--label_encoder_weights', '-lew', type=str, default='none', choices=['none', 'imagenet'])
55-
parser.add_argument('--n_attn_heads', '-nah', type=int, default=4)
56-
parser.add_argument('--n_attn_layers', '-nal', type=int, default=1)
57-
parser.add_argument('--attn_residual', '-ar', type=str2bool, default=True)
58-
parser.add_argument('--out_activation', '-oa', type=str, default='sigmoid', choices=['sigmoid', 'clip', 'none'])
59-
parser.add_argument('--drop_rate', '-dr', type=float, default=0.0)
60-
parser.add_argument('--drop_path_rate', '-dpr', type=float, default=0.1)
61-
parser.add_argument('--bitfit', '-bf', type=str2bool, default=True)
62-
parser.add_argument('--semseg_threshold', '-th', type=float, default=0.2)
63-
64-
# training arguments
65-
parser.add_argument('--n_steps', '-nst', type=int, default=300000)
66-
parser.add_argument('--optimizer', '-opt', type=str, default='adam', choices=['sgd', 'adam', 'adamw', 'fadam', 'dsadam'])
67-
parser.add_argument('--lr', type=float, default=1e-4)
68-
parser.add_argument('--lr_pretrained', '-lrp', type=float, default=1e-5)
69-
parser.add_argument('--lr_schedule', '-lrs', type=str, default='poly', choices=['constant', 'sqroot', 'cos', 'poly'])
70-
parser.add_argument('--lr_warmup', '-lrw', type=int, default=5000)
71-
parser.add_argument('--lr_warmup_scale', '-lrws', type=float, default=0.)
72-
parser.add_argument('--weight_decay', '-wd', type=float, default=0.)
73-
parser.add_argument('--lr_decay_degree', '-ldd', type=float, default=0.9)
74-
parser.add_argument('--temperature', '-temp', type=float, default=-1.)
75-
parser.add_argument('--reg_coef', '-rgc', type=float, default=1.)
76-
parser.add_argument('--mask_value', '-mv', type=float, default=-1.)
77-
78-
# logging arguments
79-
parser.add_argument('--log_dir', type=str, default='TRAIN')
80-
parser.add_argument('--save_dir', type=str, default='')
81-
parser.add_argument('--load_dir', type=str, default='')
27+
28+
parser.add_argument('--stage', type=int, default=0, choices=[0, 1, 2])
29+
parser.add_argument('--task', type=str, default='', choices=['', 'all'] + TASKS_GROUP_NAMES)
30+
parser.add_argument('--task_fold', '-fold', type=int, default=None, choices=[0, 1, 2, 3, 4])
8231
parser.add_argument('--exp_name', type=str, default='')
32+
parser.add_argument('--exp_subname', type=str, default='')
8333
parser.add_argument('--name_postfix', '-ptf', type=str, default='')
84-
parser.add_argument('--log_iter', '-li', type=int, default=100)
85-
parser.add_argument('--val_iter', '-vi', type=int, default=10000)
86-
parser.add_argument('--save_iter', '-si', type=int, default=10000)
34+
parser.add_argument('--save_postfix', '-sptf', type=str, default='')
35+
parser.add_argument('--result_postfix', '-rptf', type=str, default='')
8736
parser.add_argument('--load_step', '-ls', type=int, default=-1)
8837

89-
config = parser.parse_args()
38+
# optional arguments
39+
parser.add_argument('--model', type=str, default=None, choices=['VTM'])
40+
parser.add_argument('--seed', type=int, default=None)
41+
parser.add_argument('--strategy', '-str', type=str, default=None)
42+
parser.add_argument('--num_workers', '-nw', type=int, default=None)
43+
parser.add_argument('--global_batch_size', '-gbs', type=int, default=None)
44+
parser.add_argument('--eval_batch_size', '-ebs', type=int, default=None)
45+
parser.add_argument('--n_eval_batches', '-neb', type=int, default=None)
46+
parser.add_argument('--shot', type=int, default=None)
47+
parser.add_argument('--max_channels', '-mc', type=int, default=None)
48+
parser.add_argument('--support_idx', '-sid', type=int, default=None)
49+
parser.add_argument('--channel_idx', '-cid', type=int, default=None)
50+
parser.add_argument('--test_split', '-split', type=str, default=None)
51+
parser.add_argument('--semseg_threshold', '-sth', type=float, default=None)
52+
53+
parser.add_argument('--image_augmentation', '-ia', type=str2bool, default=None)
54+
parser.add_argument('--unary_augmentation', '-ua', type=str2bool, default=None)
55+
parser.add_argument('--binary_augmentation', '-ba', type=str2bool, default=None)
56+
parser.add_argument('--mixed_augmentation', '-ma', type=str2bool, default=None)
57+
parser.add_argument('--image_backbone', '-ib', type=str, default=None)
58+
parser.add_argument('--label_backbone', '-lb', type=str, default=None)
59+
parser.add_argument('--n_attn_heads', '-nah', type=int, default=None)
60+
61+
parser.add_argument('--n_steps', '-nst', type=int, default=None)
62+
parser.add_argument('--optimizer', '-opt', type=str, default=None, choices=['sgd', 'adam', 'adamw'])
63+
parser.add_argument('--lr', type=float, default=None)
64+
parser.add_argument('--lr_pretrained', '-lrp', type=float, default=None)
65+
parser.add_argument('--lr_schedule', '-lrs', type=str, default=None, choices=['constant', 'sqroot', 'cos', 'poly'])
66+
parser.add_argument('--early_stopping_patience', '-esp', type=int, default=None)
67+
68+
parser.add_argument('--log_dir', type=str, default=None)
69+
parser.add_argument('--save_dir', type=str, default=None)
70+
parser.add_argument('--load_dir', type=str, default=None)
71+
parser.add_argument('--val_iter', '-viter', type=int, default=None)
72+
parser.add_argument('--save_iter', '-siter', type=int, default=None)
73+
74+
args = parser.parse_args()
9075

9176

77+
# load config file
78+
if args.stage == 0:
79+
config_path = 'configs/train_config.yaml'
80+
elif args.stage == 1:
81+
config_path = 'configs/finetune_config.yaml'
82+
elif args.stage == 2:
83+
config_path = 'configs/test_config.yaml'
84+
85+
with open(config_path, 'r') as f:
86+
config = yaml.safe_load(f)
87+
config = EasyDict(config)
88+
89+
# copy parsed arguments
90+
for key in args.__dir__():
91+
if key[:2] != '__' and getattr(args, key) is not None:
92+
setattr(config, key, getattr(args, key))
93+
9294
# retrieve data root
9395
with open('data_paths.yaml', 'r') as f:
9496
path_dict = yaml.safe_load(f)
9597
config.root_dir = path_dict[config.dataset]
96-
if config.save_dir == '':
97-
config.save_dir = config.log_dir
98-
if config.load_dir == '':
99-
config.load_dir = config.log_dir
10098

10199
# for debugging
102100
if config.debug_mode:
103101
config.n_steps = 10
104102
config.log_iter = 1
105103
config.val_iter = 5
106104
config.save_iter = 5
107-
config.n_eval_batches = 4
105+
if config.stage == 2:
106+
config.n_eval_batches = 2
108107
config.log_dir += '_debugging'
109-
config.save_dir += '_debugging'
110-
config.load_dir += '_debugging'
111-
108+
if config.stage == 0:
109+
config.load_dir += '_debugging'
110+
if config.stage <= 1:
111+
config.save_dir += '_debugging'
112112

113-
# model-specific hyper-parameters
114-
config.n_levels = 4
115-
116-
# adjust backbone names
117-
if config.image_backbone in ['beit_base', 'beit_large']:
118-
config.image_backbone += '_patch16_224_in22k'
119-
if config.image_backbone in ['vit_tiny', 'vit_small', 'vit_base', 'vit_large']:
120-
config.image_backbone += '_patch16_224'
121-
if config.label_backbone in ['vit_tiny', 'vit_small', 'vit_base', 'vit_large']:
122-
config.label_backbone += '_patch16_224'
113+
# create experiment name
114+
if config.exp_name == '':
115+
if config.stage == 0:
116+
if config.task == '':
117+
config.exp_name = f'{config.model}_fold:{config.task_fold}{config.name_postfix}'
118+
else:
119+
config.exp_name = f'{config.model}_task:{config.task}{config.name_postfix}'
120+
else:
121+
fold_dict = {}
122+
for fold in TASKS_GROUP_TEST:
123+
for task in TASKS_GROUP_TEST[fold]:
124+
fold_dict[task] = fold
125+
task_fold = fold_dict[config.task]
126+
config.exp_name = f'{config.model}_fold:{task_fold}{config.name_postfix}'

configs/finetune_config.yaml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# environment settings
2+
seed: 0
3+
precision: bf16
4+
strategy: ddp
5+
6+
# data arguments
7+
dataset: taskonomy
8+
num_workers: 4
9+
global_batch_size: 1
10+
shot: 10
11+
eval_batch_size: 5
12+
n_eval_batches: 2
13+
img_size: 224
14+
support_idx: 0
15+
channel_idx: -1
16+
17+
# model arguments
18+
model: VTM
19+
semseg_threshold: 0.2
20+
attn_dropout: 0.5
21+
22+
# training arguments
23+
n_steps: 20000
24+
n_schedule_steps: 20000
25+
optimizer: adam
26+
lr: 0.005
27+
lr_schedule: constant
28+
lr_warmup: 0
29+
lr_warmup_scale: 0.
30+
schedule_from: 0
31+
weight_decay: 0.
32+
lr_decay_degree: 0.9
33+
mask_value: -1.
34+
early_stopping_patience: 5
35+
36+
# logging arguments
37+
log_dir: FINETUNE
38+
save_dir: FINETUNE
39+
load_dir: TRAIN
40+
log_iter: 100
41+
val_iter: 100
42+
save_iter: 100
43+
load_step: 0

configs/test_config.yaml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# environment settings
2+
seed: 0
3+
precision: bf16
4+
strategy: ddp
5+
6+
# data arguments
7+
dataset: taskonomy
8+
test_split: muleshoe
9+
num_workers: 4
10+
shot: 10
11+
eval_batch_size: 8
12+
n_eval_batches: -1
13+
img_size: 224
14+
support_idx: 0
15+
channel_idx: -1
16+
17+
# model arguments
18+
model: VTM
19+
semseg_threshold: 0.2
20+
21+
# logging arguments
22+
log_dir: TEST
23+
save_dir: FINETUNE
24+
load_dir: TRAIN
25+
load_step: 0

configs/train_config.yaml

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# environment settings
2+
seed: 0
3+
precision: bf16
4+
strategy: ddp
5+
6+
# data arguments
7+
dataset: taskonomy
8+
task_fold: 0
9+
num_workers: 4
10+
global_batch_size: 8
11+
max_channels: 5
12+
shot: 4
13+
n_buildings: -1
14+
domains_per_batch: 2
15+
eval_batch_size: 8
16+
n_eval_batches: 10
17+
img_size: 224
18+
image_augmentation: True
19+
unary_augmentation: True
20+
binary_augmentation: True
21+
mixed_augmentation: True
22+
channel_idx: -1
23+
24+
# model arguments
25+
model: VTM
26+
image_backbone: beit_base_patch16_224_in22k
27+
label_backbone: vit_base_patch16_224
28+
image_encoder_weights: imagenet
29+
drop_rate: 0.
30+
drop_path_rate: 0.1
31+
attn_drop_rate: 0.
32+
n_attn_heads: 4
33+
semseg_threshold: 0.2
34+
channel_idx: -1
35+
n_levels: 4
36+
bitfit: True
37+
38+
# training arguments
39+
n_steps: 300000
40+
optimizer: adam
41+
lr: 0.0001
42+
lr_pretrained: 0.00001
43+
lr_schedule: poly
44+
lr_warmup: 5000
45+
lr_warmup_scale: 0.
46+
schedule_from: 0
47+
weight_decay: 0.
48+
lr_decay_degree: 0.9
49+
mask_value: -1.
50+
early_stopping_patience: -1
51+
52+
# logging arguments
53+
log_dir: TRAIN
54+
save_dir: TRAIN
55+
load_dir: TRAIN
56+
log_iter: 100
57+
val_iter: 20000
58+
save_iter: 20000
59+
load_step: -1

0 commit comments

Comments
 (0)