Skip to content

Commit bad4500

Browse files
authored
Merge pull request #115 from dzenanz/fixValidationCrash
Fix multiprocessing crash -> allow validation to be invoked standalone
2 parents ca28f30 + 24aff64 commit bad4500

File tree

1 file changed

+64
-61
lines changed

1 file changed

+64
-61
lines changed

val.py

Lines changed: 64 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -34,66 +34,69 @@
3434
import function
3535

3636

37-
args = cfg.parse_args()
38-
if args.dataset == 'refuge' or args.dataset == 'refuge2':
39-
args.data_path = '../dataset'
40-
41-
GPUdevice = torch.device('cuda', args.gpu_device)
42-
43-
net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed)
44-
45-
'''load pretrained model'''
46-
assert args.weights != 0
47-
print(f'=> resuming from {args.weights}')
48-
assert os.path.exists(args.weights)
49-
checkpoint_file = os.path.join(args.weights)
50-
assert os.path.exists(checkpoint_file)
51-
loc = 'cuda:{}'.format(args.gpu_device)
52-
checkpoint = torch.load(checkpoint_file, map_location=loc)
53-
start_epoch = checkpoint['epoch']
54-
best_tol = checkpoint['best_tol']
55-
56-
state_dict = checkpoint['state_dict']
57-
if args.distributed != 'none':
58-
from collections import OrderedDict
59-
new_state_dict = OrderedDict()
60-
for k, v in state_dict.items():
61-
# name = k[7:] # remove `module.`
62-
name = 'module.' + k
63-
new_state_dict[name] = v
64-
# load params
65-
else:
66-
new_state_dict = state_dict
67-
68-
net.load_state_dict(new_state_dict)
69-
70-
# args.path_helper = checkpoint['path_helper']
71-
# logger = create_logger(args.path_helper['log_path'])
72-
# print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')
73-
74-
# args.path_helper = set_log_dir('logs', args.exp_name)
75-
# logger = create_logger(args.path_helper['log_path'])
76-
# logger.info(args)
77-
78-
args.path_helper = set_log_dir('logs', args.exp_name)
79-
logger = create_logger(args.path_helper['log_path'])
80-
logger.info(args)
81-
82-
'''segmentation data'''
83-
nice_train_loader, nice_test_loader = get_dataloader(args)
84-
85-
'''begain valuation'''
86-
best_acc = 0.0
87-
best_tol = 1e4
88-
89-
if args.mod == 'sam_adpt':
90-
net.eval()
91-
92-
if args.dataset != 'REFUGE':
93-
tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, start_epoch, net)
94-
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {start_epoch}.')
37+
def main():
38+
args = cfg.parse_args()
39+
if args.dataset == 'refuge' or args.dataset == 'refuge2':
40+
args.data_path = '../dataset'
41+
42+
GPUdevice = torch.device('cuda', args.gpu_device)
43+
44+
net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed)
45+
46+
'''load pretrained model'''
47+
assert args.weights != 0
48+
print(f'=> resuming from {args.weights}')
49+
assert os.path.exists(args.weights)
50+
checkpoint_file = os.path.join(args.weights)
51+
assert os.path.exists(checkpoint_file)
52+
loc = 'cuda:{}'.format(args.gpu_device)
53+
checkpoint = torch.load(checkpoint_file, map_location=loc)
54+
start_epoch = checkpoint['epoch']
55+
best_tol = checkpoint['best_tol']
56+
57+
state_dict = checkpoint['state_dict']
58+
if args.distributed != 'none':
59+
from collections import OrderedDict
60+
new_state_dict = OrderedDict()
61+
for k, v in state_dict.items():
62+
# name = k[7:] # remove `module.`
63+
name = 'module.' + k
64+
new_state_dict[name] = v
65+
# load params
9566
else:
96-
tol, (eiou_cup, eiou_disc, edice_cup, edice_disc) = function.validation_sam(args, nice_test_loader, start_epoch, net)
97-
logger.info(f'Total score: {tol}, IOU_CUP: {eiou_cup}, IOU_DISC: {eiou_disc}, DICE_CUP: {edice_cup}, DICE_DISC: {edice_disc} || @ epoch {start_epoch}.')
67+
new_state_dict = state_dict
9868

99-
69+
net.load_state_dict(new_state_dict)
70+
71+
# args.path_helper = checkpoint['path_helper']
72+
# logger = create_logger(args.path_helper['log_path'])
73+
# print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')
74+
75+
# args.path_helper = set_log_dir('logs', args.exp_name)
76+
# logger = create_logger(args.path_helper['log_path'])
77+
# logger.info(args)
78+
79+
args.path_helper = set_log_dir('logs', args.exp_name)
80+
logger = create_logger(args.path_helper['log_path'])
81+
logger.info(args)
82+
83+
'''segmentation data'''
84+
nice_train_loader, nice_test_loader = get_dataloader(args)
85+
86+
'''begain valuation'''
87+
best_acc = 0.0
88+
best_tol = 1e4
89+
90+
if args.mod == 'sam_adpt':
91+
net.eval()
92+
93+
if args.dataset != 'REFUGE':
94+
tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, start_epoch, net)
95+
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {start_epoch}.')
96+
else:
97+
tol, (eiou_cup, eiou_disc, edice_cup, edice_disc) = function.validation_sam(args, nice_test_loader, start_epoch, net)
98+
logger.info(f'Total score: {tol}, IOU_CUP: {eiou_cup}, IOU_DISC: {eiou_disc}, DICE_CUP: {edice_cup}, DICE_DISC: {edice_disc} || @ epoch {start_epoch}.')
99+
100+
101+
if __name__ == '__main__':
102+
main()

0 commit comments

Comments
 (0)