Skip to content

Commit 0066dd9

Browse files
committed
Isolate DeepSSMJob from Studio
1 parent 3e1087e commit 0066dd9

File tree

4 files changed

+126
-128
lines changed

4 files changed

+126
-128
lines changed

Studio/DeepSSM/DeepSSMJob.cpp

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@ using namespace pybind11::literals; // to bring in the `_a` literal
2222
namespace shapeworks {
2323

2424
//---------------------------------------------------------------------------
25-
DeepSSMJob::DeepSSMJob(std::shared_ptr<Project> project, DeepSSMTool::ToolMode tool_mode,
26-
DeepSSMTool::PrepStep prep_step)
27-
: project_(project), tool_mode_(tool_mode), prep_step_(prep_step) {
28-
}
25+
DeepSSMJob::DeepSSMJob(std::shared_ptr<Project> project, DeepSSMJob::ToolMode tool_mode, DeepSSMJob::PrepStep prep_step)
26+
: project_(project), tool_mode_(tool_mode), prep_step_(prep_step) {}
2927

3028
//---------------------------------------------------------------------------
3129
DeepSSMJob::~DeepSSMJob() {}
@@ -34,16 +32,16 @@ DeepSSMJob::~DeepSSMJob() {}
3432
void DeepSSMJob::run() {
3533
try {
3634
switch (tool_mode_) {
37-
case DeepSSMTool::ToolMode::DeepSSM_PrepType:
35+
case DeepSSMJob::ToolMode::DeepSSM_PrepType:
3836
run_prep();
3937
break;
40-
case DeepSSMTool::ToolMode::DeepSSM_AugmentationType:
38+
case DeepSSMJob::ToolMode::DeepSSM_AugmentationType:
4139
run_augmentation();
4240
break;
43-
case DeepSSMTool::ToolMode::DeepSSM_TrainingType:
41+
case DeepSSMJob::ToolMode::DeepSSM_TrainingType:
4442
run_training();
4543
break;
46-
case DeepSSMTool::ToolMode::DeepSSM_TestingType:
44+
case DeepSSMJob::ToolMode::DeepSSM_TestingType:
4745
run_testing();
4846
break;
4947
}
@@ -55,16 +53,16 @@ void DeepSSMJob::run() {
5553
//---------------------------------------------------------------------------
5654
QString DeepSSMJob::name() {
5755
switch (tool_mode_) {
58-
case DeepSSMTool::ToolMode::DeepSSM_PrepType:
56+
case DeepSSMJob::ToolMode::DeepSSM_PrepType:
5957
return "DeepSSM: Prep";
6058
break;
61-
case DeepSSMTool::ToolMode::DeepSSM_AugmentationType:
59+
case DeepSSMJob::ToolMode::DeepSSM_AugmentationType:
6260
return "DeepSSM: Augmentation";
6361
break;
64-
case DeepSSMTool::ToolMode::DeepSSM_TrainingType:
62+
case DeepSSMJob::ToolMode::DeepSSM_TrainingType:
6563
return "DeepSSM: Training";
6664
break;
67-
case DeepSSMTool::ToolMode::DeepSSM_TestingType:
65+
case DeepSSMJob::ToolMode::DeepSSM_TestingType:
6866
return "DeepSSM: Testing";
6967
break;
7068
}
@@ -86,7 +84,7 @@ void DeepSSMJob::run_prep() {
8684
params.set_training_step_complete(false);
8785
params.save_to_project();
8886

89-
if (prep_step_ == DeepSSMTool::PrepStep::NOT_STARTED || prep_step_ == DeepSSMTool::PrepStep::GROOM_TRAINING) {
87+
if (prep_step_ == DeepSSMJob::PrepStep::NOT_STARTED || prep_step_ == DeepSSMJob::PrepStep::GROOM_TRAINING) {
9088
SW_LOG("Creating Split...");
9189
/////////////////////////////////////////////////////////
9290
/// Step 1. Create Split
@@ -97,9 +95,9 @@ void DeepSSMJob::run_prep() {
9795
py::object create_split = py_deep_ssm_utils.attr("create_split");
9896
create_split(project_, train_split, val_split, test_split);
9997

100-
int num_train = DeepSSMTool::get_split(project_, DeepSSMTool::SplitType::TRAIN).size();
101-
int num_val = DeepSSMTool::get_split(project_, DeepSSMTool::SplitType::VAL).size();
102-
int num_test = DeepSSMTool::get_split(project_, DeepSSMTool::SplitType::TEST).size();
98+
int num_train = get_split(project_, SplitType::TRAIN).size();
99+
int num_val = get_split(project_, SplitType::VAL).size();
100+
int num_test = get_split(project_, SplitType::TEST).size();
103101
if (num_train == 0 || num_val == 0) {
104102
SW_ERROR("DeepSSM: Not enough subjects in training and validation. Please check split.");
105103
abort();
@@ -112,7 +110,7 @@ void DeepSSMJob::run_prep() {
112110
/////////////////////////////////////////////////////////
113111
/// Step 2. Groom Training Shapes
114112
/////////////////////////////////////////////////////////
115-
update_prep_stage(DeepSSMTool::PrepStep::GROOM_TRAINING);
113+
update_prep_stage(DeepSSMJob::PrepStep::GROOM_TRAINING);
116114
py::object groom_training_shapes = py_deep_ssm_utils.attr("groom_training_shapes");
117115

118116
QElapsedTimer timer;
@@ -142,11 +140,11 @@ void DeepSSMJob::run_prep() {
142140
}
143141
}
144142

145-
if (prep_step_ == DeepSSMTool::PrepStep::NOT_STARTED || prep_step_ == DeepSSMTool::PrepStep::OPTIMIZE_TRAINING) {
143+
if (prep_step_ == DeepSSMJob::PrepStep::NOT_STARTED || prep_step_ == DeepSSMJob::PrepStep::OPTIMIZE_TRAINING) {
146144
/////////////////////////////////////////////////////////
147145
/// Step 3. Optimize Training Particles
148146
/////////////////////////////////////////////////////////
149-
update_prep_stage(DeepSSMTool::PrepStep::OPTIMIZE_TRAINING);
147+
update_prep_stage(DeepSSMJob::PrepStep::OPTIMIZE_TRAINING);
150148
QElapsedTimer timer;
151149
timer.start();
152150
py::object optimize_training_particles = py_deep_ssm_utils.attr("optimize_training_particles");
@@ -160,11 +158,11 @@ void DeepSSMJob::run_prep() {
160158
}
161159
}
162160

163-
if (prep_step_ == DeepSSMTool::PrepStep::NOT_STARTED || prep_step_ == DeepSSMTool::PrepStep::OPTIMIZE_VALIDATION) {
161+
if (prep_step_ == DeepSSMJob::PrepStep::NOT_STARTED || prep_step_ == DeepSSMJob::PrepStep::OPTIMIZE_VALIDATION) {
164162
/////////////////////////////////////////////////////////
165163
/// Step 6. Optimize Validation Particles with Fixed Domains
166164
/////////////////////////////////////////////////////////
167-
update_prep_stage(DeepSSMTool::PrepStep::OPTIMIZE_VALIDATION);
165+
update_prep_stage(DeepSSMJob::PrepStep::OPTIMIZE_VALIDATION);
168166
py::object prep_project_for_val_particles = py_deep_ssm_utils.attr("prep_project_for_val_particles");
169167
prep_project_for_val_particles(project_);
170168

@@ -188,12 +186,12 @@ void DeepSSMJob::run_prep() {
188186
SW_LOG("DeepSSM: Optimize Validation Particles complete. Duration: {} seconds", duration.toStdString());
189187
}
190188

191-
if (prep_step_ == DeepSSMTool::PrepStep::NOT_STARTED || prep_step_ == DeepSSMTool::PrepStep::GROOM_IMAGES) {
189+
if (prep_step_ == DeepSSMJob::PrepStep::NOT_STARTED || prep_step_ == DeepSSMJob::PrepStep::GROOM_IMAGES) {
192190
/////////////////////////////////////////////////////////
193191
/// Step 4. Groom Training Images
194192
/////////////////////////////////////////////////////////
195193

196-
update_prep_stage(DeepSSMTool::PrepStep::GROOM_IMAGES);
194+
update_prep_stage(DeepSSMJob::PrepStep::GROOM_IMAGES);
197195
QElapsedTimer timer;
198196
timer.start();
199197
py::object groom_training_images = py_deep_ssm_utils.attr("groom_training_images");
@@ -210,7 +208,7 @@ void DeepSSMJob::run_prep() {
210208
/////////////////////////////////////////////////////////
211209
timer.start();
212210
py::object groom_val_test_images = py_deep_ssm_utils.attr("groom_val_test_images");
213-
groom_val_test_images(project_, DeepSSMTool::get_split(project_, DeepSSMTool::SplitType::VAL));
211+
groom_val_test_images(project_, get_split(project_, SplitType::VAL));
214212
project_->save();
215213
duration = QString::number(timer.elapsed() / 1000.0, 'f', 1);
216214
SW_LOG("DeepSSM: Groom Validation Images complete. Duration: {} seconds", duration.toStdString());
@@ -221,7 +219,7 @@ void DeepSSMJob::run_prep() {
221219
}
222220

223221
/////////////////////////////////////////////////////////
224-
update_prep_stage(DeepSSMTool::PrepStep::DONE);
222+
update_prep_stage(DeepSSMJob::PrepStep::DONE);
225223
params.set_prep_step_complete(true);
226224
params.set_aug_step_complete(false);
227225
params.set_training_step_complete(false);
@@ -317,7 +315,7 @@ void DeepSSMJob::run_testing() {
317315

318316
py::module py_deep_ssm_utils = py::module::import("DeepSSMUtils");
319317

320-
std::vector<int> test_indices = DeepSSMTool::get_split(project_, DeepSSMTool::SplitType::TEST);
318+
std::vector<int> test_indices = get_split(project_, SplitType::TEST);
321319

322320
// Groom Test Images
323321
SW_MESSAGE("Grooming Test Images");
@@ -360,7 +358,37 @@ void DeepSSMJob::run_testing() {
360358
void DeepSSMJob::python_message(std::string str) { SW_LOG(str); }
361359

362360
//---------------------------------------------------------------------------
363-
void DeepSSMJob::update_prep_stage(DeepSSMTool::PrepStep step) {
361+
std::vector<int> DeepSSMJob::get_split(ProjectHandle project, SplitType split_type) {
362+
auto subjects = project->get_subjects();
363+
364+
std::vector<int> list;
365+
366+
for (int id = 0; id < subjects.size(); id++) {
367+
auto extra_values = subjects[id]->get_extra_values();
368+
369+
std::string split = extra_values["split"];
370+
371+
if (split_type == DeepSSMJob::SplitType::TRAIN) {
372+
if (split != "train") {
373+
continue;
374+
}
375+
} else if (split_type == DeepSSMJob::SplitType::VAL) {
376+
if (split != "val") {
377+
continue;
378+
}
379+
} else if (split_type == DeepSSMJob::SplitType::TEST) {
380+
if (split != "test") {
381+
continue;
382+
}
383+
}
384+
385+
list.push_back(id);
386+
}
387+
return list;
388+
}
389+
390+
//---------------------------------------------------------------------------
391+
void DeepSSMJob::update_prep_stage(PrepStep step) {
364392
/*
365393
std::lock_guard<std::mutex> lock(mutex_);
366394

Studio/DeepSSM/DeepSSMJob.h

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#pragma once
22

3-
#include <DeepSSM/DeepSSMTool.h>
43
#include <Job/Job.h>
54
#include <Project/Project.h>
65

@@ -17,8 +16,26 @@ class DeepSSMJob : public Job {
1716
Q_OBJECT;
1817

1918
public:
20-
DeepSSMJob(std::shared_ptr<Project> project, DeepSSMTool::ToolMode tool_mode,
21-
DeepSSMTool::PrepStep prep_step = DeepSSMTool::NOT_STARTED);
19+
enum class ToolMode {
20+
DeepSSM_PrepType = 0,
21+
DeepSSM_AugmentationType = 1,
22+
DeepSSM_TrainingType = 2,
23+
DeepSSM_TestingType = 3
24+
};
25+
26+
enum PrepStep {
27+
NOT_STARTED = 0,
28+
GROOM_TRAINING = 1,
29+
OPTIMIZE_TRAINING = 2,
30+
OPTIMIZE_VALIDATION = 3,
31+
GROOM_IMAGES = 4,
32+
DONE = 5
33+
};
34+
35+
enum class SplitType { TRAIN, VAL, TEST };
36+
37+
DeepSSMJob(std::shared_ptr<Project> project, DeepSSMJob::ToolMode tool_mode,
38+
DeepSSMJob::PrepStep prep_step = DeepSSMJob::NOT_STARTED);
2239
~DeepSSMJob();
2340

2441
void run() override;
@@ -32,17 +49,18 @@ class DeepSSMJob : public Job {
3249

3350
void python_message(std::string str);
3451

52+
static std::vector<int> get_split(ProjectHandle project, DeepSSMJob::SplitType split_type);
53+
3554
private:
36-
void update_prep_stage(DeepSSMTool::PrepStep step);
55+
void update_prep_stage(DeepSSMJob::PrepStep step);
3756
void process_test_results();
3857

39-
QSharedPointer<Session> session_;
4058
std::shared_ptr<Project> project_;
4159

42-
DeepSSMTool::ToolMode tool_mode_;
60+
DeepSSMJob::ToolMode tool_mode_;
4361

4462
QString prep_message_;
45-
DeepSSMTool::PrepStep prep_step_{DeepSSMTool::NOT_STARTED};
63+
DeepSSMJob::PrepStep prep_step_{DeepSSMJob::NOT_STARTED};
4664

4765
// mutex
4866
std::mutex mutex_;

0 commit comments

Comments
 (0)