Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions monai/deploy/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 MONAI Consortium
# Copyright 2021-2025 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -30,7 +30,7 @@

# Need to import explicit ones to quiet mypy complaints
from holoscan.core import *
from holoscan.core import Application, Condition, ConditionType, Fragment, Operator, OperatorSpec
from holoscan.core import Application, Condition, ConditionType, Fragment, Operator, OperatorSpec, Resource

from .app_context import AppContext, init_app_context
from .arg_parser import parse_args
Expand Down
169 changes: 126 additions & 43 deletions monai/deploy/operators/monai_seg_inference_operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021-2023 MONAI Consortium
# Copyright 2021-2025 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import logging
import os
from pathlib import Path
Expand Down Expand Up @@ -40,7 +41,7 @@
# Dynamic class is not handled so make it Any for now: https://github.com/python/mypy/issues/2477
Compose: Any = Compose_

from monai.deploy.core import AppContext, ConditionType, Fragment, Image, OperatorSpec
from monai.deploy.core import AppContext, Condition, ConditionType, Fragment, Image, OperatorSpec, Resource

from .inference_operator import InferenceOperator

Expand All @@ -56,14 +57,20 @@ class InfererType(StrEnum):

# @md.env(pip_packages=["monai>=1.0.0", "torch>=1.10.2", "numpy>=1.21"])
class MonaiSegInferenceOperator(InferenceOperator):
"""This segmentation operator uses MONAI transforms and Sliding Window Inference.
"""This segmentation operator uses MONAI transforms and performs Simple or Sliding Window Inference.

This operator performs pre-transforms on a input image, inference
using a given model, and post-transforms. The segmentation image is saved
as a named Image object in memory.

If specified in the post transforms, results may also be saved to disk.

This operator uses the MONAI inference utils functions for sliding window and simple inference,
and thus input parameters need to be as expected by these functions.

Any additional sliding window arguments not explicitly defined in this operator can be passed via
**kwargs for forwarding to 'sliding_window_inference'.

Named Input:
image: Image object of the input image.

Expand All @@ -74,6 +81,63 @@ class MonaiSegInferenceOperator(InferenceOperator):
# For testing the app directly, the model should be at the following path.
MODEL_LOCAL_PATH = Path(os.environ.get("HOLOSCAN_MODEL_PATH", Path.cwd() / "model/model.ts"))

@staticmethod
def filter_sw_kwargs(**kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Filters the keyword arguments into a tuple of two dictionaries:

1. A dictionary of named parameters to pass to the sliding_window_inference function that:
- Are not explicitly defined in the __init__ of this class
- Are not explicitly used when calling sliding_window_inference

2. A dicionary of named parameters to pass to the base class __init__ of this class that:
- Are not used when calling sliding_window_inference
- Can be successfully converted from Python --> Holoscan's C++ layer

Args:
**kwargs: extra arguments passed into __init__ beyond the explicitly defined args.

Returns:
filtered_swi_params: A filtered dictionary of arguments to be passed to sliding_window_inference.
filtered_base_init_params: A filtered dictionary of arguments to be passed to the base class __init__.
"""

logger = logging.getLogger(f"{__name__}.{MonaiSegInferenceOperator.__name__}")

init_params = inspect.signature(MonaiSegInferenceOperator).parameters
swi_params = inspect.signature(sliding_window_inference).parameters

# inputs + predictor explicitly used when calling sliding_window_inference
explicit_used = {"inputs", "predictor"}

# Holoscan convertible types (not exhaustive)
# This will be revisited when there is a better way to handle this.
allowed_types = (str, int, float, bool, bytes, list, tuple, torch.Tensor, Condition, Resource)

filtered_swi_params = {}
filtered_base_init_params = {}

for name, val in kwargs.items():
# Drop explicitly defined kwargs
if name in init_params or name in explicit_used:
logger.warning(f"{name!r} is already explicitly defined or used; dropping kwarg.")
# SWI params
elif name in swi_params:
filtered_swi_params[name] = val
logger.debug(f"{name!r} used in sliding_window_inference; keeping kwarg for inference call.")
# Drop kwargs that can't be converted by Holoscan
elif not isinstance(val, allowed_types):
logger.warning(
f"{name!r} type of {type(val).__name__!r} is a non-convertible kwarg for Holoscan; dropping kwarg."
)
# Base __init__ params
else:
filtered_base_init_params[name] = val
logger.debug(
f"{name!r} type of {type(val).__name__!r} can be converted by Holoscan; keeping kwarg for base init."
)
return filtered_swi_params, filtered_base_init_params

def __init__(
self,
fragment: Fragment,
Expand All @@ -83,7 +147,7 @@ def __init__(
post_transforms: Compose,
app_context: AppContext,
model_name: Optional[str] = "",
overlap: float = 0.25,
overlap: Union[Sequence[float], float] = 0.25,
sw_batch_size: int = 4,
inferer: Union[InfererType, str] = InfererType.SLIDING_WINDOW,
model_path: Path = MODEL_LOCAL_PATH,
Expand All @@ -93,19 +157,19 @@ def __init__(

Args:
fragment (Fragment): An instance of the Application class which is derived from Fragment.
roi_size (Union[Sequence[int], int]): The window size to execute "SLIDING_WINDOW" evaluation.
An optional input only to be passed for "SLIDING_WINDOW".
If using a "SIMPLE" Inferer, this input is ignored.
roi_size (Sequence[int], int, optional): The window size to execute "SLIDING_WINDOW" evaluation.
Applicable for "SLIDING_WINDOW" only.
pre_transforms (Compose): MONAI Compose object used for pre-transforms.
post_transforms (Compose): MONAI Compose object used for post-transforms.
app_context (AppContext): Object holding the I/O and model paths, and potentially loaded models.
model_name (str, optional): Name of the model. Default to "" for single model app.
overlap (float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25.
overlap (Sequence[float], float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25.
Applicable for "SLIDING_WINDOW" only.
sw_batch_size (int): The batch size to run window slices. Defaults to 4.
Applicable for "SLIDING_WINDOW" only.
sw_batch_size(int): The batch size to run window slices. Defaults to 4.
Applicable for "SLIDING_WINDOW" only.
inferer (InfererType): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
inferer (InfererType, str): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
model_path (Path): Path to the model file. Defaults to model/models.ts of current working dir.
**kwargs: any other sliding window parameters to forward (e.g. `mode`, `cval`, etc.).
"""

self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
Expand All @@ -115,26 +179,30 @@ def __init__(
self._pred_dataset_key = "pred"
self._input_image = None # Image will come in when compute is called.
self._reader: Any = None
self._roi_size = ensure_tuple(roi_size)
self._pre_transform = pre_transforms
self._post_transforms = post_transforms
self.roi_size = ensure_tuple(roi_size)
self.pre_transforms = pre_transforms
self.post_transforms = post_transforms
self._model_name = model_name.strip() if isinstance(model_name, str) else ""
self._overlap = overlap
self._sw_batch_size = sw_batch_size
self._inferer = inferer
self.overlap = overlap
self.sw_batch_size = sw_batch_size
self.inferer = inferer
self._filtered_swi_params, self._filtered_base_init_params = self.filter_sw_kwargs(
**kwargs
) # Filter keyword args

# Add this so that the local model path can be set from the calling app
self.model_path = model_path
self.input_name_image = "image"
self.output_name_seg = "seg_image"
self._input_name_image = "image"
self._output_name_seg = "seg_image"

# The execution context passed in on compute does not have the required model info, so need to
# get and keep the model via the AppContext obj on construction.
self.app_context = app_context

self.model = self._get_model(self.app_context, self.model_path, self._model_name)
self._model = self._get_model(self.app_context, self.model_path, self._model_name)

super().__init__(fragment, *args, **kwargs)
# Pass filtered base init params
super().__init__(fragment, *args, **self._filtered_base_init_params)

def _get_model(self, app_context: AppContext, model_path: Path, model_name: str):
"""Load the model with the given name from context or model path
Expand All @@ -159,8 +227,8 @@ def _get_model(self, app_context: AppContext, model_path: Path, model_name: str)
return model

def setup(self, spec: OperatorSpec):
spec.input(self.input_name_image)
spec.output(self.output_name_seg).condition(ConditionType.NONE) # Downstream receiver optional.
spec.input(self._input_name_image)
spec.output(self._output_name_seg).condition(ConditionType.NONE) # Downstream receiver optional.

@property
def roi_size(self):
Expand Down Expand Up @@ -199,9 +267,13 @@ def overlap(self):
return self._overlap

@overlap.setter
def overlap(self, val: float):
if val < 0 or val > 1:
raise ValueError("Overlap must be between 0 and 1.")
def overlap(self, val: Union[Sequence[float], float]):
if not isinstance(val, (Sequence, int, float)) or isinstance(val, str):
raise TypeError(f"Overlap must be type Sequence[float] | float, got {type(val).__name__}.")
elif isinstance(val, Sequence) and not all(isinstance(x, (int, float)) and 0 <= x < 1 for x in val):
raise ValueError("Each overlap value must be >= 0 and < 1.")
elif isinstance(val, (int, float)) and not (0 <= float(val) < 1):
raise ValueError(f"Overlap must be >= 0 and < 1, got {val}.")
self._overlap = val

@property
Expand All @@ -221,10 +293,20 @@ def inferer(self) -> Union[InfererType, str]:
return self._inferer

@inferer.setter
def inferer(self, val: InfererType):
if not isinstance(val, InfererType):
raise ValueError(f"Value must be of the correct type {InfererType}.")
self._inferer = val
def inferer(self, val: Union[InfererType, str]):
if isinstance(val, InfererType):
self._inferer = val
return

if isinstance(val, str):
s = val.strip().lower()
valid = (InfererType.SIMPLE.value, InfererType.SLIDING_WINDOW.value)
if s in valid:
self._inferer = InfererType(s)
return
raise ValueError(f"inferer must be one of {valid}, got {val!r}.")

raise TypeError(f"inferer must be InfererType or str, got {type(val).__name__}.")

def _convert_dicom_metadata_datatype(self, metadata: Dict):
"""Converts metadata in pydicom types to the corresponding native types.
Expand Down Expand Up @@ -279,10 +361,10 @@ def compute(self, op_input, op_output, context):
else:
self._executing = True
try:
input_image = op_input.receive(self.input_name_image)
input_image = op_input.receive(self._input_name_image)
if not input_image:
raise ValueError("Input is None.")
op_output.emit(self.compute_impl(input_image, context), self.output_name_seg)
op_output.emit(self.compute_impl(input_image, context), self._output_name_seg)
finally:
# Reset state on completing this method execution.
with self._lock:
Expand All @@ -297,12 +379,12 @@ def compute_impl(self, input_image, context):
# Need to give a name to the image as in-mem Image obj has no name.
img_name = str(input_img_metadata.get("SeriesInstanceUID", "Img_in_context"))

pre_transforms: Compose = self._pre_transform
post_transforms: Compose = self._post_transforms
pre_transforms: Compose = self.pre_transforms
post_transforms: Compose = self.post_transforms
self._reader = InMemImageReader(input_image)

pre_transforms = self._pre_transform if self._pre_transform else self.pre_process(self._reader)
post_transforms = self._post_transforms if self._post_transforms else self.post_process(pre_transforms)
pre_transforms = self.pre_transforms if self.pre_transforms else self.pre_process(self._reader)
post_transforms = self.post_transforms if self.post_transforms else self.post_process(pre_transforms)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = Dataset(data=[{self._input_dataset_key: img_name}], transform=pre_transforms)
Expand All @@ -314,20 +396,21 @@ def compute_impl(self, input_image, context):
for d in dataloader:
images = d[self._input_dataset_key].to(device)
self._logger.info(f"Input of {type(images)} shape: {images.shape}")
if self._inferer == InfererType.SLIDING_WINDOW:
if self.inferer == InfererType.SLIDING_WINDOW:
d[self._pred_dataset_key] = sliding_window_inference(
inputs=images,
roi_size=self._roi_size,
sw_batch_size=self.sw_batch_size,
roi_size=self.roi_size,
overlap=self.overlap,
predictor=self.model,
sw_batch_size=self.sw_batch_size,
predictor=self._model,
**self._filtered_swi_params, # Additional sliding window arguments
)
elif self._inferer == InfererType.SIMPLE:
elif self.inferer == InfererType.SIMPLE:
# Instantiates the SimpleInferer and directly uses its __call__ function
d[self._pred_dataset_key] = simple_inference()(inputs=images, network=self.model)
d[self._pred_dataset_key] = simple_inference()(inputs=images, network=self._model)
else:
raise ValueError(
f"Unknown inferer: {self._inferer!r}. Available options are "
f"Unknown inferer: {self.inferer!r}. Available options are "
f"{InfererType.SLIDING_WINDOW!r} and {InfererType.SIMPLE!r}."
)

Expand Down