Skip to content

Commit c068604

Browse files
committed
Address #2438 by adding num_workers options.
This is a Studio preference so that it's not tied to the project. For the CLI: --num_workers=INT Number of data loader workers (default: 0) Default on windows is 0 since this seems to cause a lot of problems on windows in general. Default is 4 on other plaforms.
1 parent be87c23 commit c068604

File tree

11 files changed

+89
-27
lines changed

11 files changed

+89
-27
lines changed

Applications/shapeworks/Commands.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,13 @@ void DeepSSMCommand::buildParser() {
364364

365365
parser.add_option("--all").action("store_true").help("Run all steps");
366366

367+
// add num_workers option
368+
parser.add_option("--num_workers")
369+
.action("store")
370+
.type("int")
371+
.set_default(0)
372+
.help("Number of data loader workers (default: 0)");
373+
367374
Command::buildParser();
368375
}
369376

@@ -405,10 +412,13 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData&
405412
bool do_train = options.is_set("train") || options.is_set("all");
406413
bool do_test = options.is_set("test") || options.is_set("all");
407414

415+
int num_workers = static_cast<int>(options.get("num_workers"));
416+
408417
std::cout << "Prep step: " << (do_prep ? "on" : "off") << "\n";
409418
std::cout << "Augment step: " << (do_augment ? "on" : "off") << "\n";
410419
std::cout << "Train step: " << (do_train ? "on" : "off") << "\n";
411420
std::cout << "Test step: " << (do_test ? "on" : "off") << "\n";
421+
std::cout << "Num dataloader workers: " << num_workers << "\n";
412422

413423
if (!do_prep && !do_augment && !do_train && !do_test) {
414424
do_prep = true;
@@ -437,6 +447,7 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData&
437447

438448
if (do_prep) {
439449
auto job = QSharedPointer<DeepSSMJob>::create(project, DeepSSMJob::JobType::DeepSSM_PrepType);
450+
job->set_num_dataloader_workers(num_workers);
440451
if (prep_step == "all") {
441452
job->set_prep_step(DeepSSMJob::PrepStep::NOT_STARTED);
442453
} else if (prep_step == "groom_training") {

Libs/Application/DeepSSM/DeepSSMJob.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@ using namespace pybind11::literals; // to bring in the `_a` literal
1111
#include <QThread>
1212

1313
// shapeworks
14-
#include "DeepSSMJob.h"
15-
#include <Project/DeepSSMParameters.h>
1614
#include <Groom.h>
1715
#include <Logging.h>
1816
#include <Mesh/MeshUtils.h>
1917
#include <Optimize.h>
2018
#include <Optimize/OptimizeParameters.h>
19+
#include <Project/DeepSSMParameters.h>
20+
21+
#include "DeepSSMJob.h"
2122

2223
namespace shapeworks {
2324

@@ -274,8 +275,8 @@ void DeepSSMJob::run_training() {
274275
py::module py_deep_ssm_utils = py::module::import("DeepSSMUtils");
275276

276277
py::object prepare_data_loaders = py_deep_ssm_utils.attr("prepare_data_loaders");
277-
prepare_data_loaders(project_, batch_size, "train");
278-
prepare_data_loaders(project_, batch_size, "val");
278+
prepare_data_loaders(project_, batch_size, "train", num_dataloader_workers_);
279+
prepare_data_loaders(project_, batch_size, "val", num_dataloader_workers_);
279280

280281
std::string out_dir = "deepssm/";
281282
std::string aug_dir = out_dir + "augmentation/";
@@ -387,6 +388,12 @@ std::vector<int> DeepSSMJob::get_split(ProjectHandle project, SplitType split_ty
387388
return list;
388389
}
389390

391+
//---------------------------------------------------------------------------
392+
void DeepSSMJob::set_num_dataloader_workers(int num_workers) { num_dataloader_workers_ = num_workers; }
393+
394+
//---------------------------------------------------------------------------
395+
int DeepSSMJob::get_num_dataloader_workers() { return num_dataloader_workers_; }
396+
390397
//---------------------------------------------------------------------------
391398
void DeepSSMJob::update_prep_stage(PrepStep step) {
392399
/*

Libs/Application/DeepSSM/DeepSSMJob.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ class DeepSSMJob : public Job {
5252

5353
static std::vector<int> get_split(ProjectHandle project, DeepSSMJob::SplitType split_type);
5454

55+
void set_num_dataloader_workers(int num_workers);
56+
int get_num_dataloader_workers();
57+
5558
void set_prep_step(DeepSSMJob::PrepStep step) {
5659
std::lock_guard<std::mutex> lock(mutex_);
5760
prep_step_ = step;
@@ -68,6 +71,8 @@ class DeepSSMJob : public Job {
6871
QString prep_message_;
6972
DeepSSMJob::PrepStep prep_step_{DeepSSMJob::NOT_STARTED};
7073

74+
int num_dataloader_workers_{0};
75+
7176
// mutex
7277
std::mutex mutex_;
7378
};

Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,22 @@
1616
import torch
1717

1818

19-
def getTrainValLoaders(loader_dir, aug_data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80):
19+
def getTrainValLoaders(loader_dir, aug_data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80, num_workers=0):
2020
testPytorch()
21-
loaders.get_train_val_loaders(loader_dir, aug_data_csv, batch_size, down_factor, down_dir, train_split)
21+
loaders.get_train_val_loaders(loader_dir, aug_data_csv, batch_size, down_factor, down_dir, train_split, num_workers)
2222

2323

24-
def getTrainLoader(loader_dir, aug_data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80):
24+
def getTrainLoader(loader_dir, aug_data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80, num_workers=0):
2525
testPytorch()
26-
loaders.get_train_loader(loader_dir, aug_data_csv, batch_size, down_factor, down_dir, train_split)
26+
loaders.get_train_loader(loader_dir, aug_data_csv, batch_size, down_factor, down_dir, train_split, num_workers)
2727

2828

29-
def getValidationLoader(loader_dir, val_img_list, val_particles, down_factor=1, down_dir=None):
30-
loaders.get_validation_loader(loader_dir, val_img_list, val_particles, down_factor, down_dir)
29+
def getValidationLoader(loader_dir, val_img_list, val_particles, down_factor=1, down_dir=None, num_workers=0):
30+
loaders.get_validation_loader(loader_dir, val_img_list, val_particles, down_factor, down_dir, num_workers)
3131

3232

33-
def getTestLoader(loader_dir, test_img_list, down_factor=1, down_dir=None):
34-
loaders.get_test_loader(loader_dir, test_img_list, down_factor, down_dir)
33+
def getTestLoader(loader_dir, test_img_list, down_factor=1, down_dir=None, num_workers=0):
34+
loaders.get_test_loader(loader_dir, test_img_list, down_factor, down_dir, num_workers)
3535

3636

3737
def prepareConfigFile(config_filename, model_name, embedded_dim, out_dir, loader_dir, aug_dir, epochs, learning_rate,

Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def make_dir(dirPath):
2323
'''
2424
Reads csv and makes both train and validation data loaders from it
2525
'''
26-
def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80):
26+
def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80, num_workers=0):
2727
sw_message("Creating training and validation torch loaders:")
2828
make_dir(loader_dir)
2929
images, scores, models, prefixes = get_all_train_data(loader_dir, data_csv, down_factor, down_dir)
@@ -41,7 +41,7 @@ def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, dow
4141
train_data,
4242
batch_size=batch_size,
4343
shuffle=True,
44-
num_workers=8,
44+
num_workers=num_workers,
4545
pin_memory=torch.cuda.is_available()
4646
)
4747
train_path = loader_dir + 'train'
@@ -51,7 +51,7 @@ def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, dow
5151
val_data,
5252
batch_size=1,
5353
shuffle=True,
54-
num_workers=8,
54+
num_workers=num_workers,
5555
pin_memory=torch.cuda.is_available()
5656
)
5757
val_path = loader_dir + 'validation'
@@ -62,7 +62,7 @@ def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, dow
6262
'''
6363
Reads csv and makes just train data loaders
6464
'''
65-
def get_train_loader(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80):
65+
def get_train_loader(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80, num_workers=0):
6666
sw_message("Creating training torch loader...")
6767
# Get data
6868
make_dir(loader_dir)
@@ -74,7 +74,7 @@ def get_train_loader(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir
7474
train_data,
7575
batch_size=batch_size,
7676
shuffle=True,
77-
num_workers=8,
77+
num_workers=num_workers,
7878
pin_memory=torch.cuda.is_available()
7979
)
8080
train_path = loader_dir + 'train'
@@ -85,7 +85,7 @@ def get_train_loader(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir
8585
'''
8686
Makes validation data loader
8787
'''
88-
def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1, down_dir=None):
88+
def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1, down_dir=None, num_workers=0):
8989
sw_message("Creating validation torch loader:")
9090
# Get data
9191
image_paths = []
@@ -113,7 +113,7 @@ def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1
113113
val_data,
114114
batch_size=1,
115115
shuffle=False,
116-
num_workers=8,
116+
num_workers=num_workers,
117117
pin_memory=torch.cuda.is_available()
118118
)
119119
val_path = loader_dir + 'validation'
@@ -124,7 +124,7 @@ def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1
124124
'''
125125
Makes test data loader
126126
'''
127-
def get_test_loader(loader_dir, test_img_list, down_factor=1, down_dir=None):
127+
def get_test_loader(loader_dir, test_img_list, down_factor=1, down_dir=None, num_workers=0):
128128
sw_message("Creating test torch loader...")
129129
# get data
130130
image_paths = []
@@ -152,7 +152,7 @@ def get_test_loader(loader_dir, test_img_list, down_factor=1, down_dir=None):
152152
test_data,
153153
batch_size=1,
154154
shuffle=False,
155-
num_workers=8,
155+
num_workers=num_workers,
156156
pin_memory=torch.cuda.is_available()
157157
)
158158
test_path = loader_dir + 'test'

Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def groom_val_test_images(project, indices):
443443
project.set_subjects(subjects)
444444

445445

446-
def prepare_data_loaders(project, batch_size, split="all"):
446+
def prepare_data_loaders(project, batch_size, split="all", num_workers=0):
447447
""" Prepare PyTorch laoders """
448448
deepssm_dir = get_deepssm_dir(project)
449449
loader_dir = deepssm_dir + 'torch_loaders/'
@@ -458,19 +458,19 @@ def prepare_data_loaders(project, batch_size, split="all"):
458458
val_image_files.append(deepssm_dir + f"/val_and_test_images/{i}.nrrd")
459459
particle_file = project.get_subjects()[i].get_world_particle_filenames()[0]
460460
val_world_particles.append(particle_file)
461-
DeepSSMUtils.getValidationLoader(loader_dir, val_image_files, val_world_particles)
461+
DeepSSMUtils.getValidationLoader(loader_dir, val_image_files, val_world_particles, num_workers=num_workers)
462462

463463
if split == "all" or split == "train":
464464
aug_dir = deepssm_dir + "augmentation/"
465465
aug_data_csv = aug_dir + "TotalData.csv"
466-
DeepSSMUtils.getTrainLoader(loader_dir, aug_data_csv, batch_size)
466+
DeepSSMUtils.getTrainLoader(loader_dir, aug_data_csv, batch_size, num_workers=num_workers)
467467

468468
if split == "all" or split == "test":
469469
test_image_files = []
470470
test_indices = get_split_indices(project, "test")
471471
for i in test_indices:
472472
test_image_files.append(deepssm_dir + f"/val_and_test_images/{i}.nrrd")
473-
DeepSSMUtils.getTestLoader(loader_dir, test_image_files)
473+
DeepSSMUtils.getTestLoader(loader_dir, test_image_files, num_workers=num_workers)
474474

475475

476476
def get_test_alignment_transform(project, index):

Studio/Data/Preferences.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <Data/Preferences.h>
22
#include <Logging.h>
3+
#include <Utils/PlatformUtils.h>
34

45
// std
56
#include <iostream>
@@ -73,6 +74,21 @@ void Preferences::set_num_threads(int num_threads) {
7374
update_threads();
7475
}
7576

77+
//-----------------------------------------------------------------------------
78+
int Preferences::get_dataloader_num_workers() {
79+
int default_workers = 0;
80+
// if mac or linux, set to 4
81+
if (shapeworks::PlatformUtils::is_macos() || shapeworks::PlatformUtils::is_linux()) {
82+
default_workers = 4;
83+
}
84+
return settings_.value("Studio/dataloader_num_workers", default_workers).toInt();
85+
}
86+
87+
//-----------------------------------------------------------------------------
88+
void Preferences::set_dataloader_num_workers(int num_workers) {
89+
settings_.setValue("Studio/dataloader_num_workers", num_workers);
90+
}
91+
7692
//-----------------------------------------------------------------------------
7793
float Preferences::get_glyph_size() { return settings_.value("Project/glyph_size", 5.0).toFloat(); }
7894

Studio/Data/Preferences.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ class Preferences : public QObject {
4747
int get_num_threads();
4848
void set_num_threads(int num_threads);
4949

50+
int get_dataloader_num_workers();
51+
void set_dataloader_num_workers(int num_workers);
52+
5053
float get_glyph_size();
5154
void set_glyph_size(float value);
5255

Studio/Data/PreferencesWindow.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ void PreferencesWindow::set_values_from_preferences() {
116116
ui_->geodesic_cache_multiplier->setValue(preferences_.get_geodesic_cache_multiplier());
117117
ui_->auto_update_checkbox->setChecked(preferences_.get_auto_update_check());
118118
ui_->telemetry_enabled->setChecked(preferences_.get_telemetry_enabled());
119+
ui_->data_loader_num_workers->setText(QString::number(preferences_.get_dataloader_num_workers()));
119120
update_labels();
120121
}
121122

@@ -147,6 +148,7 @@ void PreferencesWindow::save_to_preferences() {
147148
preferences_.set_reverse_color_map(ui_->reverse_color_map->isChecked());
148149
preferences_.set_auto_update_check(ui_->auto_update_checkbox->isChecked());
149150
preferences_.set_telemetry_enabled(ui_->telemetry_enabled->isChecked());
151+
preferences_.set_dataloader_num_workers(ui_->data_loader_num_workers->text().toInt());
150152
update_labels();
151153
Q_EMIT update_view();
152154
}

Studio/Data/PreferencesWindow.ui

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
<x>0</x>
88
<y>0</y>
99
<width>756</width>
10-
<height>668</height>
10+
<height>711</height>
1111
</rect>
1212
</property>
1313
<property name="windowTitle">
@@ -364,7 +364,24 @@
364364
<string>Parallel Processing</string>
365365
</property>
366366
<layout class="QGridLayout" name="gridLayout_10">
367-
<item row="0" column="0">
367+
<item row="2" column="0">
368+
<widget class="QLabel" name="label_9">
369+
<property name="text">
370+
<string>DeepSSM data loader workers</string>
371+
</property>
372+
</widget>
373+
</item>
374+
<item row="2" column="1">
375+
<widget class="QLineEdit" name="data_loader_num_workers">
376+
<property name="text">
377+
<string>0</string>
378+
</property>
379+
<property name="alignment">
380+
<set>Qt::AlignCenter</set>
381+
</property>
382+
</widget>
383+
</item>
384+
<item row="0" column="0" colspan="2">
368385
<layout class="QGridLayout" name="gridLayout_9">
369386
<item row="2" column="0">
370387
<widget class="QLabel" name="label_19">

0 commit comments

Comments
 (0)