diff --git a/Applications/shapeworks/Commands.cpp b/Applications/shapeworks/Commands.cpp index f595c5837e..da59fe0c67 100644 --- a/Applications/shapeworks/Commands.cpp +++ b/Applications/shapeworks/Commands.cpp @@ -364,6 +364,13 @@ void DeepSSMCommand::buildParser() { parser.add_option("--all").action("store_true").help("Run all steps"); + // add num_workers option + parser.add_option("--num_workers") + .action("store") + .type("int") + .set_default(0) + .help("Number of data loader workers (default: 0)"); + Command::buildParser(); } @@ -405,10 +412,13 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& bool do_train = options.is_set("train") || options.is_set("all"); bool do_test = options.is_set("test") || options.is_set("all"); + int num_workers = static_cast(options.get("num_workers")); + std::cout << "Prep step: " << (do_prep ? "on" : "off") << "\n"; std::cout << "Augment step: " << (do_augment ? "on" : "off") << "\n"; std::cout << "Train step: " << (do_train ? "on" : "off") << "\n"; std::cout << "Test step: " << (do_test ? "on" : "off") << "\n"; + std::cout << "Num dataloader workers: " << num_workers << "\n"; if (!do_prep && !do_augment && !do_train && !do_test) { do_prep = true; @@ -437,6 +447,7 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& if (do_prep) { auto job = QSharedPointer::create(project, DeepSSMJob::JobType::DeepSSM_PrepType); + job->set_num_dataloader_workers(num_workers); if (prep_step == "all") { job->set_prep_step(DeepSSMJob::PrepStep::NOT_STARTED); } else if (prep_step == "groom_training") { diff --git a/Libs/Application/DeepSSM/DeepSSMJob.cpp b/Libs/Application/DeepSSM/DeepSSMJob.cpp index 7de61cd643..ec47638877 100644 --- a/Libs/Application/DeepSSM/DeepSSMJob.cpp +++ b/Libs/Application/DeepSSM/DeepSSMJob.cpp @@ -11,13 +11,14 @@ using namespace pybind11::literals; // to bring in the `_a` literal #include // shapeworks -#include "DeepSSMJob.h" -#include #include #include #include #include #include +#include + +#include "DeepSSMJob.h" namespace shapeworks { @@ -274,8 +275,8 @@ void DeepSSMJob::run_training() { py::module py_deep_ssm_utils = py::module::import("DeepSSMUtils"); py::object prepare_data_loaders = py_deep_ssm_utils.attr("prepare_data_loaders"); - prepare_data_loaders(project_, batch_size, "train"); - prepare_data_loaders(project_, batch_size, "val"); + prepare_data_loaders(project_, batch_size, "train", num_dataloader_workers_); + prepare_data_loaders(project_, batch_size, "val", num_dataloader_workers_); std::string out_dir = "deepssm/"; std::string aug_dir = out_dir + "augmentation/"; @@ -387,6 +388,12 @@ std::vector DeepSSMJob::get_split(ProjectHandle project, SplitType split_ty return list; } +//--------------------------------------------------------------------------- +void DeepSSMJob::set_num_dataloader_workers(int num_workers) { num_dataloader_workers_ = num_workers; } + +//--------------------------------------------------------------------------- +int DeepSSMJob::get_num_dataloader_workers() { return num_dataloader_workers_; } + //--------------------------------------------------------------------------- void DeepSSMJob::update_prep_stage(PrepStep step) { /* diff --git a/Libs/Application/DeepSSM/DeepSSMJob.h b/Libs/Application/DeepSSM/DeepSSMJob.h index d7bfe2025f..b24ba753ec 100644 --- a/Libs/Application/DeepSSM/DeepSSMJob.h +++ b/Libs/Application/DeepSSM/DeepSSMJob.h @@ -52,6 +52,9 @@ class DeepSSMJob : public Job { static std::vector get_split(ProjectHandle project, DeepSSMJob::SplitType split_type); + void set_num_dataloader_workers(int num_workers); + int get_num_dataloader_workers(); + void set_prep_step(DeepSSMJob::PrepStep step) { std::lock_guard lock(mutex_); prep_step_ = step; @@ -68,6 +71,8 @@ class DeepSSMJob : public Job { QString prep_message_; DeepSSMJob::PrepStep prep_step_{DeepSSMJob::NOT_STARTED}; + int num_dataloader_workers_{0}; + // mutex std::mutex mutex_; }; diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py index 306904bf42..135f4f0a05 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py @@ -16,22 +16,22 @@ import torch -def getTrainValLoaders(loader_dir, aug_data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80): +def getTrainValLoaders(loader_dir, aug_data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80, num_workers=0): testPytorch() - loaders.get_train_val_loaders(loader_dir, aug_data_csv, batch_size, down_factor, down_dir, train_split) + loaders.get_train_val_loaders(loader_dir, aug_data_csv, batch_size, down_factor, down_dir, train_split, num_workers) -def getTrainLoader(loader_dir, aug_data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80): +def getTrainLoader(loader_dir, aug_data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80, num_workers=0): testPytorch() - loaders.get_train_loader(loader_dir, aug_data_csv, batch_size, down_factor, down_dir, train_split) + loaders.get_train_loader(loader_dir, aug_data_csv, batch_size, down_factor, down_dir, train_split, num_workers) -def getValidationLoader(loader_dir, val_img_list, val_particles, down_factor=1, down_dir=None): - loaders.get_validation_loader(loader_dir, val_img_list, val_particles, down_factor, down_dir) +def getValidationLoader(loader_dir, val_img_list, val_particles, down_factor=1, down_dir=None, num_workers=0): + loaders.get_validation_loader(loader_dir, val_img_list, val_particles, down_factor, down_dir, num_workers) -def getTestLoader(loader_dir, test_img_list, down_factor=1, down_dir=None): - loaders.get_test_loader(loader_dir, test_img_list, down_factor, down_dir) +def getTestLoader(loader_dir, test_img_list, down_factor=1, down_dir=None, num_workers=0): + loaders.get_test_loader(loader_dir, test_img_list, down_factor, down_dir, num_workers) def prepareConfigFile(config_filename, model_name, embedded_dim, out_dir, loader_dir, aug_dir, epochs, learning_rate, diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py index 0c20fae98f..5573b7e9db 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py @@ -23,7 +23,7 @@ def make_dir(dirPath): ''' Reads csv and makes both train and validation data loaders from it ''' -def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80): +def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80, num_workers=0): sw_message("Creating training and validation torch loaders:") make_dir(loader_dir) images, scores, models, prefixes = get_all_train_data(loader_dir, data_csv, down_factor, down_dir) @@ -41,7 +41,7 @@ def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, dow train_data, batch_size=batch_size, shuffle=True, - num_workers=8, + num_workers=num_workers, pin_memory=torch.cuda.is_available() ) train_path = loader_dir + 'train' @@ -51,7 +51,7 @@ def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, dow val_data, batch_size=1, shuffle=True, - num_workers=8, + num_workers=num_workers, pin_memory=torch.cuda.is_available() ) val_path = loader_dir + 'validation' @@ -62,7 +62,7 @@ def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, dow ''' Reads csv and makes just train data loaders ''' -def get_train_loader(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80): +def get_train_loader(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80, num_workers=0): sw_message("Creating training torch loader...") # Get data make_dir(loader_dir) @@ -74,7 +74,7 @@ def get_train_loader(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir train_data, batch_size=batch_size, shuffle=True, - num_workers=8, + num_workers=num_workers, pin_memory=torch.cuda.is_available() ) train_path = loader_dir + 'train' @@ -85,7 +85,7 @@ def get_train_loader(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir ''' Makes validation data loader ''' -def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1, down_dir=None): +def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1, down_dir=None, num_workers=0): sw_message("Creating validation torch loader:") # Get data image_paths = [] @@ -113,7 +113,7 @@ def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1 val_data, batch_size=1, shuffle=False, - num_workers=8, + num_workers=num_workers, pin_memory=torch.cuda.is_available() ) val_path = loader_dir + 'validation' @@ -124,7 +124,7 @@ def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1 ''' Makes test data loader ''' -def get_test_loader(loader_dir, test_img_list, down_factor=1, down_dir=None): +def get_test_loader(loader_dir, test_img_list, down_factor=1, down_dir=None, num_workers=0): sw_message("Creating test torch loader...") # get data image_paths = [] @@ -152,7 +152,7 @@ def get_test_loader(loader_dir, test_img_list, down_factor=1, down_dir=None): test_data, batch_size=1, shuffle=False, - num_workers=8, + num_workers=num_workers, pin_memory=torch.cuda.is_available() ) test_path = loader_dir + 'test' diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py index 18a48e8a4a..7de1bd1e2c 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py @@ -443,7 +443,7 @@ def groom_val_test_images(project, indices): project.set_subjects(subjects) -def prepare_data_loaders(project, batch_size, split="all"): +def prepare_data_loaders(project, batch_size, split="all", num_workers=0): """ Prepare PyTorch laoders """ deepssm_dir = get_deepssm_dir(project) loader_dir = deepssm_dir + 'torch_loaders/' @@ -458,19 +458,19 @@ def prepare_data_loaders(project, batch_size, split="all"): val_image_files.append(deepssm_dir + f"/val_and_test_images/{i}.nrrd") particle_file = project.get_subjects()[i].get_world_particle_filenames()[0] val_world_particles.append(particle_file) - DeepSSMUtils.getValidationLoader(loader_dir, val_image_files, val_world_particles) + DeepSSMUtils.getValidationLoader(loader_dir, val_image_files, val_world_particles, num_workers=num_workers) if split == "all" or split == "train": aug_dir = deepssm_dir + "augmentation/" aug_data_csv = aug_dir + "TotalData.csv" - DeepSSMUtils.getTrainLoader(loader_dir, aug_data_csv, batch_size) + DeepSSMUtils.getTrainLoader(loader_dir, aug_data_csv, batch_size, num_workers=num_workers) if split == "all" or split == "test": test_image_files = [] test_indices = get_split_indices(project, "test") for i in test_indices: test_image_files.append(deepssm_dir + f"/val_and_test_images/{i}.nrrd") - DeepSSMUtils.getTestLoader(loader_dir, test_image_files) + DeepSSMUtils.getTestLoader(loader_dir, test_image_files, num_workers=num_workers) def get_test_alignment_transform(project, index): diff --git a/Studio/Data/Preferences.cpp b/Studio/Data/Preferences.cpp index 64a2481dad..8f1da93a96 100644 --- a/Studio/Data/Preferences.cpp +++ b/Studio/Data/Preferences.cpp @@ -1,5 +1,6 @@ #include #include +#include // std #include @@ -73,6 +74,21 @@ void Preferences::set_num_threads(int num_threads) { update_threads(); } +//----------------------------------------------------------------------------- +int Preferences::get_dataloader_num_workers() { + int default_workers = 0; + // if mac or linux, set to 4 + if (shapeworks::PlatformUtils::is_macos() || shapeworks::PlatformUtils::is_linux()) { + default_workers = 4; + } + return settings_.value("Studio/dataloader_num_workers", default_workers).toInt(); +} + +//----------------------------------------------------------------------------- +void Preferences::set_dataloader_num_workers(int num_workers) { + settings_.setValue("Studio/dataloader_num_workers", num_workers); +} + //----------------------------------------------------------------------------- float Preferences::get_glyph_size() { return settings_.value("Project/glyph_size", 5.0).toFloat(); } diff --git a/Studio/Data/Preferences.h b/Studio/Data/Preferences.h index 573d60a8ce..ed60fb17da 100644 --- a/Studio/Data/Preferences.h +++ b/Studio/Data/Preferences.h @@ -47,6 +47,9 @@ class Preferences : public QObject { int get_num_threads(); void set_num_threads(int num_threads); + int get_dataloader_num_workers(); + void set_dataloader_num_workers(int num_workers); + float get_glyph_size(); void set_glyph_size(float value); diff --git a/Studio/Data/PreferencesWindow.cpp b/Studio/Data/PreferencesWindow.cpp index 712c6b033b..10ce75cf18 100644 --- a/Studio/Data/PreferencesWindow.cpp +++ b/Studio/Data/PreferencesWindow.cpp @@ -116,6 +116,7 @@ void PreferencesWindow::set_values_from_preferences() { ui_->geodesic_cache_multiplier->setValue(preferences_.get_geodesic_cache_multiplier()); ui_->auto_update_checkbox->setChecked(preferences_.get_auto_update_check()); ui_->telemetry_enabled->setChecked(preferences_.get_telemetry_enabled()); + ui_->data_loader_num_workers->setText(QString::number(preferences_.get_dataloader_num_workers())); update_labels(); } @@ -147,6 +148,7 @@ void PreferencesWindow::save_to_preferences() { preferences_.set_reverse_color_map(ui_->reverse_color_map->isChecked()); preferences_.set_auto_update_check(ui_->auto_update_checkbox->isChecked()); preferences_.set_telemetry_enabled(ui_->telemetry_enabled->isChecked()); + preferences_.set_dataloader_num_workers(ui_->data_loader_num_workers->text().toInt()); update_labels(); Q_EMIT update_view(); } diff --git a/Studio/Data/PreferencesWindow.ui b/Studio/Data/PreferencesWindow.ui index 1656c68b10..9c79973764 100644 --- a/Studio/Data/PreferencesWindow.ui +++ b/Studio/Data/PreferencesWindow.ui @@ -7,7 +7,7 @@ 0 0 756 - 668 + 711 @@ -364,7 +364,24 @@ Parallel Processing - + + + + DeepSSM data loader workers + + + + + + + 0 + + + Qt::AlignCenter + + + + diff --git a/Studio/DeepSSM/DeepSSMTool.cpp b/Studio/DeepSSM/DeepSSMTool.cpp index 4a1ba2abe0..39ff1e6fa3 100644 --- a/Studio/DeepSSM/DeepSSMTool.cpp +++ b/Studio/DeepSSM/DeepSSMTool.cpp @@ -918,6 +918,7 @@ void DeepSSMTool::run_tool(DeepSSMJob::JobType job_type) { store_params(); deep_ssm_ = QSharedPointer::create(session_->get_project(), job_type, prep_step_); + deep_ssm_->set_num_dataloader_workers(preferences_.get_dataloader_num_workers()); connect(deep_ssm_.data(), &DeepSSMJob::progress, this, &DeepSSMTool::handle_progress); connect(deep_ssm_.data(), &DeepSSMJob::finished, this, &DeepSSMTool::handle_thread_complete);