|
1 | 1 | #include "Commands.h" |
2 | 2 |
|
3 | 3 | #include <Analyze/Analyze.h> |
| 4 | +#include <Application/DeepSSM/DeepSSMJob.h> |
4 | 5 | #include <Groom/Groom.h> |
5 | 6 | #include <Logging.h> |
6 | 7 | #include <Optimize/Optimize.h> |
7 | 8 | #include <Optimize/OptimizeParameterFile.h> |
8 | 9 | #include <Optimize/OptimizeParameters.h> |
| 10 | +#include <Profiling.h> |
9 | 11 | #include <ShapeworksUtils.h> |
10 | 12 | #include <Utils/StringUtils.h> |
11 | 13 |
|
12 | 14 | #include <boost/filesystem.hpp> |
13 | 15 |
|
14 | | -#include <Profiling.h> |
15 | | - |
16 | 16 | namespace shapeworks { |
17 | 17 |
|
18 | 18 | // boilerplate for a command. Copy this to start a new command |
@@ -43,8 +43,6 @@ bool Example::execute(const optparse::Values &options, SharedCommandData &shared |
43 | 43 | } |
44 | 44 | #endif |
45 | 45 |
|
46 | | - |
47 | | - |
48 | 46 | /////////////////////////////////////////////////////////////////////////////// |
49 | 47 | // Seed |
50 | 48 | /////////////////////////////////////////////////////////////////////////////// |
@@ -331,4 +329,110 @@ bool ConvertProjectCommand::execute(const optparse::Values& options, SharedComma |
331 | 329 | return false; |
332 | 330 | } |
333 | 331 | } |
| 332 | + |
| 333 | +/////////////////////////////////////////////////////////////////////////////// |
| 334 | +// DeepSSM |
| 335 | +/////////////////////////////////////////////////////////////////////////////// |
| 336 | +void DeepSSMCommand::buildParser() { |
| 337 | + const std::string prog = "deepssm"; |
| 338 | + const std::string desc = "run deepssm steps"; |
| 339 | + parser.prog(prog).description(desc); |
| 340 | + |
| 341 | + parser.add_option("--name").action("store").type("string").set_default("").help( |
| 342 | + "Path to input project file (xlsx or swproj)."); |
| 343 | + |
| 344 | + // Create a vector of choices first |
| 345 | + std::vector<std::string> prep_choices = {"all", "groom_training", "optimize_training", "optimize_validation", |
| 346 | + "groom_images"}; |
| 347 | + |
| 348 | + // --prep option with choices |
| 349 | + parser.add_option("--prep") |
| 350 | + .action("store") |
| 351 | + .type("choice") |
| 352 | + .choices(prep_choices.begin(), prep_choices.end()) |
| 353 | + .set_default("all") |
| 354 | + .help("Preparation step to run"); |
| 355 | + |
| 356 | + // Boolean flag options |
| 357 | + parser.add_option("--augment").action("store_true").help("Run data augmentation"); |
| 358 | + |
| 359 | + parser.add_option("--train").action("store_true").help("Run training"); |
| 360 | + |
| 361 | + parser.add_option("--test").action("store_true").help("Run testing"); |
| 362 | + |
| 363 | + parser.add_option("--all").action("store_true").help("Run all steps"); |
| 364 | + |
| 365 | + Command::buildParser(); |
| 366 | +} |
| 367 | + |
| 368 | +bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& sharedData) { |
| 369 | + // Handle project file: either from --name or first positional argument |
| 370 | + std::string project_file; |
| 371 | + if (options.is_set_by_user("name")) { |
| 372 | + // User explicitly provided --name |
| 373 | + project_file = options["name"]; |
| 374 | + } else if (!parser.args().empty()) { |
| 375 | + // Use first positional argument |
| 376 | + project_file = parser.args()[0]; |
| 377 | + } else { |
| 378 | + // No project file provided at all |
| 379 | + parser.error("Project file must be provided either as --name or as a positional argument"); |
| 380 | + } |
| 381 | + |
| 382 | + std::cout << "DeepSSM: Using project file: " << project_file << std::endl; |
| 383 | + |
| 384 | + bool do_prep = options.is_set("prep") || options.is_set("all"); |
| 385 | + std::string prep_step = options["prep"]; |
| 386 | + bool do_augment = options.is_set("augment") || options.is_set("all"); |
| 387 | + bool do_train = options.is_set("train") || options.is_set("all"); |
| 388 | + bool do_test = options.is_set("test") || options.is_set("all"); |
| 389 | + if (!do_prep && !do_augment && !do_train && !do_test) { |
| 390 | + do_prep = true; |
| 391 | + do_augment = true; |
| 392 | + do_train = true; |
| 393 | + do_test = true; |
| 394 | + } |
| 395 | + |
| 396 | + ProjectHandle project = std::make_shared<Project>(); |
| 397 | + project->load(project_file); |
| 398 | + |
| 399 | + DeepSSMJob job(project, DeepSSMJob::JobType::DeepSSM_PrepType); |
| 400 | + |
| 401 | + if (do_prep) { |
| 402 | + if (prep_step == "all") { |
| 403 | + job.set_prep_step(DeepSSMJob::PrepStep::NOT_STARTED); |
| 404 | + } else if (prep_step == "groom_training") { |
| 405 | + job.set_prep_step(DeepSSMJob::PrepStep::GROOM_TRAINING); |
| 406 | + } else if (prep_step == "optimize_training") { |
| 407 | + job.set_prep_step(DeepSSMJob::PrepStep::OPTIMIZE_TRAINING); |
| 408 | + } else if (prep_step == "optimize_validation") { |
| 409 | + job.set_prep_step(DeepSSMJob::PrepStep::OPTIMIZE_VALIDATION); |
| 410 | + } else if (prep_step == "groom_images") { |
| 411 | + job.set_prep_step(DeepSSMJob::PrepStep::GROOM_IMAGES); |
| 412 | + } else { |
| 413 | + SW_ERROR("Unknown prep step: {}", prep_step); |
| 414 | + return false; |
| 415 | + } |
| 416 | + std::cout << "Running DeepSSM preparation step...\n"; |
| 417 | + job.run_prep(); |
| 418 | + } |
| 419 | + if (do_augment) { |
| 420 | + std::cout << "Running DeepSSM data augmentation...\n"; |
| 421 | + job.run_augmentation(); |
| 422 | + } |
| 423 | + if (do_train) { |
| 424 | + std::cout << "Running DeepSSM training...\n"; |
| 425 | + job.run_training(); |
| 426 | + } |
| 427 | + if (do_test) { |
| 428 | + std::cout << "Running DeepSSM testing...\n"; |
| 429 | + job.run_testing(); |
| 430 | + } |
| 431 | + |
| 432 | + project->save(); |
| 433 | + |
| 434 | + SW_ERROR("DeepSSM command is not implemented yet."); |
| 435 | + return false; |
| 436 | +} |
| 437 | + |
334 | 438 | } // namespace shapeworks |
0 commit comments