-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun.py
More file actions
90 lines (67 loc) · 3.72 KB
/
run.py
File metadata and controls
90 lines (67 loc) · 3.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import time
import torch
import numpy as np
from utils.assets import get_args, fix_seed, tunable_params_stastic
from utils.read_data import get_dataset
from logs.exp_logger import MultiLogger
def main(args, logger):
start_time = time.time()
if args.method == "finetune":
from trainers.finetune import Trainer
elif args.method == "pure_kd":
from trainers.pure_kd import Trainer
elif args.method == "rd":
from trainers.rd import Trainer
else:
raise NotImplementedError
train_datasets, val_datasets, test_datasets = get_dataset(args)
logger.log_args(args)
trainer = Trainer(args, train_datasets.num_classes)
trainer.logger = logger
logger.log_print(f"Dataset: {args.dataset}, Val_datasets is None?: {val_datasets is None}")
if val_datasets is not None:
trainer.run(train_datasets, val_datasets)
else:
trainer.run(train_datasets, test_datasets)
trainer.evaluating_for_datsets(test_datasets)
logger.log_print("Total time: {:2f} h".format((time.time() - start_time)/(60*60)))
if args.save_model and not args.debug:
logger.log_print(f"Model save...")
torch.save(trainer.foundation_model.state_dict(), os.path.join(args.save_model_dir, "final_teacher_model.pt"))
torch.save(trainer.student_model.state_dict(), os.path.join(args.save_model_dir, "final_student_model.pt"))
if __name__ == '__main__':
args = get_args()
if args.debug:
args.save_log_dir = '/mnt/workspace/zhouyuhang/CODE/MKF-ours/logs/debug/' + args.dataset+'/'+args.method+'/'+args.foundation_model_name_or_path\
+'_'+args.downstream_model_name_or_path + '/'
args.save_model_dir = '/mnt/workspace/zhouyuhang/CODE/MKF-ours/checkpoint/debug/' + args.dataset+'/'+args.method+'/'+args.foundation_model_name_or_path\
+'_'+args.downstream_model_name_or_path + '/'
args.exp_name = 'debug_' + 'bz'+str(args.batch_size)+'_lr'+str(args.lr)+'_optim'+args.optim+'_epoch'+str(args.epochs)+'_kd'+args.kd_type+'_atte'+args.attention_type\
+'_kdw'+str(args.kd_weight) + '_cew' + str(args.ce_weight) + '/'
else:
args.save_log_dir = '/mnt/workspace/zhouyuhang/CODE/MKF-ours/logs/' + args.dataset+'/'+args.method+'/'+args.foundation_model_name_or_path\
+'_'+args.downstream_model_name_or_path + '/'
args.save_model_dir = '/mnt/workspace/zhouyuhang/CODE/MKF-ours/checkpoint/' + args.dataset+'/'+args.method+'/'+args.foundation_model_name_or_path\
+'_'+args.downstream_model_name_or_path + '/'
args.exp_name = 'bz'+str(args.batch_size)+'_lr'+str(args.lr)+'_optim'+args.optim+'_epoch'+str(args.epochs)+'_kd'+args.kd_type+'_atte'+args.attention_type\
+'_kdw'+str(args.kd_weight) + '_cew' + str(args.ce_weight) +'_ckaw'+str(args.cka_weight) + '_stucew' + str(args.stu_ce_weight) + '_patient' + str(args.patient) + '/'
args.save_log_dir = args.save_log_dir + args.exp_name
args.save_model_dir = args.save_model_dir + args.exp_name
if not os.path.exists(args.save_log_dir):
os.makedirs(args.save_log_dir)
if not os.path.exists(args.save_model_dir):
os.makedirs(args.save_model_dir)
logger = MultiLogger(args.save_log_dir, args.exp_name, loggers=['disk'])
if args.debug:
args.epochs = 1
args.epochs_stage2 = 1
args.batch_size = 16
logger.log_print('=' * 108)
logger.log_print('Arguments =')
for arg in np.sort(list(vars(args).keys())):
logger.log_print('\t' + str(arg) + ': ' + str(getattr(args, arg)))
logger.log_print('=' * 108)
fix_seed(args.seed)
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
main(args, logger)