3434import function
3535
3636
37- args = cfg .parse_args ()
38- if args .dataset == 'refuge' or args .dataset == 'refuge2' :
39- args .data_path = '../dataset'
40-
41- GPUdevice = torch .device ('cuda' , args .gpu_device )
42-
43- net = get_network (args , args .net , use_gpu = args .gpu , gpu_device = GPUdevice , distribution = args .distributed )
44-
45- '''load pretrained model'''
46- assert args .weights != 0
47- print (f'=> resuming from { args .weights } ' )
48- assert os .path .exists (args .weights )
49- checkpoint_file = os .path .join (args .weights )
50- assert os .path .exists (checkpoint_file )
51- loc = 'cuda:{}' .format (args .gpu_device )
52- checkpoint = torch .load (checkpoint_file , map_location = loc )
53- start_epoch = checkpoint ['epoch' ]
54- best_tol = checkpoint ['best_tol' ]
55-
56- state_dict = checkpoint ['state_dict' ]
57- if args .distributed != 'none' :
58- from collections import OrderedDict
59- new_state_dict = OrderedDict ()
60- for k , v in state_dict .items ():
61- # name = k[7:] # remove `module.`
62- name = 'module.' + k
63- new_state_dict [name ] = v
64- # load params
65- else :
66- new_state_dict = state_dict
67-
68- net .load_state_dict (new_state_dict )
69-
70- # args.path_helper = checkpoint['path_helper']
71- # logger = create_logger(args.path_helper['log_path'])
72- # print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')
73-
74- # args.path_helper = set_log_dir('logs', args.exp_name)
75- # logger = create_logger(args.path_helper['log_path'])
76- # logger.info(args)
77-
78- args .path_helper = set_log_dir ('logs' , args .exp_name )
79- logger = create_logger (args .path_helper ['log_path' ])
80- logger .info (args )
81-
82- '''segmentation data'''
83- nice_train_loader , nice_test_loader = get_dataloader (args )
84-
85- '''begain valuation'''
86- best_acc = 0.0
87- best_tol = 1e4
88-
89- if args .mod == 'sam_adpt' :
90- net .eval ()
91-
92- if args .dataset != 'REFUGE' :
93- tol , (eiou , edice ) = function .validation_sam (args , nice_test_loader , start_epoch , net )
94- logger .info (f'Total score: { tol } , IOU: { eiou } , DICE: { edice } || @ epoch { start_epoch } .' )
37+ def main ():
38+ args = cfg .parse_args ()
39+ if args .dataset == 'refuge' or args .dataset == 'refuge2' :
40+ args .data_path = '../dataset'
41+
42+ GPUdevice = torch .device ('cuda' , args .gpu_device )
43+
44+ net = get_network (args , args .net , use_gpu = args .gpu , gpu_device = GPUdevice , distribution = args .distributed )
45+
46+ '''load pretrained model'''
47+ assert args .weights != 0
48+ print (f'=> resuming from { args .weights } ' )
49+ assert os .path .exists (args .weights )
50+ checkpoint_file = os .path .join (args .weights )
51+ assert os .path .exists (checkpoint_file )
52+ loc = 'cuda:{}' .format (args .gpu_device )
53+ checkpoint = torch .load (checkpoint_file , map_location = loc )
54+ start_epoch = checkpoint ['epoch' ]
55+ best_tol = checkpoint ['best_tol' ]
56+
57+ state_dict = checkpoint ['state_dict' ]
58+ if args .distributed != 'none' :
59+ from collections import OrderedDict
60+ new_state_dict = OrderedDict ()
61+ for k , v in state_dict .items ():
62+ # name = k[7:] # remove `module.`
63+ name = 'module.' + k
64+ new_state_dict [name ] = v
65+ # load params
9566 else :
96- tol , (eiou_cup , eiou_disc , edice_cup , edice_disc ) = function .validation_sam (args , nice_test_loader , start_epoch , net )
97- logger .info (f'Total score: { tol } , IOU_CUP: { eiou_cup } , IOU_DISC: { eiou_disc } , DICE_CUP: { edice_cup } , DICE_DISC: { edice_disc } || @ epoch { start_epoch } .' )
67+ new_state_dict = state_dict
9868
99-
69+ net .load_state_dict (new_state_dict )
70+
71+ # args.path_helper = checkpoint['path_helper']
72+ # logger = create_logger(args.path_helper['log_path'])
73+ # print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')
74+
75+ # args.path_helper = set_log_dir('logs', args.exp_name)
76+ # logger = create_logger(args.path_helper['log_path'])
77+ # logger.info(args)
78+
79+ args .path_helper = set_log_dir ('logs' , args .exp_name )
80+ logger = create_logger (args .path_helper ['log_path' ])
81+ logger .info (args )
82+
83+ '''segmentation data'''
84+ nice_train_loader , nice_test_loader = get_dataloader (args )
85+
86+ '''begain valuation'''
87+ best_acc = 0.0
88+ best_tol = 1e4
89+
90+ if args .mod == 'sam_adpt' :
91+ net .eval ()
92+
93+ if args .dataset != 'REFUGE' :
94+ tol , (eiou , edice ) = function .validation_sam (args , nice_test_loader , start_epoch , net )
95+ logger .info (f'Total score: { tol } , IOU: { eiou } , DICE: { edice } || @ epoch { start_epoch } .' )
96+ else :
97+ tol , (eiou_cup , eiou_disc , edice_cup , edice_disc ) = function .validation_sam (args , nice_test_loader , start_epoch , net )
98+ logger .info (f'Total score: { tol } , IOU_CUP: { eiou_cup } , IOU_DISC: { eiou_disc } , DICE_CUP: { edice_cup } , DICE_DISC: { edice_disc } || @ epoch { start_epoch } .' )
99+
100+
101+ if __name__ == '__main__' :
102+ main ()
0 commit comments