Skip to content

Commit 099be51

Browse files
committed
Improve monai label threading and progressbars. Update formatting.
1 parent a7b1a0c commit 099be51

File tree

4 files changed

+102
-165
lines changed

4 files changed

+102
-165
lines changed

Studio/ShapeWorksMONAI/MonaiLabelJob.cpp

Lines changed: 30 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,8 @@ const std::string MonaiLabelJob::MONAI_RESULT_EXTENSION(".nrrd");
2525
const std::string MonaiLabelJob::MONAI_RESULT_DTYPE("uint8");
2626

2727
//---------------------------------------------------------------------------
28-
MonaiLabelJob::MonaiLabelJob(QSharedPointer<Session> session,
29-
const std::string &server_url,
30-
const std::string &client_id,
31-
const std::string &strategy,
32-
const std::string &model_type)
28+
MonaiLabelJob::MonaiLabelJob(QSharedPointer<Session> session, const std::string &server_url,
29+
const std::string &client_id, const std::string &strategy, const std::string &model_type)
3330
: session_(session),
3431
server_url_(server_url),
3532
client_id_(client_id),
@@ -40,8 +37,7 @@ MonaiLabelJob::MonaiLabelJob(QSharedPointer<Session> session,
4037
model_type_(model_type),
4138
monai_client_(nullptr) {
4239
project_ = session_->get_project();
43-
QDir projectDir(
44-
QString::fromStdString(session->get_project()->get_project_path()));
40+
QDir projectDir(QString::fromStdString(session->get_project()->get_project_path()));
4541
QString labels_dir = projectDir.filePath("labels-prediction");
4642

4743
if (MonaiLabelUtils::createDir(labels_dir)) {
@@ -75,17 +71,14 @@ void MonaiLabelJob::setCurrentSampleNumber(int n) { sample_number_ = n; }
7571

7672
//---------------------------------------------------------------------------
7773
void MonaiLabelJob::initializeClient() {
78-
SW_DEBUG(
79-
"Initializing MONAI Client with server: {} tmp dir: {} client_id: {}",
80-
server_url_, tmp_dir_, client_id_);
74+
SW_DEBUG("Initializing MONAI Client with server: {} tmp dir: {} client_id: {}", server_url_, tmp_dir_, client_id_);
8175
try {
8276
py::module monai_label = py::module::import("MONAILabel");
8377
py::object monai_client_class = monai_label.attr("MONAILabelClient");
8478
py::str py_server_url(server_url_);
8579
py::str py_tmp_dir(tmp_dir_);
8680
py::str py_client_id(client_id_);
87-
monai_client_ = std::make_shared<py::object>(
88-
monai_client_class(py_server_url, py_tmp_dir, py_client_id));
81+
monai_client_ = std::make_shared<py::object>(monai_client_class(py_server_url, py_tmp_dir, py_client_id));
8982
if (!monai_client_) {
9083
SW_ERROR("Error in instantiating MONAI client");
9184
return;
@@ -95,8 +88,7 @@ void MonaiLabelJob::initializeClient() {
9588
models_available_[model_type_] = {model_name_};
9689
Q_EMIT triggerClientInitialized();
9790
} catch (std::exception &e) {
98-
std::cerr << "Error importing MONAILabel or initializing MONAILabelClient: "
99-
<< e.what() << std::endl;
91+
std::cerr << "Error importing MONAILabel or initializing MONAILabelClient: " << e.what() << std::endl;
10092
return;
10193
}
10294
}
@@ -120,8 +112,7 @@ py::dict MonaiLabelJob::getInfo() {
120112
return response;
121113
}
122114

123-
std::vector<std::string> MonaiLabelJob::getModelNames(
124-
const std::string &model_type) {
115+
std::vector<std::string> MonaiLabelJob::getModelNames(const std::string &model_type) {
125116
auto it = models_available_.find(model_type);
126117
if (it != models_available_.end()) {
127118
return it->second;
@@ -136,8 +127,7 @@ std::string MonaiLabelJob::getModelName(std::string modelType) {
136127
py::dict info = getInfo();
137128
std::string modelName = "";
138129
if (info.contains(mappedSection)) {
139-
py::dict sectionConfig =
140-
info[mappedSection].cast<py::dict>(); // models dict
130+
py::dict sectionConfig = info[mappedSection].cast<py::dict>(); // models dict
141131
for (const auto &item : sectionConfig) {
142132
std::string nameFound = py::str(item.first);
143133
py::dict modelConfig = item.second.cast<py::dict>();
@@ -168,14 +158,10 @@ std::string MonaiLabelJob::getSessionId() {
168158
}
169159

170160
//---------------------------------------------------------------------------
171-
py::dict MonaiLabelJob::getParamsFromConfig(std::string section,
172-
std::string name) {
161+
py::dict MonaiLabelJob::getParamsFromConfig(std::string section, std::string name) {
173162
py::dict info = getInfo();
174163
std::unordered_map<std::string, std::string> mapping = {
175-
{"infer", "models"},
176-
{"train", "trainers"},
177-
{"activelearning", "strategies"},
178-
{"scoring", "scoring"}};
164+
{"infer", "models"}, {"train", "trainers"}, {"activelearning", "strategies"}, {"scoring", "scoring"}};
179165
auto it = mapping.find(section);
180166
std::string mappedSection = (it != mapping.end()) ? it->second : section;
181167
py::dict result;
@@ -193,7 +179,6 @@ py::dict MonaiLabelJob::getParamsFromConfig(std::string section,
193179
if (!valueList.empty()) {
194180
// result[key] = py::str(valueList[0]);
195181
result[key] = valueList[0];
196-
197182
}
198183
} else {
199184
// result[key] = py::str(value);
@@ -220,21 +205,16 @@ py::dict MonaiLabelJob::nextSample(std::string strategy, py::dict params) {
220205
}
221206

222207
//---------------------------------------------------------------------------
223-
py::dict MonaiLabelJob::uploadImage(std::string image_in,
224-
std::string image_id) {
208+
py::dict MonaiLabelJob::uploadImage(std::string image_in, std::string image_id) {
225209
py::dict response;
226210

227211
try {
228212
if (!monai_client_) {
229213
SW_ERROR("MONAI client not initialized yet");
230214
return response;
231215
}
232-
SW_DEBUG("Uploading sample number {} to MONAI Label server",
233-
sample_number_);
234-
response =
235-
(*monai_client_)
236-
.attr("upload_image")(
237-
image_in, image_id.empty() ? py::none() : py::cast(image_id));
216+
SW_DEBUG("Uploading sample number {} to MONAI Label server", sample_number_);
217+
response = (*monai_client_).attr("upload_image")(image_in, image_id.empty() ? py::none() : py::cast(image_id));
238218
// SW_DEBUG("Upload sample response: " +
239219
// py::repr(response).cast<std::string>());
240220

@@ -250,8 +230,7 @@ py::dict MonaiLabelJob::uploadImage(std::string image_in,
250230
}
251231

252232
//---------------------------------------------------------------------------
253-
py::tuple MonaiLabelJob::infer(std::string model, std::string image_in,
254-
py::dict params, std::string label_in,
233+
py::tuple MonaiLabelJob::infer(std::string model, std::string image_in, py::dict params, std::string label_in,
255234
std::string file, std::string session_id) {
256235
py::tuple result = py::make_tuple();
257236

@@ -263,11 +242,9 @@ py::tuple MonaiLabelJob::infer(std::string model, std::string image_in,
263242
params[py::str("result_extension")] = MonaiLabelJob::MONAI_RESULT_EXTENSION;
264243
params[py::str("result_dtype")] = MonaiLabelJob::MONAI_RESULT_DTYPE;
265244
result = (*monai_client_)
266-
.attr("infer")(
267-
model, image_in, params,
268-
label_in.empty() ? py::none() : py::cast(label_in),
269-
file.empty() ? py::none() : py::cast(file),
270-
session_id.empty() ? py::none() : py::cast(session_id));
245+
.attr("infer")(model, image_in, params, label_in.empty() ? py::none() : py::cast(label_in),
246+
file.empty() ? py::none() : py::cast(file),
247+
session_id.empty() ? py::none() : py::cast(session_id));
271248
// std::cout << "DEBUG | infer call successfully made " << py::repr(result).cast<std::string>() << std::endl;
272249
// SW_DEBUG("Infer response: " + py::repr(result).cast<std::string>());
273250
}
@@ -284,16 +261,14 @@ py::tuple MonaiLabelJob::infer(std::string model, std::string image_in,
284261
}
285262

286263
//---------------------------------------------------------------------------
287-
py::dict MonaiLabelJob::saveLabel(std::string image_in, std::string label_in,
288-
py::dict params) {
264+
py::dict MonaiLabelJob::saveLabel(std::string image_in, std::string label_in, py::dict params) {
289265
py::dict response;
290266
try {
291267
if (!monai_client_) {
292268
SW_ERROR("MONAI client not initialized yet");
293269
return response;
294270
}
295-
response =
296-
(*monai_client_).attr("save_label")(image_in, label_in, "", params);
271+
response = (*monai_client_).attr("save_label")(image_in, label_in, "", params);
297272
// SW_DEBUG("Save Label response: " +
298273
// py::repr(response).cast<std::string>());
299274
}
@@ -320,6 +295,7 @@ void MonaiLabelJob::runSegmentationModel() {
320295
waitingForLabelSubmission = false;
321296
}
322297

298+
//---------------------------------------------------------------------------
323299
void MonaiLabelJob::run() { runSegmentationModel(); }
324300

325301
//---------------------------------------------------------------------------
@@ -333,6 +309,7 @@ void MonaiLabelJob::python_message(std::string str) { SW_LOG(str); }
333309
//---------------------------------------------------------------------------
334310
void MonaiLabelJob::onUploadSampleClicked() {
335311
if (waitingForUpload) {
312+
progress(-1);
336313
py::dict params = getParamsFromConfig("activelearning", strategy_);
337314
auto subjects = session_->get_project()->get_subjects();
338315
auto shapes = session_->get_shapes();
@@ -349,11 +326,9 @@ void MonaiLabelJob::onUploadSampleClicked() {
349326
QString subjectName = imageFileInfo.completeBaseName();
350327
QString extension = imageFileInfo.completeSuffix();
351328

352-
py::dict response = uploadImage(absoluteImagePath.toStdString(),
353-
subjectName.toStdString());
329+
py::dict response = uploadImage(absoluteImagePath.toStdString(), subjectName.toStdString());
354330

355-
if (response[py::str("image")].cast<std::string>() ==
356-
subjectName.toStdString()) {
331+
if (response[py::str("image")].cast<std::string>() == subjectName.toStdString()) {
357332
currentSampleId_ = subjectName.toStdString();
358333
} else {
359334
SW_ERROR("Upload source volume failed!");
@@ -375,15 +350,15 @@ void MonaiLabelJob::onUploadSampleClicked() {
375350
//---------------------------------------------------------------------------
376351
void MonaiLabelJob::onRunSegmentationClicked() {
377352
if (waitingForSegmentation) {
353+
progress(-1);
378354
if (currentSampleId_.empty()) {
379355
SW_ERROR("Sample not uploaded yet!");
380356
return;
381357
}
382358
SW_LOG("⚙️ Processing inference on the current subject");
383359
py::dict params = getParamsFromConfig("infer", model_name_);
384360

385-
py::tuple result =
386-
infer(model_name_, currentSampleId_, params, "", "", getSessionId());
361+
py::tuple result = infer(model_name_, currentSampleId_, params, "", "", getSessionId());
387362

388363
currentSegmentationPath_ = result[0].cast<std::string>();
389364
py::dict result_params = result[1].cast<py::dict>();
@@ -402,15 +377,13 @@ void MonaiLabelJob::onRunSegmentationClicked() {
402377
}
403378
}
404379

405-
MonaiLabelUtils::processSegmentation(currentSegmentationPath_, organLabels, tmp_dir_, currentSampleId_, currentSegmentationPaths_);
380+
MonaiLabelUtils::processSegmentation(currentSegmentationPath_, organLabels, tmp_dir_, currentSampleId_,
381+
currentSegmentationPaths_);
406382

407383
QDir projDir(QString::fromStdString(tmp_dir_));
408-
QString destPath =
409-
projDir.filePath(QString::fromStdString(currentSampleId_ + ".nrrd"));
410-
MonaiLabelUtils::copySegmentation(
411-
QString::fromStdString(currentSegmentationPath_), destPath);
412-
MonaiLabelUtils::deleteTempFile(
413-
QString::fromStdString(currentSegmentationPath_));
384+
QString destPath = projDir.filePath(QString::fromStdString(currentSampleId_ + ".nrrd"));
385+
MonaiLabelUtils::copySegmentation(QString::fromStdString(currentSegmentationPath_), destPath);
386+
MonaiLabelUtils::deleteTempFile(QString::fromStdString(currentSegmentationPath_));
414387
currentSegmentationPath_ = destPath.toStdString();
415388
SW_DEBUG("Prediction label saved at {}", currentSegmentationPath_);
416389

Studio/ShapeWorksMONAI/MonaiLabelJob.h

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@
1616
namespace py = pybind11;
1717

1818
namespace shapeworks {
19-
class ShapeWorksStudioApp;
20-
class Session;
21-
class Job;
22-
}
19+
class ShapeWorksStudioApp;
20+
class Session;
21+
class Job;
22+
} // namespace shapeworks
2323

2424
namespace monailabel {
2525

26+
using shapeworks::Job;
2627
using shapeworks::ProjectHandle;
2728
using shapeworks::Session;
2829
using shapeworks::ShapeWorksStudioApp;
29-
using shapeworks::Job;
3030

3131
class MonaiLabelJob : public Job {
3232
Q_OBJECT;
@@ -35,19 +35,16 @@ class MonaiLabelJob : public Job {
3535
const static std::string MONAI_RESULT_EXTENSION;
3636
const static std::string MONAI_RESULT_DTYPE;
3737

38-
MonaiLabelJob(QSharedPointer<Session> session, const std::string& server_url,
39-
const std::string& client_id, const std::string& strategy,
40-
const std::string& model_type);
38+
MonaiLabelJob(QSharedPointer<Session> session, const std::string& server_url, const std::string& client_id,
39+
const std::string& strategy, const std::string& model_type);
4140
~MonaiLabelJob();
4241
void setServer(const std::string& server_url);
4342
void setModelType(const std::string& model_type);
4443
inline const std::string& getServer() { return server_url_; }
4544
void setClientId(const std::string& client_id = "");
4645
inline const std::string& getClientId() { return client_id_; };
4746
void initializeClient();
48-
inline std::shared_ptr<py::object> getClient() const {
49-
return monai_client_;
50-
};
47+
inline std::shared_ptr<py::object> getClient() const { return monai_client_; };
5148

5249
py::dict getInfo();
5350
std::string getModelName(std::string modelType);
@@ -57,11 +54,9 @@ class MonaiLabelJob : public Job {
5754

5855
// MONAI Client callers
5956
py::dict nextSample(std::string strategy, py::dict params);
60-
py::tuple infer(std::string model, std::string image_in, py::dict params,
61-
std::string label_in, std::string file,
57+
py::tuple infer(std::string model, std::string image_in, py::dict params, std::string label_in, std::string file,
6258
std::string session_id);
63-
py::dict saveLabel(std::string image_in, std::string label_in,
64-
py::dict params);
59+
py::dict saveLabel(std::string image_in, std::string label_in, py::dict params);
6560
py::dict uploadImage(std::string image_in, std::string image_id);
6661

6762
void updateShapes();
@@ -73,9 +68,9 @@ class MonaiLabelJob : public Job {
7368
void setCurrentSampleNumber(int n);
7469

7570
public Q_SLOTS:
76-
void onUploadSampleClicked(); // Triggered when upload Sample button is clicked
77-
void onRunSegmentationClicked(); // Triggered when Run Segmentation is clicked
78-
void onSubmitLabelClicked(); // Triggered when Submit Label is clicked
71+
void onUploadSampleClicked(); // Triggered when upload Sample button is clicked
72+
void onRunSegmentationClicked(); // Triggered when Run Segmentation is clicked
73+
void onSubmitLabelClicked(); // Triggered when Submit Label is clicked
7974

8075
Q_SIGNALS:
8176
void triggerUpdateView();
@@ -95,7 +90,8 @@ class MonaiLabelJob : public Job {
9590
std::string strategy_;
9691
std::string model_type_ = "";
9792
std::string model_name_ = "";
98-
std::unordered_map<std::string, std::vector<std::string>> models_available_; // TODO: add functionality to interchange between models from UI
93+
std::unordered_map<std::string, std::vector<std::string>>
94+
models_available_; // TODO: add functionality to interchange between models from UI
9995
std::shared_ptr<py::object> monai_client_;
10096

10197
// QT states

0 commit comments

Comments
 (0)