Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions Studio/ShapeWorksMONAI/MonaiLabelJob.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,13 @@ py::dict MonaiLabelJob::getParamsFromConfig(std::string section,
if (py::isinstance<py::list>(value)) {
py::list valueList = value.cast<py::list>();
if (!valueList.empty()) {
result[key] = py::str(valueList[0]);
// result[key] = py::str(valueList[0]);
result[key] = valueList[0];

}
} else {
result[key] = py::str(value);
// result[key] = py::str(value);
result[key] = value;
}
}
}
Expand Down Expand Up @@ -265,6 +268,7 @@ py::tuple MonaiLabelJob::infer(std::string model, std::string image_in,
label_in.empty() ? py::none() : py::cast(label_in),
file.empty() ? py::none() : py::cast(file),
session_id.empty() ? py::none() : py::cast(session_id));
// std::cout << "DEBUG | infer call successfully made " << py::repr(result).cast<std::string>() << std::endl;
// SW_DEBUG("Infer response: " + py::repr(result).cast<std::string>());
}

Expand Down Expand Up @@ -375,13 +379,30 @@ void MonaiLabelJob::onRunSegmentationClicked() {
SW_ERROR("Sample not uploaded yet!");
return;
}
SW_LOG("⚙️ Processing inference on the current subject");
py::dict params = getParamsFromConfig("infer", model_name_);

py::tuple result =
infer(model_name_, currentSampleId_, params, "", "", getSessionId());

currentSegmentationPath_ =
result[0].cast<std::string>(); // temp result for segmentation
currentSegmentationPath_ = result[0].cast<std::string>();
py::dict result_params = result[1].cast<py::dict>();

// Extract label names from result_params
py::dict label_dict = result_params["label_names"].cast<py::dict>();
std::map<int, std::string> organLabels;
organNames_.resize(0);

for (auto &item : label_dict) {
std::string organName = item.first.cast<std::string>();
int label = item.second.cast<int>();
if (label > 0) { // Exclude background (0)
organLabels[label] = organName;
organNames_.push_back(organName);
}
}

MonaiLabelUtils::processSegmentation(currentSegmentationPath_, organLabels, tmp_dir_, currentSampleId_, currentSegmentationPaths_);

QDir projDir(QString::fromStdString(tmp_dir_));
QString destPath =
Expand Down Expand Up @@ -417,7 +438,7 @@ void MonaiLabelJob::onSubmitLabelClicked() {
std::string label_in = currentSegmentationPath_;

entry[py::str("name")] = image_in;
entry[py::str("idx")] = 1; // TODO: handle multi-organ label submission
entry[py::str("idx")] = 1;
label_info.append(entry);

py::dict params;
Expand All @@ -430,13 +451,14 @@ void MonaiLabelJob::onSubmitLabelClicked() {

//---------------------------------------------------------------------------
void MonaiLabelJob::updateShapes() {
if (!currentSampleId_.empty() && !currentSegmentationPath_.empty()) {
if (!currentSampleId_.empty() && !currentSegmentationPaths_.empty()) {
auto shapes = session_->get_shapes();

session_->get_project()->set_domain_names(organNames_);
if (sample_number_ < shapes.size()) {
auto cur_shape = shapes[sample_number_];
auto cur_subject = cur_shape->get_subject();
cur_subject->set_original_filenames({currentSegmentationPath_});
cur_subject->set_number_of_domains(currentSegmentationPaths_.size());
cur_subject->set_original_filenames(currentSegmentationPaths_);
}

} else {
Expand Down
2 changes: 2 additions & 0 deletions Studio/ShapeWorksMONAI/MonaiLabelJob.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ class MonaiLabelJob : public Job {
int sample_number_;
std::string currentSampleId_;
std::string currentSegmentationPath_;
std::vector<std::string> currentSegmentationPaths_;
std::vector<std::string> organNames_;

QSharedPointer<Session> session_;
ProjectHandle project_;
Expand Down
12 changes: 7 additions & 5 deletions Studio/ShapeWorksMONAI/MonaiLabelTool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,11 @@ void MonaiLabelTool::onConnectServer() {
"establishing connection with MONAI Label server");
return;
}
SW_LOG("Connecting to MONAI Label Server...")
SW_LOG("Connecting to MONAI Label Server...")
ui_->connectServerButton->setText("Connecting...");
ui_->connectServerButton->setEnabled(false);
loadParamsFromUi();
if (model_type_ == MONAI_MODE_SEGMENTATION) {
SW_LOG("Connecting to the server...");
runSegmentationTool();
} else {
SW_ERROR(
Expand Down Expand Up @@ -194,7 +193,7 @@ void MonaiLabelTool::runSegmentationTool() {

//---------------------------------------------------------------------------
void MonaiLabelTool::handleClientInitialized() {
SW_LOG("Connection successfully established to the server, continue with segmentation!");
SW_LOG("Connection successfully established to the server, continue with segmentation!");
tool_is_running_ = true;
if (session_->get_shapes().size() > 1)
ui_->uploadSampleButton->setEnabled(true);
Expand All @@ -221,7 +220,7 @@ void MonaiLabelTool::handleClientInitialized() {

//---------------------------------------------------------------------------
void MonaiLabelTool::handleUploadSampleCompleted() {
SW_LOG("Upload complete! Run {} model on the uploaded sample.", model_type_);
SW_LOG("Upload complete! Run {} model on the uploaded sample.", model_type_);
ui_->uploadSampleButton->setEnabled(false);
ui_->runSegmentationButton->setEnabled(true);
ui_->submitLabelButton->setEnabled(false);
Expand All @@ -237,6 +236,9 @@ void MonaiLabelTool::handleSegmentationCompleted() {
ui_->runSegmentationButton->setEnabled(false);
ui_->submitLabelButton->setEnabled(true);
session_->get_project()->save();
SW_LOG(
"✅ Segmentation for the current sample done! Submit the prediction label to server or "
"proceed with next sample!");
Q_EMIT progress(66);
}

Expand All @@ -250,7 +252,7 @@ void MonaiLabelTool::handleSubmitLabelCompleted() {
ui_->uploadSampleButton->setEnabled(false);
ui_->runSegmentationButton->setEnabled(false);
ui_->submitLabelButton->setEnabled(false);
SW_LOG("Label submitted to the server. Proceed with next source volume.")
SW_LOG("Label submitted to the server. Proceed with next sample.")
samples_processed_++;
// Q_EMIT
// progress((int)(samples_processed_/session_->get_shapes().size())*100);
Expand Down
97 changes: 97 additions & 0 deletions Studio/ShapeWorksMONAI/MonaiLabelUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@
#include <QFileInfo>
#include <QSharedPointer>

#include <itkImageFileReader.h>
#include <itkImageFileWriter.h>
#include <itkBinaryThresholdImageFilter.h>
#include <itkCastImageFilter.h>
#include <itkImageRegionIterator.h>
#include <Data/Session.h>

using namespace shapeworks;

typedef float PixelType;
typedef itk::Image< PixelType, 3 > ImageType;
namespace monailabel {

bool MonaiLabelUtils::createDir(const QString& dirPath) {
Expand Down Expand Up @@ -45,4 +52,94 @@ std::string MonaiLabelUtils::getFeatureName(QSharedPointer<Session> session) {
return feature_name;
}

//---------------------------------------------------------------------------
ImageType::Pointer MonaiLabelUtils::loadNRRD(const std::string& filePath) {
using ReaderType = itk::ImageFileReader<ImageType>;
ReaderType::Pointer reader = ReaderType::New();
reader->SetFileName(filePath);
reader->Update();
return reader->GetOutput();
}

//---------------------------------------------------------------------------
bool MonaiLabelUtils::isOrganPresent(ImageType::Pointer image) {
itk::ImageRegionIterator<ImageType> it(image, image->GetRequestedRegion());
while (!it.IsAtEnd()) {
if (it.Get() > 0) { // If any non-background pixel is found
return true;
}
++it;
}
return false;
}
//---------------------------------------------------------------------------
ImageType::Pointer MonaiLabelUtils::extractOrganSegmentation(ImageType::Pointer inputImage, int label) {
using ThresholdFilterType = itk::BinaryThresholdImageFilter<ImageType, ImageType>;
ThresholdFilterType::Pointer thresholdFilter = ThresholdFilterType::New();
thresholdFilter->SetInput(inputImage);
thresholdFilter->SetLowerThreshold(label);
thresholdFilter->SetUpperThreshold(label);
// thresholdFilter->SetInsideValue(label);
thresholdFilter->SetInsideValue(1); // dont save as label
thresholdFilter->SetOutsideValue(0); // Background set to 0
thresholdFilter->Update();

ImageType::Pointer organImage = thresholdFilter->GetOutput();
// Check if organ is present in segmentation
if (thresholdFilter->GetOutput()->GetBufferedRegion().GetNumberOfPixels() == 0) {
return nullptr; // Return null if the organ is not present
} // save all
// if (!isOrganPresent(organImage)) {
// return nullptr; // Return nullptr if no valid organ pixels exist
// }
// return thresholdFilter->GetOutput();
return organImage;
}

//---------------------------------------------------------------------------
void MonaiLabelUtils::saveNRRD(ImageType::Pointer image, const std::string& outputPath) {
using WriterType = itk::ImageFileWriter<ImageType>;
WriterType::Pointer writer = WriterType::New();
writer->SetFileName(outputPath);
writer->SetInput(image);
writer->UseCompressionOn();
writer->Update();
}

//---------------------------------------------------------------------------
void MonaiLabelUtils::processSegmentation(
const std::string& segmentationPath,
const std::map<int, std::string>& organLabels, const std::string& outputDir,
const std::string& sampleId,
std::vector<std::string>& organSegmentationPaths) {

organSegmentationPaths.resize(0);
ImageType::Pointer inputImage = loadNRRD(segmentationPath);
if (!inputImage) {
SW_ERROR("Failed to load segmentation file: {}", segmentationPath);
return;
}

QDir projDir(QString::fromStdString(outputDir));
// if (!projDir.exists()) {
// projDir.mkpath(".");
// }

// Extract and save each organ segmentation
for (const auto& [label, organName] : organLabels) {
ImageType::Pointer organImage = extractOrganSegmentation(inputImage, label);

if (!organImage) {
SW_LOG("Warning: {} (Label {}) not found in segmentation.", organName, label);
continue;
}

QString destPath = projDir.filePath(
QString::fromStdString(sampleId + "_" + organName + ".nrrd"));
saveNRRD(organImage, destPath.toStdString());
SW_LOG("✅ Saved segmented organ: {}", destPath.toStdString());
organSegmentationPaths.push_back(destPath.toStdString());
}
}

} // namespace monailabel
18 changes: 17 additions & 1 deletion Studio/ShapeWorksMONAI/MonaiLabelUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@
#include <QValidator>
#include <QSharedPointer>
#include <Data/Session.h>
#include <itkImage.h>

namespace shapeworks {
class Session;
}
namespace monailabel {

typedef float PixelType;
typedef itk::Image< PixelType, 3 > ImageType;

class UrlValidator : public QValidator {
public:
Expand Down Expand Up @@ -45,7 +49,19 @@ class MonaiLabelUtils {
static bool copySegmentation(const QString& sourcePath,
const QString& destinationPath);
static bool deleteTempFile(const QString& filePath);
static std::string getFeatureName(QSharedPointer<shapeworks::Session> session);
static std::string getFeatureName(
QSharedPointer<shapeworks::Session> session);
static ImageType::Pointer loadNRRD(const std::string& filePath);
static ImageType::Pointer extractOrganSegmentation(ImageType::Pointer inputImage,
int label);
static void saveNRRD(ImageType::Pointer image, const std::string& outputPath);
static bool isOrganPresent(ImageType::Pointer image);

static void processSegmentation(const std::string& segmentationPath,
const std::map<int, std::string>& organLabels,
const std::string& outputDir,
const std::string& sampleId,
std::vector<std::string>& organSegmentationPaths);
};

} // namespace monailabel
Binary file added docs/img/swmonai/illustration-monai-0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/img/swmonai/illustration-monai-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/img/swmonai/illustration-monai-2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/img/swmonai/illustration-monai-3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/img/swmonai/illustration-monai-4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
13 changes: 13 additions & 0 deletions docs/new/ai-assisted-segmentation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

# AI-Assisted Segmentation in ShapeWorks

## Getting Started with MONAI Label in ShapeWorks

[`Medical Open Network for AI (MONAI) Label`](https://monai.io/) is a deep learning framework designed for efficient annotation and segmentation of medical images.

## What’s New?
ShapeWorks Studio now integrates MONAI Label, enabling seamless access to fully automated and interactive deep learning models for segmenting radiology images across various modalities.

For a detailed demo and step-by-step instructions on using MONAI Label within ShapeWorks Studio, refer to the following guide:

### [Getting Started with AI-Assisted Segmentation](../studio/ai-assisted-segmentation.md)
Loading
Loading