Skip to content

Commit c2f00c9

Browse files
authored
Merge pull request #2410 from SCIInstitute/deepssm_command
Deepssm command
2 parents 39fc1b4 + 46c68c9 commit c2f00c9

32 files changed

+480
-243
lines changed

Applications/shapeworks/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ target_include_directories(shapeworks_exe PUBLIC
2323

2424
target_link_libraries(shapeworks_exe
2525
Mesh ${VTK_LIBRARIES} Optimize Utils trimesh2 Particles
26-
pybind11::embed Project Image Groom Analyze
26+
pybind11::embed Project Image Groom Analyze Application
2727
)
2828

2929
message(STATUS "opt libs ${OPTIMIZE_LIBRARIES}")

Applications/shapeworks/Command.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,15 @@ class ParticleSystemCommand : public Command
103103
private:
104104
};
105105

106+
class DeepSSMCommandGroup : public Command
107+
{
108+
public:
109+
const std::string type() override { return "DeepSSM"; }
110+
111+
private:
112+
};
113+
114+
106115
class ShapeworksCommand : public Command
107116
{
108117
public:

Applications/shapeworks/Commands.cpp

Lines changed: 164 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
#include "Commands.h"
22

33
#include <Analyze/Analyze.h>
4+
#include <Application/DeepSSM/DeepSSMJob.h>
5+
#include <Application/Job/PythonWorker.h>
46
#include <Groom/Groom.h>
57
#include <Logging.h>
68
#include <Optimize/Optimize.h>
79
#include <Optimize/OptimizeParameterFile.h>
810
#include <Optimize/OptimizeParameters.h>
11+
#include <Profiling.h>
912
#include <ShapeworksUtils.h>
1013
#include <Utils/StringUtils.h>
1114

15+
#include <QApplication>
1216
#include <boost/filesystem.hpp>
1317

14-
#include <Profiling.h>
15-
1618
namespace shapeworks {
1719

1820
// boilerplate for a command. Copy this to start a new command
@@ -43,8 +45,6 @@ bool Example::execute(const optparse::Values &options, SharedCommandData &shared
4345
}
4446
#endif
4547

46-
47-
4848
///////////////////////////////////////////////////////////////////////////////
4949
// Seed
5050
///////////////////////////////////////////////////////////////////////////////
@@ -331,4 +331,164 @@ bool ConvertProjectCommand::execute(const optparse::Values& options, SharedComma
331331
return false;
332332
}
333333
}
334+
335+
///////////////////////////////////////////////////////////////////////////////
336+
// DeepSSM
337+
///////////////////////////////////////////////////////////////////////////////
338+
void DeepSSMCommand::buildParser() {
339+
const std::string prog = "deepssm";
340+
const std::string desc = "run deepssm steps";
341+
parser.prog(prog).description(desc);
342+
343+
parser.add_option("--name").action("store").type("string").set_default("").help(
344+
"Path to input project file (xlsx or swproj).");
345+
346+
// Create a vector of choices first
347+
std::vector<std::string> prep_choices = {"all", "groom_training", "optimize_training", "optimize_validation",
348+
"groom_images"};
349+
350+
// --prep option with choices
351+
parser.add_option("--prep")
352+
.action("store")
353+
.type("choice")
354+
.choices(prep_choices.begin(), prep_choices.end())
355+
//.set_default("all")
356+
.help("Preparation step to run");
357+
358+
// Boolean flag options
359+
parser.add_option("--augment").action("store_true").help("Run data augmentation");
360+
361+
parser.add_option("--train").action("store_true").help("Run training");
362+
363+
parser.add_option("--test").action("store_true").help("Run testing");
364+
365+
parser.add_option("--all").action("store_true").help("Run all steps");
366+
367+
Command::buildParser();
368+
}
369+
370+
bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& sharedData) {
371+
// Create a non-gui QApplication instance
372+
int argc = 3;
373+
char* argv[3];
374+
argv[0] = const_cast<char*>("shapeworks");
375+
argv[1] = const_cast<char*>("-platform");
376+
argv[2] = const_cast<char*>("offscreen");
377+
378+
QApplication app(argc, argv);
379+
380+
// Handle project file: either from --name or first positional argument
381+
std::string project_file;
382+
if (options.is_set_by_user("name")) {
383+
// User explicitly provided --name
384+
project_file = options["name"];
385+
} else if (!parser.args().empty()) {
386+
// Use first positional argument
387+
project_file = parser.args()[0];
388+
} else {
389+
// No project file provided at all
390+
parser.error("Project file must be provided either as --name or as a positional argument");
391+
}
392+
393+
// Handle prep option with manual default
394+
std::string prep_step;
395+
if (options.is_set_by_user("prep")) {
396+
prep_step = options["prep"];
397+
} else {
398+
prep_step = "all"; // Manual default
399+
}
400+
401+
std::cout << "DeepSSM: Using project file: " << project_file << std::endl;
402+
403+
bool do_prep = options.is_set("prep") || options.is_set("all");
404+
bool do_augment = options.is_set("augment") || options.is_set("all");
405+
bool do_train = options.is_set("train") || options.is_set("all");
406+
bool do_test = options.is_set("test") || options.is_set("all");
407+
408+
std::cout << "Prep step: " << (do_prep ? "on" : "off") << "\n";
409+
std::cout << "Augment step: " << (do_augment ? "on" : "off") << "\n";
410+
std::cout << "Train step: " << (do_train ? "on" : "off") << "\n";
411+
std::cout << "Test step: " << (do_test ? "on" : "off") << "\n";
412+
413+
if (!do_prep && !do_augment && !do_train && !do_test) {
414+
do_prep = true;
415+
do_augment = true;
416+
do_train = true;
417+
do_test = true;
418+
}
419+
420+
ProjectHandle project = std::make_shared<Project>();
421+
project->load(project_file);
422+
423+
PythonWorker python_worker;
424+
python_worker.set_cli_mode(true);
425+
426+
auto wait_for_job = [&](auto job) {
427+
// This lambda will block until the job is complete
428+
while (!job->is_complete()) {
429+
QCoreApplication::processEvents();
430+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
431+
if (job->is_aborted()) {
432+
return false;
433+
}
434+
}
435+
return true;
436+
};
437+
438+
if (do_prep) {
439+
auto job = QSharedPointer<DeepSSMJob>::create(project, DeepSSMJob::JobType::DeepSSM_PrepType);
440+
if (prep_step == "all") {
441+
job->set_prep_step(DeepSSMJob::PrepStep::NOT_STARTED);
442+
} else if (prep_step == "groom_training") {
443+
job->set_prep_step(DeepSSMJob::PrepStep::GROOM_TRAINING);
444+
} else if (prep_step == "optimize_training") {
445+
job->set_prep_step(DeepSSMJob::PrepStep::OPTIMIZE_TRAINING);
446+
} else if (prep_step == "optimize_validation") {
447+
job->set_prep_step(DeepSSMJob::PrepStep::OPTIMIZE_VALIDATION);
448+
} else if (prep_step == "groom_images") {
449+
job->set_prep_step(DeepSSMJob::PrepStep::GROOM_IMAGES);
450+
} else {
451+
SW_ERROR("Unknown prep step: {}", prep_step);
452+
return false;
453+
}
454+
std::cout << "Running DeepSSM preparation step...\n";
455+
python_worker.run_job(job);
456+
if (!wait_for_job(job)) {
457+
return false;
458+
}
459+
std::cout << "DeepSSM preparation step completed.\n";
460+
}
461+
if (do_augment) {
462+
std::cout << "Running DeepSSM data augmentation...\n";
463+
auto job = QSharedPointer<DeepSSMJob>::create(project, DeepSSMJob::JobType::DeepSSM_AugmentationType);
464+
python_worker.run_job(job);
465+
if (!wait_for_job(job)) {
466+
return false;
467+
}
468+
std::cout << "DeepSSM data augmentation completed.\n";
469+
}
470+
if (do_train) {
471+
std::cout << "Running DeepSSM training...\n";
472+
auto job = QSharedPointer<DeepSSMJob>::create(project, DeepSSMJob::JobType::DeepSSM_TrainingType);
473+
python_worker.run_job(job);
474+
if (!wait_for_job(job)) {
475+
return false;
476+
}
477+
std::cout << "DeepSSM training completed.\n";
478+
}
479+
if (do_test) {
480+
std::cout << "Running DeepSSM testing...\n";
481+
auto job = QSharedPointer<DeepSSMJob>::create(project, DeepSSMJob::JobType::DeepSSM_TestingType);
482+
python_worker.run_job(job);
483+
if (!wait_for_job(job)) {
484+
return false;
485+
}
486+
std::cout << "DeepSSM testing completed.\n";
487+
}
488+
489+
project->save();
490+
491+
return false;
492+
}
493+
334494
} // namespace shapeworks

Applications/shapeworks/Commands.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,5 +101,6 @@ COMMAND_DECLARE(OptimizeCommand, OptimizeCommandGroup);
101101
COMMAND_DECLARE(GroomCommand, GroomCommandGroup);
102102
COMMAND_DECLARE(AnalyzeCommand, AnalyzeCommandGroup);
103103
COMMAND_DECLARE(ConvertProjectCommand, ProjectCommandGroup);
104+
COMMAND_DECLARE(DeepSSMCommand, DeepSSMCommandGroup);
104105

105106
} // shapeworks

Applications/shapeworks/shapeworks.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ int main(int argc, char *argv[])
110110
shapeworks.addCommand(GroomCommand::getCommand());
111111
shapeworks.addCommand(AnalyzeCommand::getCommand());
112112
shapeworks.addCommand(ConvertProjectCommand::getCommand());
113+
shapeworks.addCommand(DeepSSMCommand::getCommand());
113114

114115
try {
115116
TIME_START("shapeworks");

Libs/Application/CMakeLists.txt

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
SET(APPLICATION_MOC_HDRS
2+
DeepSSM/DeepSSMJob.h
3+
Job/Job.h
4+
Job/PythonWorker.h
5+
ShapeWorksVtkOutputWindow.h
6+
)
7+
8+
qt5_wrap_cpp( APPLICATION_MOC_SRCS ${APPLICATION_MOC_HDRS} )
9+
10+
SET(Application_headers
11+
)
12+
13+
add_library(Application STATIC
14+
DeepSSM/DeepSSMJob.cpp
15+
Job/Job.cpp
16+
Job/PythonWorker.cpp
17+
ShapeWorksVtkOutputWindow.cpp
18+
${APPLICATION_MOC_SRCS}
19+
)
20+
21+
target_include_directories(Application PUBLIC
22+
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
23+
$<INSTALL_INTERFACE:include>)
24+
25+
set(SW_PYTHON_LIBS pybind11::embed)
26+
27+
if (APPLE)
28+
include_directories(${_Python3_INCLUDE_DIR})
29+
set(SW_PYTHON_LIBS "")
30+
endif(APPLE)
31+
32+
target_link_libraries(Application PUBLIC
33+
Groom
34+
Mesh
35+
Utils
36+
Particles
37+
Project
38+
${SW_PYTHON_LIBS}
39+
)
40+
41+
# set
42+
set_target_properties(Application PROPERTIES PUBLIC_HEADER
43+
"${Application_headers}")
44+
45+
install(TARGETS Application EXPORT ShapeWorksTargets
46+
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
47+
RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}
48+
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
49+
PUBLIC_HEADER DESTINATION include/Application
50+
)

0 commit comments

Comments
 (0)