@@ -22,10 +22,8 @@ using namespace pybind11::literals; // to bring in the `_a` literal
2222namespace shapeworks {
2323
2424// ---------------------------------------------------------------------------
25- DeepSSMJob::DeepSSMJob (std::shared_ptr<Project> project, DeepSSMTool::ToolMode tool_mode,
26- DeepSSMTool::PrepStep prep_step)
27- : project_(project), tool_mode_(tool_mode), prep_step_(prep_step) {
28- }
25+ DeepSSMJob::DeepSSMJob (std::shared_ptr<Project> project, DeepSSMJob::ToolMode tool_mode, DeepSSMJob::PrepStep prep_step)
26+ : project_(project), tool_mode_(tool_mode), prep_step_(prep_step) {}
2927
3028// ---------------------------------------------------------------------------
3129DeepSSMJob::~DeepSSMJob () {}
@@ -34,16 +32,16 @@ DeepSSMJob::~DeepSSMJob() {}
3432void DeepSSMJob::run () {
3533 try {
3634 switch (tool_mode_) {
37- case DeepSSMTool ::ToolMode::DeepSSM_PrepType:
35+ case DeepSSMJob ::ToolMode::DeepSSM_PrepType:
3836 run_prep ();
3937 break ;
40- case DeepSSMTool ::ToolMode::DeepSSM_AugmentationType:
38+ case DeepSSMJob ::ToolMode::DeepSSM_AugmentationType:
4139 run_augmentation ();
4240 break ;
43- case DeepSSMTool ::ToolMode::DeepSSM_TrainingType:
41+ case DeepSSMJob ::ToolMode::DeepSSM_TrainingType:
4442 run_training ();
4543 break ;
46- case DeepSSMTool ::ToolMode::DeepSSM_TestingType:
44+ case DeepSSMJob ::ToolMode::DeepSSM_TestingType:
4745 run_testing ();
4846 break ;
4947 }
@@ -55,16 +53,16 @@ void DeepSSMJob::run() {
5553// ---------------------------------------------------------------------------
5654QString DeepSSMJob::name () {
5755 switch (tool_mode_) {
58- case DeepSSMTool ::ToolMode::DeepSSM_PrepType:
56+ case DeepSSMJob ::ToolMode::DeepSSM_PrepType:
5957 return " DeepSSM: Prep" ;
6058 break ;
61- case DeepSSMTool ::ToolMode::DeepSSM_AugmentationType:
59+ case DeepSSMJob ::ToolMode::DeepSSM_AugmentationType:
6260 return " DeepSSM: Augmentation" ;
6361 break ;
64- case DeepSSMTool ::ToolMode::DeepSSM_TrainingType:
62+ case DeepSSMJob ::ToolMode::DeepSSM_TrainingType:
6563 return " DeepSSM: Training" ;
6664 break ;
67- case DeepSSMTool ::ToolMode::DeepSSM_TestingType:
65+ case DeepSSMJob ::ToolMode::DeepSSM_TestingType:
6866 return " DeepSSM: Testing" ;
6967 break ;
7068 }
@@ -86,7 +84,7 @@ void DeepSSMJob::run_prep() {
8684 params.set_training_step_complete (false );
8785 params.save_to_project ();
8886
89- if (prep_step_ == DeepSSMTool ::PrepStep::NOT_STARTED || prep_step_ == DeepSSMTool ::PrepStep::GROOM_TRAINING) {
87+ if (prep_step_ == DeepSSMJob ::PrepStep::NOT_STARTED || prep_step_ == DeepSSMJob ::PrepStep::GROOM_TRAINING) {
9088 SW_LOG (" Creating Split..." );
9189 // ///////////////////////////////////////////////////////
9290 // / Step 1. Create Split
@@ -97,9 +95,9 @@ void DeepSSMJob::run_prep() {
9795 py::object create_split = py_deep_ssm_utils.attr (" create_split" );
9896 create_split (project_, train_split, val_split, test_split);
9997
100- int num_train = DeepSSMTool:: get_split (project_, DeepSSMTool:: SplitType::TRAIN).size ();
101- int num_val = DeepSSMTool:: get_split (project_, DeepSSMTool:: SplitType::VAL).size ();
102- int num_test = DeepSSMTool:: get_split (project_, DeepSSMTool:: SplitType::TEST).size ();
98+ int num_train = get_split (project_, SplitType::TRAIN).size ();
99+ int num_val = get_split (project_, SplitType::VAL).size ();
100+ int num_test = get_split (project_, SplitType::TEST).size ();
103101 if (num_train == 0 || num_val == 0 ) {
104102 SW_ERROR (" DeepSSM: Not enough subjects in training and validation. Please check split." );
105103 abort ();
@@ -112,7 +110,7 @@ void DeepSSMJob::run_prep() {
112110 // ///////////////////////////////////////////////////////
113111 // / Step 2. Groom Training Shapes
114112 // ///////////////////////////////////////////////////////
115- update_prep_stage (DeepSSMTool ::PrepStep::GROOM_TRAINING);
113+ update_prep_stage (DeepSSMJob ::PrepStep::GROOM_TRAINING);
116114 py::object groom_training_shapes = py_deep_ssm_utils.attr (" groom_training_shapes" );
117115
118116 QElapsedTimer timer;
@@ -142,11 +140,11 @@ void DeepSSMJob::run_prep() {
142140 }
143141 }
144142
145- if (prep_step_ == DeepSSMTool ::PrepStep::NOT_STARTED || prep_step_ == DeepSSMTool ::PrepStep::OPTIMIZE_TRAINING) {
143+ if (prep_step_ == DeepSSMJob ::PrepStep::NOT_STARTED || prep_step_ == DeepSSMJob ::PrepStep::OPTIMIZE_TRAINING) {
146144 // ///////////////////////////////////////////////////////
147145 // / Step 3. Optimize Training Particles
148146 // ///////////////////////////////////////////////////////
149- update_prep_stage (DeepSSMTool ::PrepStep::OPTIMIZE_TRAINING);
147+ update_prep_stage (DeepSSMJob ::PrepStep::OPTIMIZE_TRAINING);
150148 QElapsedTimer timer;
151149 timer.start ();
152150 py::object optimize_training_particles = py_deep_ssm_utils.attr (" optimize_training_particles" );
@@ -160,11 +158,11 @@ void DeepSSMJob::run_prep() {
160158 }
161159 }
162160
163- if (prep_step_ == DeepSSMTool ::PrepStep::NOT_STARTED || prep_step_ == DeepSSMTool ::PrepStep::OPTIMIZE_VALIDATION) {
161+ if (prep_step_ == DeepSSMJob ::PrepStep::NOT_STARTED || prep_step_ == DeepSSMJob ::PrepStep::OPTIMIZE_VALIDATION) {
164162 // ///////////////////////////////////////////////////////
165163 // / Step 6. Optimize Validation Particles with Fixed Domains
166164 // ///////////////////////////////////////////////////////
167- update_prep_stage (DeepSSMTool ::PrepStep::OPTIMIZE_VALIDATION);
165+ update_prep_stage (DeepSSMJob ::PrepStep::OPTIMIZE_VALIDATION);
168166 py::object prep_project_for_val_particles = py_deep_ssm_utils.attr (" prep_project_for_val_particles" );
169167 prep_project_for_val_particles (project_);
170168
@@ -188,12 +186,12 @@ void DeepSSMJob::run_prep() {
188186 SW_LOG (" DeepSSM: Optimize Validation Particles complete. Duration: {} seconds" , duration.toStdString ());
189187 }
190188
191- if (prep_step_ == DeepSSMTool ::PrepStep::NOT_STARTED || prep_step_ == DeepSSMTool ::PrepStep::GROOM_IMAGES) {
189+ if (prep_step_ == DeepSSMJob ::PrepStep::NOT_STARTED || prep_step_ == DeepSSMJob ::PrepStep::GROOM_IMAGES) {
192190 // ///////////////////////////////////////////////////////
193191 // / Step 4. Groom Training Images
194192 // ///////////////////////////////////////////////////////
195193
196- update_prep_stage (DeepSSMTool ::PrepStep::GROOM_IMAGES);
194+ update_prep_stage (DeepSSMJob ::PrepStep::GROOM_IMAGES);
197195 QElapsedTimer timer;
198196 timer.start ();
199197 py::object groom_training_images = py_deep_ssm_utils.attr (" groom_training_images" );
@@ -210,7 +208,7 @@ void DeepSSMJob::run_prep() {
210208 // ///////////////////////////////////////////////////////
211209 timer.start ();
212210 py::object groom_val_test_images = py_deep_ssm_utils.attr (" groom_val_test_images" );
213- groom_val_test_images (project_, DeepSSMTool:: get_split (project_, DeepSSMTool:: SplitType::VAL));
211+ groom_val_test_images (project_, get_split (project_, SplitType::VAL));
214212 project_->save ();
215213 duration = QString::number (timer.elapsed () / 1000.0 , ' f' , 1 );
216214 SW_LOG (" DeepSSM: Groom Validation Images complete. Duration: {} seconds" , duration.toStdString ());
@@ -221,7 +219,7 @@ void DeepSSMJob::run_prep() {
221219 }
222220
223221 // ///////////////////////////////////////////////////////
224- update_prep_stage (DeepSSMTool ::PrepStep::DONE);
222+ update_prep_stage (DeepSSMJob ::PrepStep::DONE);
225223 params.set_prep_step_complete (true );
226224 params.set_aug_step_complete (false );
227225 params.set_training_step_complete (false );
@@ -317,7 +315,7 @@ void DeepSSMJob::run_testing() {
317315
318316 py::module py_deep_ssm_utils = py::module::import (" DeepSSMUtils" );
319317
320- std::vector<int > test_indices = DeepSSMTool:: get_split (project_, DeepSSMTool:: SplitType::TEST);
318+ std::vector<int > test_indices = get_split (project_, SplitType::TEST);
321319
322320 // Groom Test Images
323321 SW_MESSAGE (" Grooming Test Images" );
@@ -360,7 +358,37 @@ void DeepSSMJob::run_testing() {
360358void DeepSSMJob::python_message (std::string str) { SW_LOG (str); }
361359
362360// ---------------------------------------------------------------------------
363- void DeepSSMJob::update_prep_stage (DeepSSMTool::PrepStep step) {
361+ std::vector<int > DeepSSMJob::get_split (ProjectHandle project, SplitType split_type) {
362+ auto subjects = project->get_subjects ();
363+
364+ std::vector<int > list;
365+
366+ for (int id = 0 ; id < subjects.size (); id++) {
367+ auto extra_values = subjects[id]->get_extra_values ();
368+
369+ std::string split = extra_values[" split" ];
370+
371+ if (split_type == DeepSSMJob::SplitType::TRAIN) {
372+ if (split != " train" ) {
373+ continue ;
374+ }
375+ } else if (split_type == DeepSSMJob::SplitType::VAL) {
376+ if (split != " val" ) {
377+ continue ;
378+ }
379+ } else if (split_type == DeepSSMJob::SplitType::TEST) {
380+ if (split != " test" ) {
381+ continue ;
382+ }
383+ }
384+
385+ list.push_back (id);
386+ }
387+ return list;
388+ }
389+
390+ // ---------------------------------------------------------------------------
391+ void DeepSSMJob::update_prep_stage (PrepStep step) {
364392 /*
365393 std::lock_guard<std::mutex> lock(mutex_);
366394
0 commit comments