Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/deploy/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

# Need to import explicit ones to quiet mypy complaints
from holoscan.core import *
from holoscan.core import Application, Condition, ConditionType, Fragment, Operator, OperatorSpec
from holoscan.core import Application, Condition, ConditionType, Fragment, Operator, OperatorSpec, Resource

from .app_context import AppContext, init_app_context
from .arg_parser import parse_args
Expand Down
205 changes: 160 additions & 45 deletions monai/deploy/operators/monai_seg_inference_operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021-2023 MONAI Consortium
# Copyright 2021-2025 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -9,19 +9,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import logging
import os
from pathlib import Path
from threading import Lock
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch

from monai.deploy.utils.importutil import optional_import
from monai.utils import StrEnum # Will use the built-in StrEnum when SDK requires Python 3.11.

MONAI_UTILS = "monai.utils"
torch, _ = optional_import("torch", "1.5")
np_str_obj_array_pattern, _ = optional_import("torch.utils.data._utils.collate", name="np_str_obj_array_pattern")
Dataset, _ = optional_import("monai.data", name="Dataset")
DataLoader, _ = optional_import("monai.data", name="DataLoader")
Expand All @@ -40,7 +41,7 @@
# Dynamic class is not handled so make it Any for now: https://github.com/python/mypy/issues/2477
Compose: Any = Compose_

from monai.deploy.core import AppContext, ConditionType, Fragment, Image, OperatorSpec
from monai.deploy.core import AppContext, Condition, ConditionType, Fragment, Image, OperatorSpec, Resource

from .inference_operator import InferenceOperator

Expand All @@ -56,14 +57,20 @@ class InfererType(StrEnum):

# @md.env(pip_packages=["monai>=1.0.0", "torch>=1.10.2", "numpy>=1.21"])
class MonaiSegInferenceOperator(InferenceOperator):
"""This segmentation operator uses MONAI transforms and Sliding Window Inference.
"""This segmentation operator uses MONAI transforms and performs Simple or Sliding Window Inference.

This operator performs pre-transforms on a input image, inference
using a given model, and post-transforms. The segmentation image is saved
as a named Image object in memory.

If specified in the post transforms, results may also be saved to disk.

This operator uses the MONAI inference utils functions for sliding window and simple inference,
and thus input parameters need to be as expected by these functions.

Any additional sliding window arguments not explicitly defined in this operator can be passed via
**kwargs for forwarding to 'sliding_window_inference'.

Named Input:
image: Image object of the input image.

Expand All @@ -74,6 +81,47 @@ class MonaiSegInferenceOperator(InferenceOperator):
# For testing the app directly, the model should be at the following path.
MODEL_LOCAL_PATH = Path(os.environ.get("HOLOSCAN_MODEL_PATH", Path.cwd() / "model/model.ts"))

@staticmethod
def filter_sw_kwargs(**kwargs) -> Dict[str, Any]:
"""
Returns a dictionary of named parameters of the sliding_window_inference function that:
- Are not explicitly defined in the __init__ of this class
- Are not explicitly used when calling sliding_window_inference
- Can be successfully converted from Python --> Holoscan's C++ layer

Args:
**kwargs: extra arguments passed into __init__ beyond the explicitly defined args.

Returns:
filtered_params: A filtered dictionary of arguments to be passed to sliding_window_inference.
"""

logger = logging.getLogger(f"{__name__}.{MonaiSegInferenceOperator.__name__}")

init_params = inspect.signature(MonaiSegInferenceOperator).parameters

# inputs + predictor explicitly used when calling sliding_window_inference
explicit_used = {"inputs", "predictor"}

# Holoscan convertible types (not exhaustive)
# This will be revisited when there is a better way to handle this.
allowed_types = (str, int, float, bool, bytes, list, tuple, torch.Tensor, Condition, Resource)

filtered_params = {}
for name, val in kwargs.items():
# Drop explicitly defined kwargs
if name in init_params or name in explicit_used:
logger.warning(f"{name!r} is already explicitly defined or used; dropping kwarg.")
continue
# Drop kwargs that can't be converted by Holoscan
if not isinstance(val, allowed_types):
logger.warning(
f"{name!r} type of {type(val).__name__!r} is a non-convertible kwarg for Holoscan; dropping kwarg."
)
continue
filtered_params[name] = val
return filtered_params

def __init__(
self,
fragment: Fragment,
Expand All @@ -83,8 +131,11 @@ def __init__(
post_transforms: Compose,
app_context: AppContext,
model_name: Optional[str] = "",
overlap: float = 0.25,
sw_batch_size: int = 4,
overlap: Union[Sequence[float], float] = 0.25,
sw_device: Optional[Union[torch.device, str]] = None,
device: Optional[Union[torch.device, str]] = None,
process_fn: Optional[Callable] = None,
inferer: Union[InfererType, str] = InfererType.SLIDING_WINDOW,
model_path: Path = MODEL_LOCAL_PATH,
**kwargs,
Expand All @@ -93,19 +144,25 @@ def __init__(

Args:
fragment (Fragment): An instance of the Application class which is derived from Fragment.
roi_size (Union[Sequence[int], int]): The window size to execute "SLIDING_WINDOW" evaluation.
An optional input only to be passed for "SLIDING_WINDOW".
If using a "SIMPLE" Inferer, this input is ignored.
roi_size (Sequence[int], int, optional): The window size to execute "SLIDING_WINDOW" evaluation.
Applicable for "SLIDING_WINDOW" only.
pre_transforms (Compose): MONAI Compose object used for pre-transforms.
post_transforms (Compose): MONAI Compose object used for post-transforms.
app_context (AppContext): Object holding the I/O and model paths, and potentially loaded models.
model_name (str, optional): Name of the model. Default to "" for single model app.
overlap (float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25.
sw_batch_size (int): The batch size to run window slices. Defaults to 4.
Applicable for "SLIDING_WINDOW" only.
overlap (Sequence[float], float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25.
Applicable for "SLIDING_WINDOW" only.
sw_batch_size(int): The batch size to run window slices. Defaults to 4.
Applicable for "SLIDING_WINDOW" only.
inferer (InfererType): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
sw_device (torch.device, str, optional): Device for the window data. Defaults to None.
Applicable for "SLIDING_WINDOW" only.
device: (torch.device, str, optional): Device for the stitched output prediction. Defaults to None.
Applicable for "SLIDING_WINDOW" only.
process_fn: (Callable, optional): process inference output and adjust the importance map per window. Defaults to None.
Applicable for "SLIDING_WINDOW" only.
inferer (InfererType, str): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
model_path (Path): Path to the model file. Defaults to model/models.ts of current working dir.
**kwargs: any other sliding window parameters to forward (e.g. `mode`, `cval`, etc.).
"""

self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
Expand All @@ -115,26 +172,33 @@ def __init__(
self._pred_dataset_key = "pred"
self._input_image = None # Image will come in when compute is called.
self._reader: Any = None
self._roi_size = ensure_tuple(roi_size)
self._pre_transform = pre_transforms
self._post_transforms = post_transforms
self.roi_size = ensure_tuple(roi_size)
self.pre_transforms = pre_transforms
self.post_transforms = post_transforms
self._model_name = model_name.strip() if isinstance(model_name, str) else ""
self._overlap = overlap
self._sw_batch_size = sw_batch_size
self._inferer = inferer
self.overlap = overlap
self.sw_batch_size = sw_batch_size
self.inferer = inferer
self._implicit_params = self.filter_sw_kwargs(**kwargs) # Filter keyword args

# Sliding window inference args whose type Holoscan can't convert - define explicitly
self.sw_device = sw_device
self.device = device
self.process_fn = process_fn

# Add this so that the local model path can be set from the calling app
self.model_path = model_path
self.input_name_image = "image"
self.output_name_seg = "seg_image"
self._input_name_image = "image"
self._output_name_seg = "seg_image"

# The execution context passed in on compute does not have the required model info, so need to
# get and keep the model via the AppContext obj on construction.
self.app_context = app_context

self.model = self._get_model(self.app_context, self.model_path, self._model_name)
self._model = self._get_model(self.app_context, self.model_path, self._model_name)

super().__init__(fragment, *args, **kwargs)
# Pass filtered kwargs
super().__init__(fragment, *args, **self._implicit_params)

def _get_model(self, app_context: AppContext, model_path: Path, model_name: str):
"""Load the model with the given name from context or model path
Expand All @@ -159,8 +223,8 @@ def _get_model(self, app_context: AppContext, model_path: Path, model_name: str)
return model

def setup(self, spec: OperatorSpec):
spec.input(self.input_name_image)
spec.output(self.output_name_seg).condition(ConditionType.NONE) # Downstream receiver optional.
spec.input(self._input_name_image)
spec.output(self._output_name_seg).condition(ConditionType.NONE) # Downstream receiver optional.

@property
def roi_size(self):
Expand Down Expand Up @@ -199,9 +263,13 @@ def overlap(self):
return self._overlap

@overlap.setter
def overlap(self, val: float):
if val < 0 or val > 1:
raise ValueError("Overlap must be between 0 and 1.")
def overlap(self, val: Union[Sequence[float], float]):
if not isinstance(val, (Sequence, int, float)) or isinstance(val, str):
raise TypeError(f"Overlap must be type Sequence[float] | float, got {type(val).__name__}.")
elif isinstance(val, Sequence) and not all(isinstance(x, (int, float)) and 0 <= x < 1 for x in val):
raise ValueError("Each overlap value must be >= 0 and < 1.")
elif isinstance(val, (int, float)) and not (0 <= float(val) < 1):
raise ValueError(f"Overlap must be >= 0 and < 1, got {val}.")
self._overlap = val

@property
Expand All @@ -215,16 +283,59 @@ def sw_batch_size(self, val: int):
raise ValueError("sw_batch_size must be a positive integer.")
self._sw_batch_size = val

@property
def sw_device(self):
"""Device for the window data."""
return self._sw_device

@sw_device.setter
def sw_device(self, val: Optional[Union[torch.device, str]]):
if val is not None and not isinstance(val, (torch.device, str)):
raise TypeError(f"sw_device must be type torch.device | str | None, got {type(val).__name__}.")
self._sw_device = val

@property
def device(self):
"""Device for the stitched output prediction."""
return self._device

@device.setter
def device(self, val: Optional[Union[torch.device, str]]):
if val is not None and not isinstance(val, (torch.device, str)):
raise TypeError(f"device must be type torch.device | str | None, got {type(val).__name__}.")
self._device = val

@property
def process_fn(self):
"""Process inference output and adjust the importance map per window."""
return self._process_fn

@process_fn.setter
def process_fn(self, val: Optional[Callable]):
if val is not None and not callable(val):
raise TypeError(f"process_fn must be type Callable | None, got {type(val).__name__}.")
self._process_fn = val

@property
def inferer(self) -> Union[InfererType, str]:
"""The type of inferer to use"""
return self._inferer

@inferer.setter
def inferer(self, val: InfererType):
if not isinstance(val, InfererType):
raise ValueError(f"Value must be of the correct type {InfererType}.")
self._inferer = val
def inferer(self, val: Union[InfererType, str]):
if isinstance(val, InfererType):
self._inferer = val
return

if isinstance(val, str):
s = val.strip().lower()
valid = (InfererType.SIMPLE.value, InfererType.SLIDING_WINDOW.value)
if s in valid:
self._inferer = InfererType(s)
return
raise ValueError(f"inferer must be one of {valid}, got {val!r}.")

raise TypeError(f"inferer must be InfererType or str, got {type(val).__name__}.")

def _convert_dicom_metadata_datatype(self, metadata: Dict):
"""Converts metadata in pydicom types to the corresponding native types.
Expand Down Expand Up @@ -279,10 +390,10 @@ def compute(self, op_input, op_output, context):
else:
self._executing = True
try:
input_image = op_input.receive(self.input_name_image)
input_image = op_input.receive(self._input_name_image)
if not input_image:
raise ValueError("Input is None.")
op_output.emit(self.compute_impl(input_image, context), self.output_name_seg)
op_output.emit(self.compute_impl(input_image, context), self._output_name_seg)
finally:
# Reset state on completing this method execution.
with self._lock:
Expand All @@ -297,12 +408,12 @@ def compute_impl(self, input_image, context):
# Need to give a name to the image as in-mem Image obj has no name.
img_name = str(input_img_metadata.get("SeriesInstanceUID", "Img_in_context"))

pre_transforms: Compose = self._pre_transform
post_transforms: Compose = self._post_transforms
pre_transforms: Compose = self.pre_transforms
post_transforms: Compose = self.post_transforms
self._reader = InMemImageReader(input_image)

pre_transforms = self._pre_transform if self._pre_transform else self.pre_process(self._reader)
post_transforms = self._post_transforms if self._post_transforms else self.post_process(pre_transforms)
pre_transforms = self.pre_transforms if self.pre_transforms else self.pre_process(self._reader)
post_transforms = self.post_transforms if self.post_transforms else self.post_process(pre_transforms)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = Dataset(data=[{self._input_dataset_key: img_name}], transform=pre_transforms)
Expand All @@ -314,20 +425,24 @@ def compute_impl(self, input_image, context):
for d in dataloader:
images = d[self._input_dataset_key].to(device)
self._logger.info(f"Input of {type(images)} shape: {images.shape}")
if self._inferer == InfererType.SLIDING_WINDOW:
if self.inferer == InfererType.SLIDING_WINDOW:
d[self._pred_dataset_key] = sliding_window_inference(
inputs=images,
roi_size=self._roi_size,
sw_batch_size=self.sw_batch_size,
roi_size=self.roi_size,
overlap=self.overlap,
predictor=self.model,
sw_batch_size=self.sw_batch_size,
sw_device=self.sw_device,
device=self.device,
process_fn=self.process_fn,
predictor=self._model,
**self._implicit_params, # Additional sliding window arguments
)
elif self._inferer == InfererType.SIMPLE:
elif self.inferer == InfererType.SIMPLE:
# Instantiates the SimpleInferer and directly uses its __call__ function
d[self._pred_dataset_key] = simple_inference()(inputs=images, network=self.model)
d[self._pred_dataset_key] = simple_inference()(inputs=images, network=self._model)
else:
raise ValueError(
f"Unknown inferer: {self._inferer!r}. Available options are "
f"Unknown inferer: {self.inferer!r}. Available options are "
f"{InfererType.SLIDING_WINDOW!r} and {InfererType.SIMPLE!r}."
)

Expand Down