Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions Applications/shapeworks/Commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,13 @@ void DeepSSMCommand::buildParser() {

parser.add_option("--all").action("store_true").help("Run all steps");

// add num_workers option
parser.add_option("--num_workers")
.action("store")
.type("int")
.set_default(0)
.help("Number of data loader workers (default: 0)");

Command::buildParser();
}

Expand Down Expand Up @@ -405,10 +412,13 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData&
bool do_train = options.is_set("train") || options.is_set("all");
bool do_test = options.is_set("test") || options.is_set("all");

int num_workers = static_cast<int>(options.get("num_workers"));

std::cout << "Prep step: " << (do_prep ? "on" : "off") << "\n";
std::cout << "Augment step: " << (do_augment ? "on" : "off") << "\n";
std::cout << "Train step: " << (do_train ? "on" : "off") << "\n";
std::cout << "Test step: " << (do_test ? "on" : "off") << "\n";
std::cout << "Num dataloader workers: " << num_workers << "\n";

if (!do_prep && !do_augment && !do_train && !do_test) {
do_prep = true;
Expand Down Expand Up @@ -437,6 +447,7 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData&

if (do_prep) {
auto job = QSharedPointer<DeepSSMJob>::create(project, DeepSSMJob::JobType::DeepSSM_PrepType);
job->set_num_dataloader_workers(num_workers);
if (prep_step == "all") {
job->set_prep_step(DeepSSMJob::PrepStep::NOT_STARTED);
} else if (prep_step == "groom_training") {
Expand Down
15 changes: 11 additions & 4 deletions Libs/Application/DeepSSM/DeepSSMJob.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ using namespace pybind11::literals; // to bring in the `_a` literal
#include <QThread>

// shapeworks
#include "DeepSSMJob.h"
#include <Project/DeepSSMParameters.h>
#include <Groom.h>
#include <Logging.h>
#include <Mesh/MeshUtils.h>
#include <Optimize.h>
#include <Optimize/OptimizeParameters.h>
#include <Project/DeepSSMParameters.h>

#include "DeepSSMJob.h"

namespace shapeworks {

Expand Down Expand Up @@ -274,8 +275,8 @@ void DeepSSMJob::run_training() {
py::module py_deep_ssm_utils = py::module::import("DeepSSMUtils");

py::object prepare_data_loaders = py_deep_ssm_utils.attr("prepare_data_loaders");
prepare_data_loaders(project_, batch_size, "train");
prepare_data_loaders(project_, batch_size, "val");
prepare_data_loaders(project_, batch_size, "train", num_dataloader_workers_);
prepare_data_loaders(project_, batch_size, "val", num_dataloader_workers_);

std::string out_dir = "deepssm/";
std::string aug_dir = out_dir + "augmentation/";
Expand Down Expand Up @@ -387,6 +388,12 @@ std::vector<int> DeepSSMJob::get_split(ProjectHandle project, SplitType split_ty
return list;
}

//---------------------------------------------------------------------------
void DeepSSMJob::set_num_dataloader_workers(int num_workers) { num_dataloader_workers_ = num_workers; }

//---------------------------------------------------------------------------
int DeepSSMJob::get_num_dataloader_workers() { return num_dataloader_workers_; }

//---------------------------------------------------------------------------
void DeepSSMJob::update_prep_stage(PrepStep step) {
/*
Expand Down
5 changes: 5 additions & 0 deletions Libs/Application/DeepSSM/DeepSSMJob.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class DeepSSMJob : public Job {

static std::vector<int> get_split(ProjectHandle project, DeepSSMJob::SplitType split_type);

void set_num_dataloader_workers(int num_workers);
int get_num_dataloader_workers();

void set_prep_step(DeepSSMJob::PrepStep step) {
std::lock_guard<std::mutex> lock(mutex_);
prep_step_ = step;
Expand All @@ -68,6 +71,8 @@ class DeepSSMJob : public Job {
QString prep_message_;
DeepSSMJob::PrepStep prep_step_{DeepSSMJob::NOT_STARTED};

int num_dataloader_workers_{0};

// mutex
std::mutex mutex_;
};
Expand Down
16 changes: 8 additions & 8 deletions Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,22 @@
import torch


def getTrainValLoaders(loader_dir, aug_data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80):
def getTrainValLoaders(loader_dir, aug_data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80, num_workers=0):
testPytorch()
loaders.get_train_val_loaders(loader_dir, aug_data_csv, batch_size, down_factor, down_dir, train_split)
loaders.get_train_val_loaders(loader_dir, aug_data_csv, batch_size, down_factor, down_dir, train_split, num_workers)


def getTrainLoader(loader_dir, aug_data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80):
def getTrainLoader(loader_dir, aug_data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80, num_workers=0):
testPytorch()
loaders.get_train_loader(loader_dir, aug_data_csv, batch_size, down_factor, down_dir, train_split)
loaders.get_train_loader(loader_dir, aug_data_csv, batch_size, down_factor, down_dir, train_split, num_workers)


def getValidationLoader(loader_dir, val_img_list, val_particles, down_factor=1, down_dir=None):
loaders.get_validation_loader(loader_dir, val_img_list, val_particles, down_factor, down_dir)
def getValidationLoader(loader_dir, val_img_list, val_particles, down_factor=1, down_dir=None, num_workers=0):
loaders.get_validation_loader(loader_dir, val_img_list, val_particles, down_factor, down_dir, num_workers)


def getTestLoader(loader_dir, test_img_list, down_factor=1, down_dir=None):
loaders.get_test_loader(loader_dir, test_img_list, down_factor, down_dir)
def getTestLoader(loader_dir, test_img_list, down_factor=1, down_dir=None, num_workers=0):
loaders.get_test_loader(loader_dir, test_img_list, down_factor, down_dir, num_workers)


def prepareConfigFile(config_filename, model_name, embedded_dim, out_dir, loader_dir, aug_dir, epochs, learning_rate,
Expand Down
18 changes: 9 additions & 9 deletions Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def make_dir(dirPath):
'''
Reads csv and makes both train and validation data loaders from it
'''
def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80):
def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80, num_workers=0):
sw_message("Creating training and validation torch loaders:")
make_dir(loader_dir)
images, scores, models, prefixes = get_all_train_data(loader_dir, data_csv, down_factor, down_dir)
Expand All @@ -41,7 +41,7 @@ def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, dow
train_data,
batch_size=batch_size,
shuffle=True,
num_workers=8,
num_workers=num_workers,
pin_memory=torch.cuda.is_available()
)
train_path = loader_dir + 'train'
Expand All @@ -51,7 +51,7 @@ def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, dow
val_data,
batch_size=1,
shuffle=True,
num_workers=8,
num_workers=num_workers,
pin_memory=torch.cuda.is_available()
)
val_path = loader_dir + 'validation'
Expand All @@ -62,7 +62,7 @@ def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, dow
'''
Reads csv and makes just train data loaders
'''
def get_train_loader(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80):
def get_train_loader(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80, num_workers=0):
sw_message("Creating training torch loader...")
# Get data
make_dir(loader_dir)
Expand All @@ -74,7 +74,7 @@ def get_train_loader(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir
train_data,
batch_size=batch_size,
shuffle=True,
num_workers=8,
num_workers=num_workers,
pin_memory=torch.cuda.is_available()
)
train_path = loader_dir + 'train'
Expand All @@ -85,7 +85,7 @@ def get_train_loader(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir
'''
Makes validation data loader
'''
def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1, down_dir=None):
def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1, down_dir=None, num_workers=0):
sw_message("Creating validation torch loader:")
# Get data
image_paths = []
Expand Down Expand Up @@ -113,7 +113,7 @@ def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1
val_data,
batch_size=1,
shuffle=False,
num_workers=8,
num_workers=num_workers,
pin_memory=torch.cuda.is_available()
)
val_path = loader_dir + 'validation'
Expand All @@ -124,7 +124,7 @@ def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1
'''
Makes test data loader
'''
def get_test_loader(loader_dir, test_img_list, down_factor=1, down_dir=None):
def get_test_loader(loader_dir, test_img_list, down_factor=1, down_dir=None, num_workers=0):
sw_message("Creating test torch loader...")
# get data
image_paths = []
Expand Down Expand Up @@ -152,7 +152,7 @@ def get_test_loader(loader_dir, test_img_list, down_factor=1, down_dir=None):
test_data,
batch_size=1,
shuffle=False,
num_workers=8,
num_workers=num_workers,
pin_memory=torch.cuda.is_available()
)
test_path = loader_dir + 'test'
Expand Down
8 changes: 4 additions & 4 deletions Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def groom_val_test_images(project, indices):
project.set_subjects(subjects)


def prepare_data_loaders(project, batch_size, split="all"):
def prepare_data_loaders(project, batch_size, split="all", num_workers=0):
""" Prepare PyTorch laoders """
deepssm_dir = get_deepssm_dir(project)
loader_dir = deepssm_dir + 'torch_loaders/'
Expand All @@ -458,19 +458,19 @@ def prepare_data_loaders(project, batch_size, split="all"):
val_image_files.append(deepssm_dir + f"/val_and_test_images/{i}.nrrd")
particle_file = project.get_subjects()[i].get_world_particle_filenames()[0]
val_world_particles.append(particle_file)
DeepSSMUtils.getValidationLoader(loader_dir, val_image_files, val_world_particles)
DeepSSMUtils.getValidationLoader(loader_dir, val_image_files, val_world_particles, num_workers=num_workers)

if split == "all" or split == "train":
aug_dir = deepssm_dir + "augmentation/"
aug_data_csv = aug_dir + "TotalData.csv"
DeepSSMUtils.getTrainLoader(loader_dir, aug_data_csv, batch_size)
DeepSSMUtils.getTrainLoader(loader_dir, aug_data_csv, batch_size, num_workers=num_workers)

if split == "all" or split == "test":
test_image_files = []
test_indices = get_split_indices(project, "test")
for i in test_indices:
test_image_files.append(deepssm_dir + f"/val_and_test_images/{i}.nrrd")
DeepSSMUtils.getTestLoader(loader_dir, test_image_files)
DeepSSMUtils.getTestLoader(loader_dir, test_image_files, num_workers=num_workers)


def get_test_alignment_transform(project, index):
Expand Down
16 changes: 16 additions & 0 deletions Studio/Data/Preferences.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <Data/Preferences.h>
#include <Logging.h>
#include <Utils/PlatformUtils.h>

// std
#include <iostream>
Expand Down Expand Up @@ -73,6 +74,21 @@ void Preferences::set_num_threads(int num_threads) {
update_threads();
}

//-----------------------------------------------------------------------------
int Preferences::get_dataloader_num_workers() {
int default_workers = 0;
// if mac or linux, set to 4
if (shapeworks::PlatformUtils::is_macos() || shapeworks::PlatformUtils::is_linux()) {
default_workers = 4;
}
return settings_.value("Studio/dataloader_num_workers", default_workers).toInt();
}

//-----------------------------------------------------------------------------
void Preferences::set_dataloader_num_workers(int num_workers) {
settings_.setValue("Studio/dataloader_num_workers", num_workers);
}

//-----------------------------------------------------------------------------
float Preferences::get_glyph_size() { return settings_.value("Project/glyph_size", 5.0).toFloat(); }

Expand Down
3 changes: 3 additions & 0 deletions Studio/Data/Preferences.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class Preferences : public QObject {
int get_num_threads();
void set_num_threads(int num_threads);

int get_dataloader_num_workers();
void set_dataloader_num_workers(int num_workers);

float get_glyph_size();
void set_glyph_size(float value);

Expand Down
2 changes: 2 additions & 0 deletions Studio/Data/PreferencesWindow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ void PreferencesWindow::set_values_from_preferences() {
ui_->geodesic_cache_multiplier->setValue(preferences_.get_geodesic_cache_multiplier());
ui_->auto_update_checkbox->setChecked(preferences_.get_auto_update_check());
ui_->telemetry_enabled->setChecked(preferences_.get_telemetry_enabled());
ui_->data_loader_num_workers->setText(QString::number(preferences_.get_dataloader_num_workers()));
update_labels();
}

Expand Down Expand Up @@ -147,6 +148,7 @@ void PreferencesWindow::save_to_preferences() {
preferences_.set_reverse_color_map(ui_->reverse_color_map->isChecked());
preferences_.set_auto_update_check(ui_->auto_update_checkbox->isChecked());
preferences_.set_telemetry_enabled(ui_->telemetry_enabled->isChecked());
preferences_.set_dataloader_num_workers(ui_->data_loader_num_workers->text().toInt());
update_labels();
Q_EMIT update_view();
}
Expand Down
21 changes: 19 additions & 2 deletions Studio/Data/PreferencesWindow.ui
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<x>0</x>
<y>0</y>
<width>756</width>
<height>668</height>
<height>711</height>
</rect>
</property>
<property name="windowTitle">
Expand Down Expand Up @@ -364,7 +364,24 @@
<string>Parallel Processing</string>
</property>
<layout class="QGridLayout" name="gridLayout_10">
<item row="0" column="0">
<item row="2" column="0">
<widget class="QLabel" name="label_9">
<property name="text">
<string>DeepSSM data loader workers</string>
</property>
</widget>
</item>
<item row="2" column="1">
<widget class="QLineEdit" name="data_loader_num_workers">
<property name="text">
<string>0</string>
</property>
<property name="alignment">
<set>Qt::AlignCenter</set>
</property>
</widget>
</item>
<item row="0" column="0" colspan="2">
<layout class="QGridLayout" name="gridLayout_9">
<item row="2" column="0">
<widget class="QLabel" name="label_19">
Expand Down
1 change: 1 addition & 0 deletions Studio/DeepSSM/DeepSSMTool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,7 @@ void DeepSSMTool::run_tool(DeepSSMJob::JobType job_type) {
store_params();

deep_ssm_ = QSharedPointer<DeepSSMJob>::create(session_->get_project(), job_type, prep_step_);
deep_ssm_->set_num_dataloader_workers(preferences_.get_dataloader_num_workers());
connect(deep_ssm_.data(), &DeepSSMJob::progress, this, &DeepSSMTool::handle_progress);
connect(deep_ssm_.data(), &DeepSSMJob::finished, this, &DeepSSMTool::handle_thread_complete);

Expand Down
Loading