|
1 | 1 | #include "Commands.h" |
2 | 2 |
|
3 | 3 | #include <Analyze/Analyze.h> |
| 4 | +#include <Application/DeepSSM/DeepSSMJob.h> |
| 5 | +#include <Application/Job/PythonWorker.h> |
4 | 6 | #include <Groom/Groom.h> |
5 | 7 | #include <Logging.h> |
6 | 8 | #include <Optimize/Optimize.h> |
7 | 9 | #include <Optimize/OptimizeParameterFile.h> |
8 | 10 | #include <Optimize/OptimizeParameters.h> |
| 11 | +#include <Profiling.h> |
9 | 12 | #include <ShapeworksUtils.h> |
10 | 13 | #include <Utils/StringUtils.h> |
11 | 14 |
|
| 15 | +#include <QApplication> |
12 | 16 | #include <boost/filesystem.hpp> |
13 | 17 |
|
14 | | -#include <Profiling.h> |
15 | | - |
16 | 18 | namespace shapeworks { |
17 | 19 |
|
18 | 20 | // boilerplate for a command. Copy this to start a new command |
@@ -43,8 +45,6 @@ bool Example::execute(const optparse::Values &options, SharedCommandData &shared |
43 | 45 | } |
44 | 46 | #endif |
45 | 47 |
|
46 | | - |
47 | | - |
48 | 48 | /////////////////////////////////////////////////////////////////////////////// |
49 | 49 | // Seed |
50 | 50 | /////////////////////////////////////////////////////////////////////////////// |
@@ -331,4 +331,164 @@ bool ConvertProjectCommand::execute(const optparse::Values& options, SharedComma |
331 | 331 | return false; |
332 | 332 | } |
333 | 333 | } |
| 334 | + |
| 335 | +/////////////////////////////////////////////////////////////////////////////// |
| 336 | +// DeepSSM |
| 337 | +/////////////////////////////////////////////////////////////////////////////// |
| 338 | +void DeepSSMCommand::buildParser() { |
| 339 | + const std::string prog = "deepssm"; |
| 340 | + const std::string desc = "run deepssm steps"; |
| 341 | + parser.prog(prog).description(desc); |
| 342 | + |
| 343 | + parser.add_option("--name").action("store").type("string").set_default("").help( |
| 344 | + "Path to input project file (xlsx or swproj)."); |
| 345 | + |
| 346 | + // Create a vector of choices first |
| 347 | + std::vector<std::string> prep_choices = {"all", "groom_training", "optimize_training", "optimize_validation", |
| 348 | + "groom_images"}; |
| 349 | + |
| 350 | + // --prep option with choices |
| 351 | + parser.add_option("--prep") |
| 352 | + .action("store") |
| 353 | + .type("choice") |
| 354 | + .choices(prep_choices.begin(), prep_choices.end()) |
| 355 | + //.set_default("all") |
| 356 | + .help("Preparation step to run"); |
| 357 | + |
| 358 | + // Boolean flag options |
| 359 | + parser.add_option("--augment").action("store_true").help("Run data augmentation"); |
| 360 | + |
| 361 | + parser.add_option("--train").action("store_true").help("Run training"); |
| 362 | + |
| 363 | + parser.add_option("--test").action("store_true").help("Run testing"); |
| 364 | + |
| 365 | + parser.add_option("--all").action("store_true").help("Run all steps"); |
| 366 | + |
| 367 | + Command::buildParser(); |
| 368 | +} |
| 369 | + |
| 370 | +bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& sharedData) { |
| 371 | + // Create a non-gui QApplication instance |
| 372 | + int argc = 3; |
| 373 | + char* argv[3]; |
| 374 | + argv[0] = const_cast<char*>("shapeworks"); |
| 375 | + argv[1] = const_cast<char*>("-platform"); |
| 376 | + argv[2] = const_cast<char*>("offscreen"); |
| 377 | + |
| 378 | + QApplication app(argc, argv); |
| 379 | + |
| 380 | + // Handle project file: either from --name or first positional argument |
| 381 | + std::string project_file; |
| 382 | + if (options.is_set_by_user("name")) { |
| 383 | + // User explicitly provided --name |
| 384 | + project_file = options["name"]; |
| 385 | + } else if (!parser.args().empty()) { |
| 386 | + // Use first positional argument |
| 387 | + project_file = parser.args()[0]; |
| 388 | + } else { |
| 389 | + // No project file provided at all |
| 390 | + parser.error("Project file must be provided either as --name or as a positional argument"); |
| 391 | + } |
| 392 | + |
| 393 | + // Handle prep option with manual default |
| 394 | + std::string prep_step; |
| 395 | + if (options.is_set_by_user("prep")) { |
| 396 | + prep_step = options["prep"]; |
| 397 | + } else { |
| 398 | + prep_step = "all"; // Manual default |
| 399 | + } |
| 400 | + |
| 401 | + std::cout << "DeepSSM: Using project file: " << project_file << std::endl; |
| 402 | + |
| 403 | + bool do_prep = options.is_set("prep") || options.is_set("all"); |
| 404 | + bool do_augment = options.is_set("augment") || options.is_set("all"); |
| 405 | + bool do_train = options.is_set("train") || options.is_set("all"); |
| 406 | + bool do_test = options.is_set("test") || options.is_set("all"); |
| 407 | + |
| 408 | + std::cout << "Prep step: " << (do_prep ? "on" : "off") << "\n"; |
| 409 | + std::cout << "Augment step: " << (do_augment ? "on" : "off") << "\n"; |
| 410 | + std::cout << "Train step: " << (do_train ? "on" : "off") << "\n"; |
| 411 | + std::cout << "Test step: " << (do_test ? "on" : "off") << "\n"; |
| 412 | + |
| 413 | + if (!do_prep && !do_augment && !do_train && !do_test) { |
| 414 | + do_prep = true; |
| 415 | + do_augment = true; |
| 416 | + do_train = true; |
| 417 | + do_test = true; |
| 418 | + } |
| 419 | + |
| 420 | + ProjectHandle project = std::make_shared<Project>(); |
| 421 | + project->load(project_file); |
| 422 | + |
| 423 | + PythonWorker python_worker; |
| 424 | + python_worker.set_cli_mode(true); |
| 425 | + |
| 426 | + auto wait_for_job = [&](auto job) { |
| 427 | + // This lambda will block until the job is complete |
| 428 | + while (!job->is_complete()) { |
| 429 | + QCoreApplication::processEvents(); |
| 430 | + std::this_thread::sleep_for(std::chrono::milliseconds(100)); |
| 431 | + if (job->is_aborted()) { |
| 432 | + return false; |
| 433 | + } |
| 434 | + } |
| 435 | + return true; |
| 436 | + }; |
| 437 | + |
| 438 | + if (do_prep) { |
| 439 | + auto job = QSharedPointer<DeepSSMJob>::create(project, DeepSSMJob::JobType::DeepSSM_PrepType); |
| 440 | + if (prep_step == "all") { |
| 441 | + job->set_prep_step(DeepSSMJob::PrepStep::NOT_STARTED); |
| 442 | + } else if (prep_step == "groom_training") { |
| 443 | + job->set_prep_step(DeepSSMJob::PrepStep::GROOM_TRAINING); |
| 444 | + } else if (prep_step == "optimize_training") { |
| 445 | + job->set_prep_step(DeepSSMJob::PrepStep::OPTIMIZE_TRAINING); |
| 446 | + } else if (prep_step == "optimize_validation") { |
| 447 | + job->set_prep_step(DeepSSMJob::PrepStep::OPTIMIZE_VALIDATION); |
| 448 | + } else if (prep_step == "groom_images") { |
| 449 | + job->set_prep_step(DeepSSMJob::PrepStep::GROOM_IMAGES); |
| 450 | + } else { |
| 451 | + SW_ERROR("Unknown prep step: {}", prep_step); |
| 452 | + return false; |
| 453 | + } |
| 454 | + std::cout << "Running DeepSSM preparation step...\n"; |
| 455 | + python_worker.run_job(job); |
| 456 | + if (!wait_for_job(job)) { |
| 457 | + return false; |
| 458 | + } |
| 459 | + std::cout << "DeepSSM preparation step completed.\n"; |
| 460 | + } |
| 461 | + if (do_augment) { |
| 462 | + std::cout << "Running DeepSSM data augmentation...\n"; |
| 463 | + auto job = QSharedPointer<DeepSSMJob>::create(project, DeepSSMJob::JobType::DeepSSM_AugmentationType); |
| 464 | + python_worker.run_job(job); |
| 465 | + if (!wait_for_job(job)) { |
| 466 | + return false; |
| 467 | + } |
| 468 | + std::cout << "DeepSSM data augmentation completed.\n"; |
| 469 | + } |
| 470 | + if (do_train) { |
| 471 | + std::cout << "Running DeepSSM training...\n"; |
| 472 | + auto job = QSharedPointer<DeepSSMJob>::create(project, DeepSSMJob::JobType::DeepSSM_TrainingType); |
| 473 | + python_worker.run_job(job); |
| 474 | + if (!wait_for_job(job)) { |
| 475 | + return false; |
| 476 | + } |
| 477 | + std::cout << "DeepSSM training completed.\n"; |
| 478 | + } |
| 479 | + if (do_test) { |
| 480 | + std::cout << "Running DeepSSM testing...\n"; |
| 481 | + auto job = QSharedPointer<DeepSSMJob>::create(project, DeepSSMJob::JobType::DeepSSM_TestingType); |
| 482 | + python_worker.run_job(job); |
| 483 | + if (!wait_for_job(job)) { |
| 484 | + return false; |
| 485 | + } |
| 486 | + std::cout << "DeepSSM testing completed.\n"; |
| 487 | + } |
| 488 | + |
| 489 | + project->save(); |
| 490 | + |
| 491 | + return false; |
| 492 | +} |
| 493 | + |
334 | 494 | } // namespace shapeworks |
0 commit comments