|
13 | 13 | ''' |
14 | 14 | Network Test Function |
15 | 15 | predicts the PCA scores using the trained networks |
16 | | - returns the error measures and saves the predicted and poriginal particles for comparison |
| 16 | + returns the error measures and saves the predicted and original particles for comparison |
17 | 17 | ''' |
| 18 | + |
| 19 | + |
18 | 20 | def test(config_file, loader="test"): |
19 | | - with open(config_file) as json_file: |
20 | | - parameters = json.load(json_file) |
21 | | - model_dir = parameters["paths"]["out_dir"] + parameters["model_name"]+ '/' |
22 | | - pred_dir = model_dir + loader + '_predictions/' |
23 | | - loaders.make_dir(pred_dir) |
24 | | - if parameters["use_best_model"]: |
25 | | - model_path = model_dir + 'best_model.torch' |
26 | | - else: |
27 | | - model_path = model_dir + 'final_model.torch' |
28 | | - if parameters["fine_tune"]["enabled"]: |
29 | | - model_path_ft = model_path.replace(".torch", "_ft.torch") |
30 | | - else: |
31 | | - model_path_ft = model_path |
32 | | - loader_dir = parameters["paths"]["loader_dir"] |
| 21 | + with open(config_file) as json_file: |
| 22 | + parameters = json.load(json_file) |
| 23 | + model_dir = parameters["paths"]["out_dir"] + parameters["model_name"] + '/' |
| 24 | + pred_dir = model_dir + loader + '_predictions/' |
| 25 | + loaders.make_dir(pred_dir) |
| 26 | + if parameters["use_best_model"]: |
| 27 | + model_path = model_dir + 'best_model.torch' |
| 28 | + else: |
| 29 | + model_path = model_dir + 'final_model.torch' |
| 30 | + if parameters["fine_tune"]["enabled"]: |
| 31 | + model_path_ft = model_path.replace(".torch", "_ft.torch") |
| 32 | + else: |
| 33 | + model_path_ft = model_path |
| 34 | + loader_dir = parameters["paths"]["loader_dir"] |
| 35 | + |
| 36 | + # load the loaders |
| 37 | + sw_message("Loading " + loader + " data loader...") |
| 38 | + test_loader = torch.load(loader_dir + loader) |
| 39 | + |
| 40 | + # initialization |
| 41 | + sw_message("Loading trained model...") |
| 42 | + if parameters['tl_net']['enabled']: |
| 43 | + model_tl = model.DeepSSMNet_TLNet(config_file) |
| 44 | + model_tl.load_state_dict(torch.load(model_path)) |
| 45 | + device = model_tl.device |
| 46 | + model_tl.to(device) |
| 47 | + model_tl.eval() |
| 48 | + else: |
| 49 | + model_pca = model.DeepSSMNet(config_file) |
| 50 | + model_pca.load_state_dict(torch.load(model_path)) |
| 51 | + device = model_pca.device |
| 52 | + model_pca.to(device) |
| 53 | + model_pca.eval() |
| 54 | + model_ft = model.DeepSSMNet(config_file) |
| 55 | + model_ft.load_state_dict(torch.load(model_path_ft)) |
| 56 | + model_ft.to(device) |
| 57 | + model_ft.eval() |
33 | 58 |
|
34 | | - # load the loaders |
35 | | - sw_message("Loading "+ loader + " data loader...") |
36 | | - test_loader = torch.load(loader_dir + loader) |
37 | | - print("Done.\n") |
38 | | - # initalizations |
39 | | - sw_message("Loading trained model...") |
40 | | - if parameters['tl_net']['enabled']: |
41 | | - model_tl = model.DeepSSMNet_TLNet(config_file) |
42 | | - model_tl.load_state_dict(torch.load(model_path)) |
43 | | - device = model_tl.device |
44 | | - model_tl.to(device) |
45 | | - model_tl.eval() |
46 | | - else: |
47 | | - model_pca = model.DeepSSMNet(config_file) |
48 | | - model_pca.load_state_dict(torch.load(model_path)) |
49 | | - device = model_pca.device |
50 | | - model_pca.to(device) |
51 | | - model_pca.eval() |
52 | | - model_ft = model.DeepSSMNet(config_file) |
53 | | - model_ft.load_state_dict(torch.load(model_path_ft)) |
54 | | - model_ft.to(device) |
55 | | - model_ft.eval() |
| 59 | + # Get test names |
| 60 | + test_names_file = loader_dir + loader + '_names.txt' |
| 61 | + f = open(test_names_file, 'r') |
| 62 | + test_names_string = f.read() |
| 63 | + f.close() |
| 64 | + test_names_string = test_names_string.replace("[", "").replace("]", "").replace("'", "").replace(" ", "") |
| 65 | + test_names = test_names_string.split(",") |
| 66 | + sw_message(f"Predicting for {loader} images...") |
| 67 | + index = 0 |
| 68 | + pred_scores = [] |
56 | 69 |
|
57 | | - # Get test names |
58 | | - test_names_file = loader_dir + loader + '_names.txt' |
59 | | - f = open(test_names_file, 'r') |
60 | | - test_names_string = f.read() |
61 | | - f.close() |
62 | | - test_names_string = test_names_string.replace("[","").replace("]","").replace("'","").replace(" ","") |
63 | | - test_names = test_names_string.split(",") |
64 | | - sw_message(f"Predicting for {loader} images...") |
65 | | - index = 0 |
66 | | - pred_scores = [] |
| 70 | + pred_path = pred_dir + 'world_predictions/' |
| 71 | + loaders.make_dir(pred_path) |
| 72 | + pred_path_pca = pred_dir + 'pca_predictions/' |
| 73 | + loaders.make_dir(pred_path_pca) |
67 | 74 |
|
68 | | - if parameters['tl_net']['enabled']: |
69 | | - predPath_tl = pred_dir + '/TL_Predictions' |
70 | | - loaders.make_dir(predPath_tl) |
71 | | - else: |
72 | | - predPath_ft = pred_dir + 'FT_Predictions/' |
73 | | - predPath_pca = pred_dir + 'PCA_Predictions/' |
74 | | - loaders.make_dir(predPath_ft) |
75 | | - loaders.make_dir(predPath_pca) |
76 | | - predicted_particle_files = [] |
77 | | - for img, _, mdl, _ in test_loader: |
78 | | - if sw_check_abort(): |
79 | | - sw_message("Aborted") |
80 | | - return |
81 | | - sw_message(f"Predicting {index+1}/{len(test_loader)}") |
82 | | - sw_progress((index+1) / len(test_loader)) |
83 | | - img = img.to(device) |
84 | | - if parameters['tl_net']['enabled']: |
85 | | - mdl = torch.FloatTensor([1]).to(device) |
86 | | - [pred_tf, pred_mdl_tl] = model_tl(mdl, img) |
87 | | - pred_scores.append(pred_tf.cpu().data.numpy()) |
88 | | - # save the AE latent space as shape descriptors |
89 | | - nmpred = predPath_tl + '/' + test_names[index] + '.npy' |
90 | | - np.save(nmpred, pred_tf.squeeze().detach().cpu().numpy()) |
91 | | - nmpred = predPath_tl + '/' + test_names[index] + '.particles' |
92 | | - np.savetxt(nmpred, pred_mdl_tl.squeeze().detach().cpu().numpy()) |
93 | | - else: |
94 | | - [pred, pred_mdl_pca] = model_pca(img) |
95 | | - [pred, pred_mdl_ft] = model_ft(img) |
96 | | - pred_scores.append(pred.cpu().data.numpy()[0]) |
97 | | - nmpred = predPath_pca + '/predicted_pca_' + test_names[index] + '.particles' |
98 | | - np.savetxt(nmpred, pred_mdl_pca.squeeze().detach().cpu().numpy()) |
99 | | - nmpred = predPath_ft + '/predicted_ft_' + test_names[index] + '.particles' |
100 | | - np.savetxt(nmpred, pred_mdl_ft.squeeze().detach().cpu().numpy()) |
101 | | - predicted_particle_files.append(nmpred) |
102 | | - index += 1 |
103 | | - sw_message("Test completed.") |
104 | | - return predicted_particle_files |
| 75 | + predicted_particle_files = [] |
| 76 | + for img, _, mdl, _ in test_loader: |
| 77 | + if sw_check_abort(): |
| 78 | + sw_message("Aborted") |
| 79 | + return |
| 80 | + sw_message(f"Predicting {index + 1}/{len(test_loader)}") |
| 81 | + sw_progress((index + 1) / len(test_loader)) |
| 82 | + img = img.to(device) |
| 83 | + particle_filename = pred_path + test_names[index] + '.particles' |
| 84 | + if parameters['tl_net']['enabled']: |
| 85 | + mdl = torch.FloatTensor([1]).to(device) |
| 86 | + [pred_tf, pred_mdl_tl] = model_tl(mdl, img) |
| 87 | + pred_scores.append(pred_tf.cpu().data.numpy()) |
| 88 | + # save the AE latent space as shape descriptors |
| 89 | + filename = pred_path + test_names[index] + '.npy' |
| 90 | + np.save(filename, pred_tf.squeeze().detach().cpu().numpy()) |
| 91 | + np.savetxt(particle_filename, pred_mdl_tl.squeeze().detach().cpu().numpy()) |
| 92 | + else: |
| 93 | + [pred, pred_mdl_pca] = model_pca(img) |
| 94 | + [pred, pred_mdl_ft] = model_ft(img) |
| 95 | + pred_scores.append(pred.cpu().data.numpy()[0]) |
| 96 | + filename = pred_path_pca + '/predicted_pca_' + test_names[index] + '.particles' |
| 97 | + np.savetxt(filename, pred_mdl_pca.squeeze().detach().cpu().numpy()) |
| 98 | + np.savetxt(particle_filename, pred_mdl_ft.squeeze().detach().cpu().numpy()) |
| 99 | + print("Predicted particle file: ", particle_filename) |
| 100 | + predicted_particle_files.append(filename) |
| 101 | + index += 1 |
| 102 | + sw_message("Test completed.") |
| 103 | + return predicted_particle_files |
0 commit comments