Skip to content

Commit 077c588

Browse files
authored
DeepSSM input verification (#2196)
* Throw exception when number of original domains doesn't match. * When images or original files are not present, disallow deepssm mode * Throw useful exception when image size will be too small. * Add timing to steps of prep. * Add one more timer * Set Eigen threads to the same as TBB. This came about due to a bug or unexpected interaction with calling DeepSSM commands from Studio. After the image registration stuff runs, the OpenMP thread count is set to a crazy high level (e.g. 384 on a 16-core machine), which results in poor eigen performance per this page: https://eigen.tuxfamily.org/dox/TopicMultiThreading.html "Warning: On most OS it is very important to limit the number of threads to the number of physical cores, otherwise significant slowdowns are expected, especially for operations involving dense matrices." * Add Fine Tuning plot, combine csv * Fix logger being closed for FT * Allow decay LR to be off * Fix formatting * Reinitialze optimizer for fine tuning, use fine tuning learning rate * Parse table doubles and limit precision for display. * Fix table display (digits) * Update parameter names and tooltips * Reorganize DeepSSM Prep dialog a bit. Change percent variability to be consistent Add read-only training percent to show the user the amount that will be used. * Update screenshot. * Improve test mesh loading * Fix TL-DeepSSM when Decay Learning is off. * Replace shapeworks executable usage with Python API of MeshWarper. * Fix typos * Simplify and improve get_mesh_distance, also write back distance field to prediction mesh. * Add more room for image spacing. * Add ability to set reconstructed meshes directly. * Shift test mesh distance calc and results into Python from Studio. * Fix compile * Fix image name for train/test when image is not already selected. * Fix problem with boost create_directories
1 parent acb2f40 commit 077c588

File tree

20 files changed

+1337
-1013
lines changed

20 files changed

+1337
-1013
lines changed

Examples/Python/deep_ssm.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def Run_Pipeline(args):
345345
"model_name": model_name,
346346
"num_latent_dim": int(embedded_dim),
347347
"paths": {
348-
"out_dir": output_directory,
348+
"out_dir": deepssm_dir,
349349
"loader_dir": loader_dir,
350350
"aug_dir": aug_dir
351351
},
@@ -446,7 +446,7 @@ def Run_Pipeline(args):
446446
predicted_val_local_particles = []
447447
for particle_file, transform in zip(predicted_val_world_particles, val_transforms):
448448
particles = np.loadtxt(particle_file)
449-
local_particle_file = particle_file.replace("FT_Predictions/", "local_predictions/")
449+
local_particle_file = particle_file.replace("world_predictions/", "local_predictions/")
450450
local_particles = sw.utils.transformParticles(particles, transform, inverse=True)
451451
np.savetxt(local_particle_file, local_particles)
452452
predicted_val_local_particles.append(local_particle_file)
@@ -468,8 +468,6 @@ def Run_Pipeline(args):
468468

469469
print("Validation mean mesh surface-to-surface distance: " + str(mean_dist))
470470

471-
# If tiny test or verify, check results and exit
472-
check_results(args, mean_dist)
473471
open(status_dir + "step_11.txt", 'w').close()
474472

475473
######################################################################################
@@ -512,7 +510,7 @@ def Run_Pipeline(args):
512510
predicted_test_local_particles = []
513511
for particle_file, transform in zip(predicted_test_world_particles, test_transforms):
514512
particles = np.loadtxt(particle_file)
515-
local_particle_file = particle_file.replace("FT_Predictions/", "local_predictions/")
513+
local_particle_file = particle_file.replace("world_predictions/", "local_predictions/")
516514
local_particles = sw.utils.transformParticles(particles, transform, inverse=True)
517515
np.savetxt(local_particle_file, local_particles)
518516
predicted_test_local_particles.append(local_particle_file)
@@ -530,6 +528,9 @@ def Run_Pipeline(args):
530528
template_particles, template_mesh, test_out_dir,
531529
planes=test_planes)
532530
print("Test mean mesh surface-to-surface distance: " + str(mean_dist))
531+
532+
# If tiny test or verify, check results and exit
533+
check_results(args, mean_dist)
533534
open(status_dir + "step_12.txt", 'w').close()
534535

535536
print("All steps complete")

Libs/Analyze/Shape.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ MeshGroup Shape::get_reconstructed_meshes(bool wait) {
134134
return reconstructed_meshes_;
135135
}
136136

137+
//---------------------------------------------------------------------------
138+
void Shape::set_reconstructed_meshes(MeshGroup meshes) { reconstructed_meshes_ = meshes; }
139+
137140
//---------------------------------------------------------------------------
138141
void Shape::reset_groomed_mesh() { groomed_meshes_ = MeshGroup(subject_->get_number_of_domains()); }
139142

@@ -572,7 +575,7 @@ std::shared_ptr<Image> Shape::get_image_volume(std::string image_volume_name) {
572575
std::shared_ptr<Image> image = std::make_shared<Image>(filename);
573576
image_volume_ = image;
574577
image_volume_filename_ = filename;
575-
} catch (std::exception &ex) {
578+
} catch (std::exception& ex) {
576579
SW_ERROR("Unable to open file: {}", filename);
577580
}
578581
}

Libs/Analyze/Shape.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ class Shape {
7474
//! Retrieve the reconstructed meshes
7575
MeshGroup get_reconstructed_meshes(bool wait = false);
7676

77+
//! Set the reconstructed meshes
78+
void set_reconstructed_meshes(MeshGroup meshes);
79+
7780
//! Reset the groomed mesh so that it will be re-created
7881
void reset_groomed_mesh();
7982

Libs/Common/ShapeworksUtils.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ void ShapeWorksUtils::setup_threads() {
8585
num_threads = std::max(1, atoi(num_threads_env));
8686
}
8787
SW_DEBUG("TBB using {} threads", num_threads);
88+
Eigen::setNbThreads(num_threads);
8889
tbb::global_control c(tbb::global_control::max_allowed_parallelism, num_threads);
8990
}
9091

Libs/Groom/Groom.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,10 @@ int Groom::get_total_ops() {
453453
for (int i = 0; i < domains.size(); i++) {
454454
auto params = GroomParameters(project_, domains[i]);
455455

456+
if (project_->get_original_domain_types().size() <= i) {
457+
throw std::runtime_error("invalid domain, number of original file types does not match number of domains");
458+
}
459+
456460
if (project_->get_original_domain_types()[i] == DomainType::Image) {
457461
num_tools += params.get_isolate_tool() ? 1 : 0;
458462
num_tools += params.get_fill_holes_tool() ? 1 : 0;
@@ -526,7 +530,6 @@ bool Groom::run_alignment() {
526530
std::vector<Mesh> reference_meshes;
527531
std::vector<Mesh> meshes;
528532
for (size_t i = 0; i < subjects.size(); i++) {
529-
530533
if (!subjects[i]->is_excluded()) {
531534
Mesh mesh = get_mesh(i, domain, true);
532535
// if fixed subjects are present, only add the fixed subjects
@@ -711,7 +714,9 @@ std::string Groom::get_output_filename(std::string input, DomainType domain_type
711714
path = base + "/" + prefix;
712715

713716
try {
714-
boost::filesystem::create_directories(path);
717+
if (!boost::filesystem::exists(path)) {
718+
boost::filesystem::create_directories(path);
719+
}
715720
} catch (std::exception& e) {
716721
throw std::runtime_error("Unable to create groom output directory: \"" + path + "\"");
717722
}

Libs/Image/Image.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,9 @@ Image& Image::write(const std::string& filename, bool compressed) {
323323

324324
// if the directory doesn't exist, create it
325325
boost::filesystem::path dir(filename);
326-
boost::filesystem::create_directories(dir.parent_path());
326+
if (dir.has_parent_path() && !boost::filesystem::exists(dir.parent_path())) {
327+
boost::filesystem::create_directories(dir.parent_path());
328+
}
327329

328330
using WriterType = itk::ImageFileWriter<ImageType>;
329331
WriterType::Pointer writer = WriterType::New();

Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from .run_utils import create_split, groom_training_shapes, groom_training_images, get_reference_index, \
1111
run_data_augmentation, groom_val_test_images, prep_project_for_val_particles, groom_validation_shapes, \
12-
prepare_data_loaders, get_deepssm_dir, get_split_indices, optimize_training_particles
12+
prepare_data_loaders, get_deepssm_dir, get_split_indices, optimize_training_particles, process_test_predictions
1313

1414
from .config_file import prepare_config_file
1515

Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py

Lines changed: 83 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -13,92 +13,91 @@
1313
'''
1414
Network Test Function
1515
predicts the PCA scores using the trained networks
16-
returns the error measures and saves the predicted and poriginal particles for comparison
16+
returns the error measures and saves the predicted and original particles for comparison
1717
'''
18+
19+
1820
def test(config_file, loader="test"):
19-
with open(config_file) as json_file:
20-
parameters = json.load(json_file)
21-
model_dir = parameters["paths"]["out_dir"] + parameters["model_name"]+ '/'
22-
pred_dir = model_dir + loader + '_predictions/'
23-
loaders.make_dir(pred_dir)
24-
if parameters["use_best_model"]:
25-
model_path = model_dir + 'best_model.torch'
26-
else:
27-
model_path = model_dir + 'final_model.torch'
28-
if parameters["fine_tune"]["enabled"]:
29-
model_path_ft = model_path.replace(".torch", "_ft.torch")
30-
else:
31-
model_path_ft = model_path
32-
loader_dir = parameters["paths"]["loader_dir"]
21+
with open(config_file) as json_file:
22+
parameters = json.load(json_file)
23+
model_dir = parameters["paths"]["out_dir"] + parameters["model_name"] + '/'
24+
pred_dir = model_dir + loader + '_predictions/'
25+
loaders.make_dir(pred_dir)
26+
if parameters["use_best_model"]:
27+
model_path = model_dir + 'best_model.torch'
28+
else:
29+
model_path = model_dir + 'final_model.torch'
30+
if parameters["fine_tune"]["enabled"]:
31+
model_path_ft = model_path.replace(".torch", "_ft.torch")
32+
else:
33+
model_path_ft = model_path
34+
loader_dir = parameters["paths"]["loader_dir"]
35+
36+
# load the loaders
37+
sw_message("Loading " + loader + " data loader...")
38+
test_loader = torch.load(loader_dir + loader)
39+
40+
# initialization
41+
sw_message("Loading trained model...")
42+
if parameters['tl_net']['enabled']:
43+
model_tl = model.DeepSSMNet_TLNet(config_file)
44+
model_tl.load_state_dict(torch.load(model_path))
45+
device = model_tl.device
46+
model_tl.to(device)
47+
model_tl.eval()
48+
else:
49+
model_pca = model.DeepSSMNet(config_file)
50+
model_pca.load_state_dict(torch.load(model_path))
51+
device = model_pca.device
52+
model_pca.to(device)
53+
model_pca.eval()
54+
model_ft = model.DeepSSMNet(config_file)
55+
model_ft.load_state_dict(torch.load(model_path_ft))
56+
model_ft.to(device)
57+
model_ft.eval()
3358

34-
# load the loaders
35-
sw_message("Loading "+ loader + " data loader...")
36-
test_loader = torch.load(loader_dir + loader)
37-
print("Done.\n")
38-
# initalizations
39-
sw_message("Loading trained model...")
40-
if parameters['tl_net']['enabled']:
41-
model_tl = model.DeepSSMNet_TLNet(config_file)
42-
model_tl.load_state_dict(torch.load(model_path))
43-
device = model_tl.device
44-
model_tl.to(device)
45-
model_tl.eval()
46-
else:
47-
model_pca = model.DeepSSMNet(config_file)
48-
model_pca.load_state_dict(torch.load(model_path))
49-
device = model_pca.device
50-
model_pca.to(device)
51-
model_pca.eval()
52-
model_ft = model.DeepSSMNet(config_file)
53-
model_ft.load_state_dict(torch.load(model_path_ft))
54-
model_ft.to(device)
55-
model_ft.eval()
59+
# Get test names
60+
test_names_file = loader_dir + loader + '_names.txt'
61+
f = open(test_names_file, 'r')
62+
test_names_string = f.read()
63+
f.close()
64+
test_names_string = test_names_string.replace("[", "").replace("]", "").replace("'", "").replace(" ", "")
65+
test_names = test_names_string.split(",")
66+
sw_message(f"Predicting for {loader} images...")
67+
index = 0
68+
pred_scores = []
5669

57-
# Get test names
58-
test_names_file = loader_dir + loader + '_names.txt'
59-
f = open(test_names_file, 'r')
60-
test_names_string = f.read()
61-
f.close()
62-
test_names_string = test_names_string.replace("[","").replace("]","").replace("'","").replace(" ","")
63-
test_names = test_names_string.split(",")
64-
sw_message(f"Predicting for {loader} images...")
65-
index = 0
66-
pred_scores = []
70+
pred_path = pred_dir + 'world_predictions/'
71+
loaders.make_dir(pred_path)
72+
pred_path_pca = pred_dir + 'pca_predictions/'
73+
loaders.make_dir(pred_path_pca)
6774

68-
if parameters['tl_net']['enabled']:
69-
predPath_tl = pred_dir + '/TL_Predictions'
70-
loaders.make_dir(predPath_tl)
71-
else:
72-
predPath_ft = pred_dir + 'FT_Predictions/'
73-
predPath_pca = pred_dir + 'PCA_Predictions/'
74-
loaders.make_dir(predPath_ft)
75-
loaders.make_dir(predPath_pca)
76-
predicted_particle_files = []
77-
for img, _, mdl, _ in test_loader:
78-
if sw_check_abort():
79-
sw_message("Aborted")
80-
return
81-
sw_message(f"Predicting {index+1}/{len(test_loader)}")
82-
sw_progress((index+1) / len(test_loader))
83-
img = img.to(device)
84-
if parameters['tl_net']['enabled']:
85-
mdl = torch.FloatTensor([1]).to(device)
86-
[pred_tf, pred_mdl_tl] = model_tl(mdl, img)
87-
pred_scores.append(pred_tf.cpu().data.numpy())
88-
# save the AE latent space as shape descriptors
89-
nmpred = predPath_tl + '/' + test_names[index] + '.npy'
90-
np.save(nmpred, pred_tf.squeeze().detach().cpu().numpy())
91-
nmpred = predPath_tl + '/' + test_names[index] + '.particles'
92-
np.savetxt(nmpred, pred_mdl_tl.squeeze().detach().cpu().numpy())
93-
else:
94-
[pred, pred_mdl_pca] = model_pca(img)
95-
[pred, pred_mdl_ft] = model_ft(img)
96-
pred_scores.append(pred.cpu().data.numpy()[0])
97-
nmpred = predPath_pca + '/predicted_pca_' + test_names[index] + '.particles'
98-
np.savetxt(nmpred, pred_mdl_pca.squeeze().detach().cpu().numpy())
99-
nmpred = predPath_ft + '/predicted_ft_' + test_names[index] + '.particles'
100-
np.savetxt(nmpred, pred_mdl_ft.squeeze().detach().cpu().numpy())
101-
predicted_particle_files.append(nmpred)
102-
index += 1
103-
sw_message("Test completed.")
104-
return predicted_particle_files
75+
predicted_particle_files = []
76+
for img, _, mdl, _ in test_loader:
77+
if sw_check_abort():
78+
sw_message("Aborted")
79+
return
80+
sw_message(f"Predicting {index + 1}/{len(test_loader)}")
81+
sw_progress((index + 1) / len(test_loader))
82+
img = img.to(device)
83+
particle_filename = pred_path + test_names[index] + '.particles'
84+
if parameters['tl_net']['enabled']:
85+
mdl = torch.FloatTensor([1]).to(device)
86+
[pred_tf, pred_mdl_tl] = model_tl(mdl, img)
87+
pred_scores.append(pred_tf.cpu().data.numpy())
88+
# save the AE latent space as shape descriptors
89+
filename = pred_path + test_names[index] + '.npy'
90+
np.save(filename, pred_tf.squeeze().detach().cpu().numpy())
91+
np.savetxt(particle_filename, pred_mdl_tl.squeeze().detach().cpu().numpy())
92+
else:
93+
[pred, pred_mdl_pca] = model_pca(img)
94+
[pred, pred_mdl_ft] = model_ft(img)
95+
pred_scores.append(pred.cpu().data.numpy()[0])
96+
filename = pred_path_pca + '/predicted_pca_' + test_names[index] + '.particles'
97+
np.savetxt(filename, pred_mdl_pca.squeeze().detach().cpu().numpy())
98+
np.savetxt(particle_filename, pred_mdl_ft.squeeze().detach().cpu().numpy())
99+
print("Predicted particle file: ", particle_filename)
100+
predicted_particle_files.append(filename)
101+
index += 1
102+
sw_message("Test completed.")
103+
return predicted_particle_files

0 commit comments

Comments
 (0)