diff --git a/monai/deploy/core/__init__.py b/monai/deploy/core/__init__.py index deba6250..3967467f 100644 --- a/monai/deploy/core/__init__.py +++ b/monai/deploy/core/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 MONAI Consortium +# Copyright 2021-2025 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -30,7 +30,7 @@ # Need to import explicit ones to quiet mypy complaints from holoscan.core import * -from holoscan.core import Application, Condition, ConditionType, Fragment, Operator, OperatorSpec +from holoscan.core import Application, Condition, ConditionType, Fragment, Operator, OperatorSpec, Resource from .app_context import AppContext, init_app_context from .arg_parser import parse_args diff --git a/monai/deploy/operators/monai_seg_inference_operator.py b/monai/deploy/operators/monai_seg_inference_operator.py index 63d6007a..20003f57 100644 --- a/monai/deploy/operators/monai_seg_inference_operator.py +++ b/monai/deploy/operators/monai_seg_inference_operator.py @@ -1,4 +1,4 @@ -# Copyright 2021-2023 MONAI Consortium +# Copyright 2021-2025 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging import os from pathlib import Path @@ -40,7 +41,7 @@ # Dynamic class is not handled so make it Any for now: https://github.com/python/mypy/issues/2477 Compose: Any = Compose_ -from monai.deploy.core import AppContext, ConditionType, Fragment, Image, OperatorSpec +from monai.deploy.core import AppContext, Condition, ConditionType, Fragment, Image, OperatorSpec, Resource from .inference_operator import InferenceOperator @@ -56,7 +57,7 @@ class InfererType(StrEnum): # @md.env(pip_packages=["monai>=1.0.0", "torch>=1.10.2", "numpy>=1.21"]) class MonaiSegInferenceOperator(InferenceOperator): - """This segmentation operator uses MONAI transforms and Sliding Window Inference. + """This segmentation operator uses MONAI transforms and performs Simple or Sliding Window Inference. This operator performs pre-transforms on a input image, inference using a given model, and post-transforms. The segmentation image is saved @@ -64,6 +65,12 @@ class MonaiSegInferenceOperator(InferenceOperator): If specified in the post transforms, results may also be saved to disk. + This operator uses the MONAI inference utils functions for sliding window and simple inference, + and thus input parameters need to be as expected by these functions. + + Any additional sliding window arguments not explicitly defined in this operator can be passed via + **kwargs for forwarding to 'sliding_window_inference'. + Named Input: image: Image object of the input image. @@ -74,6 +81,63 @@ class MonaiSegInferenceOperator(InferenceOperator): # For testing the app directly, the model should be at the following path. MODEL_LOCAL_PATH = Path(os.environ.get("HOLOSCAN_MODEL_PATH", Path.cwd() / "model/model.ts")) + @staticmethod + def filter_sw_kwargs(**kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + Filters the keyword arguments into a tuple of two dictionaries: + + 1. A dictionary of named parameters to pass to the sliding_window_inference function that: + - Are not explicitly defined in the __init__ of this class + - Are not explicitly used when calling sliding_window_inference + + 2. A dicionary of named parameters to pass to the base class __init__ of this class that: + - Are not used when calling sliding_window_inference + - Can be successfully converted from Python --> Holoscan's C++ layer + + Args: + **kwargs: extra arguments passed into __init__ beyond the explicitly defined args. + + Returns: + filtered_swi_params: A filtered dictionary of arguments to be passed to sliding_window_inference. + filtered_base_init_params: A filtered dictionary of arguments to be passed to the base class __init__. + """ + + logger = logging.getLogger(f"{__name__}.{MonaiSegInferenceOperator.__name__}") + + init_params = inspect.signature(MonaiSegInferenceOperator).parameters + swi_params = inspect.signature(sliding_window_inference).parameters + + # inputs + predictor explicitly used when calling sliding_window_inference + explicit_used = {"inputs", "predictor"} + + # Holoscan convertible types (not exhaustive) + # This will be revisited when there is a better way to handle this. + allowed_types = (str, int, float, bool, bytes, list, tuple, torch.Tensor, Condition, Resource) + + filtered_swi_params = {} + filtered_base_init_params = {} + + for name, val in kwargs.items(): + # Drop explicitly defined kwargs + if name in init_params or name in explicit_used: + logger.warning(f"{name!r} is already explicitly defined or used; dropping kwarg.") + # SWI params + elif name in swi_params: + filtered_swi_params[name] = val + logger.debug(f"{name!r} used in sliding_window_inference; keeping kwarg for inference call.") + # Drop kwargs that can't be converted by Holoscan + elif not isinstance(val, allowed_types): + logger.warning( + f"{name!r} type of {type(val).__name__!r} is a non-convertible kwarg for Holoscan; dropping kwarg." + ) + # Base __init__ params + else: + filtered_base_init_params[name] = val + logger.debug( + f"{name!r} type of {type(val).__name__!r} can be converted by Holoscan; keeping kwarg for base init." + ) + return filtered_swi_params, filtered_base_init_params + def __init__( self, fragment: Fragment, @@ -83,7 +147,7 @@ def __init__( post_transforms: Compose, app_context: AppContext, model_name: Optional[str] = "", - overlap: float = 0.25, + overlap: Union[Sequence[float], float] = 0.25, sw_batch_size: int = 4, inferer: Union[InfererType, str] = InfererType.SLIDING_WINDOW, model_path: Path = MODEL_LOCAL_PATH, @@ -93,19 +157,19 @@ def __init__( Args: fragment (Fragment): An instance of the Application class which is derived from Fragment. - roi_size (Union[Sequence[int], int]): The window size to execute "SLIDING_WINDOW" evaluation. - An optional input only to be passed for "SLIDING_WINDOW". - If using a "SIMPLE" Inferer, this input is ignored. + roi_size (Sequence[int], int, optional): The window size to execute "SLIDING_WINDOW" evaluation. + Applicable for "SLIDING_WINDOW" only. pre_transforms (Compose): MONAI Compose object used for pre-transforms. post_transforms (Compose): MONAI Compose object used for post-transforms. app_context (AppContext): Object holding the I/O and model paths, and potentially loaded models. model_name (str, optional): Name of the model. Default to "" for single model app. - overlap (float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25. + overlap (Sequence[float], float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25. + Applicable for "SLIDING_WINDOW" only. + sw_batch_size (int): The batch size to run window slices. Defaults to 4. Applicable for "SLIDING_WINDOW" only. - sw_batch_size(int): The batch size to run window slices. Defaults to 4. - Applicable for "SLIDING_WINDOW" only. - inferer (InfererType): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW". + inferer (InfererType, str): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW". model_path (Path): Path to the model file. Defaults to model/models.ts of current working dir. + **kwargs: any other sliding window parameters to forward (e.g. `mode`, `cval`, etc.). """ self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__)) @@ -115,26 +179,30 @@ def __init__( self._pred_dataset_key = "pred" self._input_image = None # Image will come in when compute is called. self._reader: Any = None - self._roi_size = ensure_tuple(roi_size) - self._pre_transform = pre_transforms - self._post_transforms = post_transforms + self.roi_size = ensure_tuple(roi_size) + self.pre_transforms = pre_transforms + self.post_transforms = post_transforms self._model_name = model_name.strip() if isinstance(model_name, str) else "" - self._overlap = overlap - self._sw_batch_size = sw_batch_size - self._inferer = inferer + self.overlap = overlap + self.sw_batch_size = sw_batch_size + self.inferer = inferer + self._filtered_swi_params, self._filtered_base_init_params = self.filter_sw_kwargs( + **kwargs + ) # Filter keyword args # Add this so that the local model path can be set from the calling app self.model_path = model_path - self.input_name_image = "image" - self.output_name_seg = "seg_image" + self._input_name_image = "image" + self._output_name_seg = "seg_image" # The execution context passed in on compute does not have the required model info, so need to # get and keep the model via the AppContext obj on construction. self.app_context = app_context - self.model = self._get_model(self.app_context, self.model_path, self._model_name) + self._model = self._get_model(self.app_context, self.model_path, self._model_name) - super().__init__(fragment, *args, **kwargs) + # Pass filtered base init params + super().__init__(fragment, *args, **self._filtered_base_init_params) def _get_model(self, app_context: AppContext, model_path: Path, model_name: str): """Load the model with the given name from context or model path @@ -159,8 +227,8 @@ def _get_model(self, app_context: AppContext, model_path: Path, model_name: str) return model def setup(self, spec: OperatorSpec): - spec.input(self.input_name_image) - spec.output(self.output_name_seg).condition(ConditionType.NONE) # Downstream receiver optional. + spec.input(self._input_name_image) + spec.output(self._output_name_seg).condition(ConditionType.NONE) # Downstream receiver optional. @property def roi_size(self): @@ -199,9 +267,13 @@ def overlap(self): return self._overlap @overlap.setter - def overlap(self, val: float): - if val < 0 or val > 1: - raise ValueError("Overlap must be between 0 and 1.") + def overlap(self, val: Union[Sequence[float], float]): + if not isinstance(val, (Sequence, int, float)) or isinstance(val, str): + raise TypeError(f"Overlap must be type Sequence[float] | float, got {type(val).__name__}.") + elif isinstance(val, Sequence) and not all(isinstance(x, (int, float)) and 0 <= x < 1 for x in val): + raise ValueError("Each overlap value must be >= 0 and < 1.") + elif isinstance(val, (int, float)) and not (0 <= float(val) < 1): + raise ValueError(f"Overlap must be >= 0 and < 1, got {val}.") self._overlap = val @property @@ -221,10 +293,20 @@ def inferer(self) -> Union[InfererType, str]: return self._inferer @inferer.setter - def inferer(self, val: InfererType): - if not isinstance(val, InfererType): - raise ValueError(f"Value must be of the correct type {InfererType}.") - self._inferer = val + def inferer(self, val: Union[InfererType, str]): + if isinstance(val, InfererType): + self._inferer = val + return + + if isinstance(val, str): + s = val.strip().lower() + valid = (InfererType.SIMPLE.value, InfererType.SLIDING_WINDOW.value) + if s in valid: + self._inferer = InfererType(s) + return + raise ValueError(f"inferer must be one of {valid}, got {val!r}.") + + raise TypeError(f"inferer must be InfererType or str, got {type(val).__name__}.") def _convert_dicom_metadata_datatype(self, metadata: Dict): """Converts metadata in pydicom types to the corresponding native types. @@ -279,10 +361,10 @@ def compute(self, op_input, op_output, context): else: self._executing = True try: - input_image = op_input.receive(self.input_name_image) + input_image = op_input.receive(self._input_name_image) if not input_image: raise ValueError("Input is None.") - op_output.emit(self.compute_impl(input_image, context), self.output_name_seg) + op_output.emit(self.compute_impl(input_image, context), self._output_name_seg) finally: # Reset state on completing this method execution. with self._lock: @@ -297,12 +379,12 @@ def compute_impl(self, input_image, context): # Need to give a name to the image as in-mem Image obj has no name. img_name = str(input_img_metadata.get("SeriesInstanceUID", "Img_in_context")) - pre_transforms: Compose = self._pre_transform - post_transforms: Compose = self._post_transforms + pre_transforms: Compose = self.pre_transforms + post_transforms: Compose = self.post_transforms self._reader = InMemImageReader(input_image) - pre_transforms = self._pre_transform if self._pre_transform else self.pre_process(self._reader) - post_transforms = self._post_transforms if self._post_transforms else self.post_process(pre_transforms) + pre_transforms = self.pre_transforms if self.pre_transforms else self.pre_process(self._reader) + post_transforms = self.post_transforms if self.post_transforms else self.post_process(pre_transforms) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dataset = Dataset(data=[{self._input_dataset_key: img_name}], transform=pre_transforms) @@ -314,20 +396,21 @@ def compute_impl(self, input_image, context): for d in dataloader: images = d[self._input_dataset_key].to(device) self._logger.info(f"Input of {type(images)} shape: {images.shape}") - if self._inferer == InfererType.SLIDING_WINDOW: + if self.inferer == InfererType.SLIDING_WINDOW: d[self._pred_dataset_key] = sliding_window_inference( inputs=images, - roi_size=self._roi_size, - sw_batch_size=self.sw_batch_size, + roi_size=self.roi_size, overlap=self.overlap, - predictor=self.model, + sw_batch_size=self.sw_batch_size, + predictor=self._model, + **self._filtered_swi_params, # Additional sliding window arguments ) - elif self._inferer == InfererType.SIMPLE: + elif self.inferer == InfererType.SIMPLE: # Instantiates the SimpleInferer and directly uses its __call__ function - d[self._pred_dataset_key] = simple_inference()(inputs=images, network=self.model) + d[self._pred_dataset_key] = simple_inference()(inputs=images, network=self._model) else: raise ValueError( - f"Unknown inferer: {self._inferer!r}. Available options are " + f"Unknown inferer: {self.inferer!r}. Available options are " f"{InfererType.SLIDING_WINDOW!r} and {InfererType.SIMPLE!r}." )