11import argparse
22import yaml
3+ from easydict import EasyDict
4+
5+ from dataset .taskonomy_constants import TASKS_GROUP_NAMES , TASKS_GROUP_TEST
36
47
58def str2bool (v ):
@@ -14,109 +17,110 @@ def str2bool(v):
1417# argument parser
1518parser = argparse .ArgumentParser ()
1619
17- # environment arguments
18- parser .add_argument ('--seed' , type = int , default = 0 )
19- parser .add_argument ('--precision' , '-prc' , type = str , default = 'bf16' , choices = ['fp32' , 'fp16' , 'bf16' ])
20- parser .add_argument ('--strategy' , '-str' , type = str , default = 'ddp' , choices = ['none' , 'ddp' ])
20+ # necessary arguments
2121parser .add_argument ('--debug_mode' , '-debug' , default = False , action = 'store_true' )
2222parser .add_argument ('--continue_mode' , '-cont' , default = False , action = 'store_true' )
2323parser .add_argument ('--skip_mode' , '-skip' , default = False , action = 'store_true' )
2424parser .add_argument ('--no_eval' , '-ne' , default = False , action = 'store_true' )
2525parser .add_argument ('--no_save' , '-ns' , default = False , action = 'store_true' )
2626parser .add_argument ('--reset_mode' , '-reset' , default = False , action = 'store_true' )
27- parser .add_argument ('--profile_mode' , '-prof' , default = False , action = 'store_true' )
28- parser .add_argument ('--sanity_check' , '-sc' , default = False , action = 'store_true' )
29-
30- # data arguments
31- parser .add_argument ('--dataset' , type = str , default = 'taskonomy' , choices = ['taskonomy' ])
32- parser .add_argument ('--task' , type = str , default = '' , choices = ['' , 'all' ])
33- parser .add_argument ('--task_fold' , '-fold' , type = int , default = 0 , choices = [0 , 1 , 2 , 3 , 4 ])
34-
35- parser .add_argument ('--num_workers' , '-nw' , type = int , default = 8 )
36- parser .add_argument ('--global_batch_size' , '-gbs' , type = int , default = 8 )
37- parser .add_argument ('--max_channels' , '-mc' , type = int , default = 5 )
38- parser .add_argument ('--shot' , type = int , default = 4 )
39- parser .add_argument ('--domains_per_batch' , '-dpb' , type = int , default = 2 )
40- parser .add_argument ('--eval_batch_size' , '-ebs' , type = int , default = 8 )
41- parser .add_argument ('--n_eval_batches' , '-neb' , type = int , default = 10 )
42-
43- parser .add_argument ('--img_size' , type = int , default = 224 , choices = [224 ])
44- parser .add_argument ('--image_augmentation' , '-ia' , type = str2bool , default = True )
45- parser .add_argument ('--unary_augmentation' , '-ua' , type = str2bool , default = True )
46- parser .add_argument ('--binary_augmentation' , '-ba' , type = str2bool , default = True )
47- parser .add_argument ('--mixed_augmentation' , '-ma' , type = str2bool , default = True )
48-
49- # model arguments
50- parser .add_argument ('--model' , type = str , default = 'VTM' , choices = ['VTM' ])
51- parser .add_argument ('--image_backbone' , '-ib' , type = str , default = 'beit_base_patch16_224_in22k' )
52- parser .add_argument ('--label_backbone' , '-lb' , type = str , default = 'vit_base_patch16_224' )
53- parser .add_argument ('--image_encoder_weights' , '-iew' , type = str , default = 'imagenet' , choices = ['none' , 'imagenet' ])
54- parser .add_argument ('--label_encoder_weights' , '-lew' , type = str , default = 'none' , choices = ['none' , 'imagenet' ])
55- parser .add_argument ('--n_attn_heads' , '-nah' , type = int , default = 4 )
56- parser .add_argument ('--n_attn_layers' , '-nal' , type = int , default = 1 )
57- parser .add_argument ('--attn_residual' , '-ar' , type = str2bool , default = True )
58- parser .add_argument ('--out_activation' , '-oa' , type = str , default = 'sigmoid' , choices = ['sigmoid' , 'clip' , 'none' ])
59- parser .add_argument ('--drop_rate' , '-dr' , type = float , default = 0.0 )
60- parser .add_argument ('--drop_path_rate' , '-dpr' , type = float , default = 0.1 )
61- parser .add_argument ('--bitfit' , '-bf' , type = str2bool , default = True )
62- parser .add_argument ('--semseg_threshold' , '-th' , type = float , default = 0.2 )
63-
64- # training arguments
65- parser .add_argument ('--n_steps' , '-nst' , type = int , default = 300000 )
66- parser .add_argument ('--optimizer' , '-opt' , type = str , default = 'adam' , choices = ['sgd' , 'adam' , 'adamw' , 'fadam' , 'dsadam' ])
67- parser .add_argument ('--lr' , type = float , default = 1e-4 )
68- parser .add_argument ('--lr_pretrained' , '-lrp' , type = float , default = 1e-5 )
69- parser .add_argument ('--lr_schedule' , '-lrs' , type = str , default = 'poly' , choices = ['constant' , 'sqroot' , 'cos' , 'poly' ])
70- parser .add_argument ('--lr_warmup' , '-lrw' , type = int , default = 5000 )
71- parser .add_argument ('--lr_warmup_scale' , '-lrws' , type = float , default = 0. )
72- parser .add_argument ('--weight_decay' , '-wd' , type = float , default = 0. )
73- parser .add_argument ('--lr_decay_degree' , '-ldd' , type = float , default = 0.9 )
74- parser .add_argument ('--temperature' , '-temp' , type = float , default = - 1. )
75- parser .add_argument ('--reg_coef' , '-rgc' , type = float , default = 1. )
76- parser .add_argument ('--mask_value' , '-mv' , type = float , default = - 1. )
77-
78- # logging arguments
79- parser .add_argument ('--log_dir' , type = str , default = 'TRAIN' )
80- parser .add_argument ('--save_dir' , type = str , default = '' )
81- parser .add_argument ('--load_dir' , type = str , default = '' )
27+
28+ parser .add_argument ('--stage' , type = int , default = 0 , choices = [0 , 1 , 2 ])
29+ parser .add_argument ('--task' , type = str , default = '' , choices = ['' , 'all' ] + TASKS_GROUP_NAMES )
30+ parser .add_argument ('--task_fold' , '-fold' , type = int , default = None , choices = [0 , 1 , 2 , 3 , 4 ])
8231parser .add_argument ('--exp_name' , type = str , default = '' )
32+ parser .add_argument ('--exp_subname' , type = str , default = '' )
8333parser .add_argument ('--name_postfix' , '-ptf' , type = str , default = '' )
84- parser .add_argument ('--log_iter' , '-li' , type = int , default = 100 )
85- parser .add_argument ('--val_iter' , '-vi' , type = int , default = 10000 )
86- parser .add_argument ('--save_iter' , '-si' , type = int , default = 10000 )
34+ parser .add_argument ('--save_postfix' , '-sptf' , type = str , default = '' )
35+ parser .add_argument ('--result_postfix' , '-rptf' , type = str , default = '' )
8736parser .add_argument ('--load_step' , '-ls' , type = int , default = - 1 )
8837
89- config = parser .parse_args ()
38+ # optional arguments
39+ parser .add_argument ('--model' , type = str , default = None , choices = ['VTM' ])
40+ parser .add_argument ('--seed' , type = int , default = None )
41+ parser .add_argument ('--strategy' , '-str' , type = str , default = None )
42+ parser .add_argument ('--num_workers' , '-nw' , type = int , default = None )
43+ parser .add_argument ('--global_batch_size' , '-gbs' , type = int , default = None )
44+ parser .add_argument ('--eval_batch_size' , '-ebs' , type = int , default = None )
45+ parser .add_argument ('--n_eval_batches' , '-neb' , type = int , default = None )
46+ parser .add_argument ('--shot' , type = int , default = None )
47+ parser .add_argument ('--max_channels' , '-mc' , type = int , default = None )
48+ parser .add_argument ('--support_idx' , '-sid' , type = int , default = None )
49+ parser .add_argument ('--channel_idx' , '-cid' , type = int , default = None )
50+ parser .add_argument ('--test_split' , '-split' , type = str , default = None )
51+ parser .add_argument ('--semseg_threshold' , '-sth' , type = float , default = None )
52+
53+ parser .add_argument ('--image_augmentation' , '-ia' , type = str2bool , default = None )
54+ parser .add_argument ('--unary_augmentation' , '-ua' , type = str2bool , default = None )
55+ parser .add_argument ('--binary_augmentation' , '-ba' , type = str2bool , default = None )
56+ parser .add_argument ('--mixed_augmentation' , '-ma' , type = str2bool , default = None )
57+ parser .add_argument ('--image_backbone' , '-ib' , type = str , default = None )
58+ parser .add_argument ('--label_backbone' , '-lb' , type = str , default = None )
59+ parser .add_argument ('--n_attn_heads' , '-nah' , type = int , default = None )
60+
61+ parser .add_argument ('--n_steps' , '-nst' , type = int , default = None )
62+ parser .add_argument ('--optimizer' , '-opt' , type = str , default = None , choices = ['sgd' , 'adam' , 'adamw' ])
63+ parser .add_argument ('--lr' , type = float , default = None )
64+ parser .add_argument ('--lr_pretrained' , '-lrp' , type = float , default = None )
65+ parser .add_argument ('--lr_schedule' , '-lrs' , type = str , default = None , choices = ['constant' , 'sqroot' , 'cos' , 'poly' ])
66+ parser .add_argument ('--early_stopping_patience' , '-esp' , type = int , default = None )
67+
68+ parser .add_argument ('--log_dir' , type = str , default = None )
69+ parser .add_argument ('--save_dir' , type = str , default = None )
70+ parser .add_argument ('--load_dir' , type = str , default = None )
71+ parser .add_argument ('--val_iter' , '-viter' , type = int , default = None )
72+ parser .add_argument ('--save_iter' , '-siter' , type = int , default = None )
73+
74+ args = parser .parse_args ()
9075
9176
77+ # load config file
78+ if args .stage == 0 :
79+ config_path = 'configs/train_config.yaml'
80+ elif args .stage == 1 :
81+ config_path = 'configs/finetune_config.yaml'
82+ elif args .stage == 2 :
83+ config_path = 'configs/test_config.yaml'
84+
85+ with open (config_path , 'r' ) as f :
86+ config = yaml .safe_load (f )
87+ config = EasyDict (config )
88+
89+ # copy parsed arguments
90+ for key in args .__dir__ ():
91+ if key [:2 ] != '__' and getattr (args , key ) is not None :
92+ setattr (config , key , getattr (args , key ))
93+
9294# retrieve data root
9395with open ('data_paths.yaml' , 'r' ) as f :
9496 path_dict = yaml .safe_load (f )
9597 config .root_dir = path_dict [config .dataset ]
96- if config .save_dir == '' :
97- config .save_dir = config .log_dir
98- if config .load_dir == '' :
99- config .load_dir = config .log_dir
10098
10199# for debugging
102100if config .debug_mode :
103101 config .n_steps = 10
104102 config .log_iter = 1
105103 config .val_iter = 5
106104 config .save_iter = 5
107- config .n_eval_batches = 4
105+ if config .stage == 2 :
106+ config .n_eval_batches = 2
108107 config .log_dir += '_debugging'
109- config .save_dir += '_debugging'
110- config .load_dir += '_debugging'
111-
108+ if config .stage == 0 :
109+ config .load_dir += '_debugging'
110+ if config .stage <= 1 :
111+ config .save_dir += '_debugging'
112112
113- # model-specific hyper-parameters
114- config .n_levels = 4
115-
116- # adjust backbone names
117- if config .image_backbone in ['beit_base' , 'beit_large' ]:
118- config .image_backbone += '_patch16_224_in22k'
119- if config .image_backbone in ['vit_tiny' , 'vit_small' , 'vit_base' , 'vit_large' ]:
120- config .image_backbone += '_patch16_224'
121- if config .label_backbone in ['vit_tiny' , 'vit_small' , 'vit_base' , 'vit_large' ]:
122- config .label_backbone += '_patch16_224'
113+ # create experiment name
114+ if config .exp_name == '' :
115+ if config .stage == 0 :
116+ if config .task == '' :
117+ config .exp_name = f'{ config .model } _fold:{ config .task_fold } { config .name_postfix } '
118+ else :
119+ config .exp_name = f'{ config .model } _task:{ config .task } { config .name_postfix } '
120+ else :
121+ fold_dict = {}
122+ for fold in TASKS_GROUP_TEST :
123+ for task in TASKS_GROUP_TEST [fold ]:
124+ fold_dict [task ] = fold
125+ task_fold = fold_dict [config .task ]
126+ config .exp_name = f'{ config .model } _fold:{ task_fold } { config .name_postfix } '
0 commit comments