Skip to content

Commit 4fd4ab7

Browse files
committed
Working on DeepSSM command
1 parent f25a0da commit 4fd4ab7

File tree

9 files changed

+179
-59
lines changed

9 files changed

+179
-59
lines changed

Applications/shapeworks/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ target_include_directories(shapeworks_exe PUBLIC
2323

2424
target_link_libraries(shapeworks_exe
2525
Mesh ${VTK_LIBRARIES} Optimize Utils trimesh2 Particles
26-
pybind11::embed Project Image Groom Analyze
26+
pybind11::embed Project Image Groom Analyze Application
2727
)
2828

2929
message(STATUS "opt libs ${OPTIMIZE_LIBRARIES}")

Applications/shapeworks/Command.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,15 @@ class ParticleSystemCommand : public Command
103103
private:
104104
};
105105

106+
class DeepSSMCommandGroup : public Command
107+
{
108+
public:
109+
const std::string type() override { return "DeepSSM"; }
110+
111+
private:
112+
};
113+
114+
106115
class ShapeworksCommand : public Command
107116
{
108117
public:

Applications/shapeworks/Commands.cpp

Lines changed: 108 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
#include "Commands.h"
22

33
#include <Analyze/Analyze.h>
4+
#include <Application/DeepSSM/DeepSSMJob.h>
45
#include <Groom/Groom.h>
56
#include <Logging.h>
67
#include <Optimize/Optimize.h>
78
#include <Optimize/OptimizeParameterFile.h>
89
#include <Optimize/OptimizeParameters.h>
10+
#include <Profiling.h>
911
#include <ShapeworksUtils.h>
1012
#include <Utils/StringUtils.h>
1113

1214
#include <boost/filesystem.hpp>
1315

14-
#include <Profiling.h>
15-
1616
namespace shapeworks {
1717

1818
// boilerplate for a command. Copy this to start a new command
@@ -43,8 +43,6 @@ bool Example::execute(const optparse::Values &options, SharedCommandData &shared
4343
}
4444
#endif
4545

46-
47-
4846
///////////////////////////////////////////////////////////////////////////////
4947
// Seed
5048
///////////////////////////////////////////////////////////////////////////////
@@ -331,4 +329,110 @@ bool ConvertProjectCommand::execute(const optparse::Values& options, SharedComma
331329
return false;
332330
}
333331
}
332+
333+
///////////////////////////////////////////////////////////////////////////////
334+
// DeepSSM
335+
///////////////////////////////////////////////////////////////////////////////
336+
void DeepSSMCommand::buildParser() {
337+
const std::string prog = "deepssm";
338+
const std::string desc = "run deepssm steps";
339+
parser.prog(prog).description(desc);
340+
341+
parser.add_option("--name").action("store").type("string").set_default("").help(
342+
"Path to input project file (xlsx or swproj).");
343+
344+
// Create a vector of choices first
345+
std::vector<std::string> prep_choices = {"all", "groom_training", "optimize_training", "optimize_validation",
346+
"groom_images"};
347+
348+
// --prep option with choices
349+
parser.add_option("--prep")
350+
.action("store")
351+
.type("choice")
352+
.choices(prep_choices.begin(), prep_choices.end())
353+
.set_default("all")
354+
.help("Preparation step to run");
355+
356+
// Boolean flag options
357+
parser.add_option("--augment").action("store_true").help("Run data augmentation");
358+
359+
parser.add_option("--train").action("store_true").help("Run training");
360+
361+
parser.add_option("--test").action("store_true").help("Run testing");
362+
363+
parser.add_option("--all").action("store_true").help("Run all steps");
364+
365+
Command::buildParser();
366+
}
367+
368+
bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& sharedData) {
369+
// Handle project file: either from --name or first positional argument
370+
std::string project_file;
371+
if (options.is_set_by_user("name")) {
372+
// User explicitly provided --name
373+
project_file = options["name"];
374+
} else if (!parser.args().empty()) {
375+
// Use first positional argument
376+
project_file = parser.args()[0];
377+
} else {
378+
// No project file provided at all
379+
parser.error("Project file must be provided either as --name or as a positional argument");
380+
}
381+
382+
std::cout << "DeepSSM: Using project file: " << project_file << std::endl;
383+
384+
bool do_prep = options.is_set("prep") || options.is_set("all");
385+
std::string prep_step = options["prep"];
386+
bool do_augment = options.is_set("augment") || options.is_set("all");
387+
bool do_train = options.is_set("train") || options.is_set("all");
388+
bool do_test = options.is_set("test") || options.is_set("all");
389+
if (!do_prep && !do_augment && !do_train && !do_test) {
390+
do_prep = true;
391+
do_augment = true;
392+
do_train = true;
393+
do_test = true;
394+
}
395+
396+
ProjectHandle project = std::make_shared<Project>();
397+
project->load(project_file);
398+
399+
DeepSSMJob job(project, DeepSSMJob::JobType::DeepSSM_PrepType);
400+
401+
if (do_prep) {
402+
if (prep_step == "all") {
403+
job.set_prep_step(DeepSSMJob::PrepStep::NOT_STARTED);
404+
} else if (prep_step == "groom_training") {
405+
job.set_prep_step(DeepSSMJob::PrepStep::GROOM_TRAINING);
406+
} else if (prep_step == "optimize_training") {
407+
job.set_prep_step(DeepSSMJob::PrepStep::OPTIMIZE_TRAINING);
408+
} else if (prep_step == "optimize_validation") {
409+
job.set_prep_step(DeepSSMJob::PrepStep::OPTIMIZE_VALIDATION);
410+
} else if (prep_step == "groom_images") {
411+
job.set_prep_step(DeepSSMJob::PrepStep::GROOM_IMAGES);
412+
} else {
413+
SW_ERROR("Unknown prep step: {}", prep_step);
414+
return false;
415+
}
416+
std::cout << "Running DeepSSM preparation step...\n";
417+
job.run_prep();
418+
}
419+
if (do_augment) {
420+
std::cout << "Running DeepSSM data augmentation...\n";
421+
job.run_augmentation();
422+
}
423+
if (do_train) {
424+
std::cout << "Running DeepSSM training...\n";
425+
job.run_training();
426+
}
427+
if (do_test) {
428+
std::cout << "Running DeepSSM testing...\n";
429+
job.run_testing();
430+
}
431+
432+
project->save();
433+
434+
SW_ERROR("DeepSSM command is not implemented yet.");
435+
return false;
436+
}
437+
334438
} // namespace shapeworks

Applications/shapeworks/Commands.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,5 +101,6 @@ COMMAND_DECLARE(OptimizeCommand, OptimizeCommandGroup);
101101
COMMAND_DECLARE(GroomCommand, GroomCommandGroup);
102102
COMMAND_DECLARE(AnalyzeCommand, AnalyzeCommandGroup);
103103
COMMAND_DECLARE(ConvertProjectCommand, ProjectCommandGroup);
104+
COMMAND_DECLARE(DeepSSMCommand, DeepSSMCommandGroup);
104105

105106
} // shapeworks

Applications/shapeworks/shapeworks.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ int main(int argc, char *argv[])
110110
shapeworks.addCommand(GroomCommand::getCommand());
111111
shapeworks.addCommand(AnalyzeCommand::getCommand());
112112
shapeworks.addCommand(ConvertProjectCommand::getCommand());
113+
shapeworks.addCommand(DeepSSMCommand::getCommand());
113114

114115
try {
115116
TIME_START("shapeworks");

Libs/Application/DeepSSM/DeepSSMJob.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,26 @@ using namespace pybind11::literals; // to bring in the `_a` literal
2222
namespace shapeworks {
2323

2424
//---------------------------------------------------------------------------
25-
DeepSSMJob::DeepSSMJob(std::shared_ptr<Project> project, DeepSSMJob::ToolMode tool_mode, DeepSSMJob::PrepStep prep_step)
26-
: project_(project), tool_mode_(tool_mode), prep_step_(prep_step) {}
25+
DeepSSMJob::DeepSSMJob(std::shared_ptr<Project> project, DeepSSMJob::JobType tool_mode, DeepSSMJob::PrepStep prep_step)
26+
: project_(project), job_type_(tool_mode), prep_step_(prep_step) {}
2727

2828
//---------------------------------------------------------------------------
2929
DeepSSMJob::~DeepSSMJob() {}
3030

3131
//---------------------------------------------------------------------------
3232
void DeepSSMJob::run() {
3333
try {
34-
switch (tool_mode_) {
35-
case DeepSSMJob::ToolMode::DeepSSM_PrepType:
34+
switch (job_type_) {
35+
case DeepSSMJob::JobType::DeepSSM_PrepType:
3636
run_prep();
3737
break;
38-
case DeepSSMJob::ToolMode::DeepSSM_AugmentationType:
38+
case DeepSSMJob::JobType::DeepSSM_AugmentationType:
3939
run_augmentation();
4040
break;
41-
case DeepSSMJob::ToolMode::DeepSSM_TrainingType:
41+
case DeepSSMJob::JobType::DeepSSM_TrainingType:
4242
run_training();
4343
break;
44-
case DeepSSMJob::ToolMode::DeepSSM_TestingType:
44+
case DeepSSMJob::JobType::DeepSSM_TestingType:
4545
run_testing();
4646
break;
4747
}
@@ -52,17 +52,17 @@ void DeepSSMJob::run() {
5252

5353
//---------------------------------------------------------------------------
5454
QString DeepSSMJob::name() {
55-
switch (tool_mode_) {
56-
case DeepSSMJob::ToolMode::DeepSSM_PrepType:
55+
switch (job_type_) {
56+
case DeepSSMJob::JobType::DeepSSM_PrepType:
5757
return "DeepSSM: Prep";
5858
break;
59-
case DeepSSMJob::ToolMode::DeepSSM_AugmentationType:
59+
case DeepSSMJob::JobType::DeepSSM_AugmentationType:
6060
return "DeepSSM: Augmentation";
6161
break;
62-
case DeepSSMJob::ToolMode::DeepSSM_TrainingType:
62+
case DeepSSMJob::JobType::DeepSSM_TrainingType:
6363
return "DeepSSM: Training";
6464
break;
65-
case DeepSSMJob::ToolMode::DeepSSM_TestingType:
65+
case DeepSSMJob::JobType::DeepSSM_TestingType:
6666
return "DeepSSM: Testing";
6767
break;
6868
}

Libs/Application/DeepSSM/DeepSSMJob.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class DeepSSMJob : public Job {
1616
Q_OBJECT;
1717

1818
public:
19-
enum class ToolMode {
19+
enum class JobType {
2020
DeepSSM_PrepType = 0,
2121
DeepSSM_AugmentationType = 1,
2222
DeepSSM_TrainingType = 2,
@@ -34,7 +34,7 @@ class DeepSSMJob : public Job {
3434

3535
enum class SplitType { TRAIN, VAL, TEST };
3636

37-
DeepSSMJob(std::shared_ptr<Project> project, DeepSSMJob::ToolMode tool_mode,
37+
DeepSSMJob(std::shared_ptr<Project> project, DeepSSMJob::JobType tool_mode,
3838
DeepSSMJob::PrepStep prep_step = DeepSSMJob::NOT_STARTED);
3939
~DeepSSMJob();
4040

@@ -51,13 +51,18 @@ class DeepSSMJob : public Job {
5151

5252
static std::vector<int> get_split(ProjectHandle project, DeepSSMJob::SplitType split_type);
5353

54+
void set_prep_step(DeepSSMJob::PrepStep step) {
55+
std::lock_guard<std::mutex> lock(mutex_);
56+
prep_step_ = step;
57+
}
58+
5459
private:
5560
void update_prep_stage(DeepSSMJob::PrepStep step);
5661
void process_test_results();
5762

5863
std::shared_ptr<Project> project_;
5964

60-
DeepSSMJob::ToolMode tool_mode_;
65+
DeepSSMJob::JobType job_type_;
6166

6267
QString prep_message_;
6368
DeepSSMJob::PrepStep prep_step_{DeepSSMJob::NOT_STARTED};

0 commit comments

Comments
 (0)