1- # Copyright 2021-2023 MONAI Consortium
1+ # Copyright 2021-2025 MONAI Consortium
22# Licensed under the Apache License, Version 2.0 (the "License");
33# you may not use this file except in compliance with the License.
44# You may obtain a copy of the License at
99# See the License for the specific language governing permissions and
1010# limitations under the License.
1111
12+ import inspect
1213import logging
1314import os
1415from pathlib import Path
4041# Dynamic class is not handled so make it Any for now: https://github.com/python/mypy/issues/2477
4142Compose : Any = Compose_
4243
43- from monai .deploy .core import AppContext , ConditionType , Fragment , Image , OperatorSpec
44+ from monai .deploy .core import AppContext , Condition , ConditionType , Fragment , Image , OperatorSpec , Resource
4445
4546from .inference_operator import InferenceOperator
4647
@@ -56,14 +57,20 @@ class InfererType(StrEnum):
5657
5758# @md.env(pip_packages=["monai>=1.0.0", "torch>=1.10.2", "numpy>=1.21"])
5859class MonaiSegInferenceOperator (InferenceOperator ):
59- """This segmentation operator uses MONAI transforms and Sliding Window Inference.
60+ """This segmentation operator uses MONAI transforms and performs Simple or Sliding Window Inference.
6061
6162 This operator performs pre-transforms on a input image, inference
6263 using a given model, and post-transforms. The segmentation image is saved
6364 as a named Image object in memory.
6465
6566 If specified in the post transforms, results may also be saved to disk.
6667
68+ This operator uses the MONAI inference utils functions for sliding window and simple inference,
69+ and thus input parameters need to be as expected by these functions.
70+
71+ Any additional sliding window arguments not explicitly defined in this operator can be passed via
72+ **kwargs for forwarding to 'sliding_window_inference'.
73+
6774 Named Input:
6875 image: Image object of the input image.
6976
@@ -74,6 +81,63 @@ class MonaiSegInferenceOperator(InferenceOperator):
7481 # For testing the app directly, the model should be at the following path.
7582 MODEL_LOCAL_PATH = Path (os .environ .get ("HOLOSCAN_MODEL_PATH" , Path .cwd () / "model/model.ts" ))
7683
84+ @staticmethod
85+ def filter_sw_kwargs (** kwargs ) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
86+ """
87+ Filters the keyword arguments into a tuple of two dictionaries:
88+
89+ 1. A dictionary of named parameters to pass to the sliding_window_inference function that:
90+ - Are not explicitly defined in the __init__ of this class
91+ - Are not explicitly used when calling sliding_window_inference
92+
93+ 2. A dicionary of named parameters to pass to the base class __init__ of this class that:
94+ - Are not used when calling sliding_window_inference
95+ - Can be successfully converted from Python --> Holoscan's C++ layer
96+
97+ Args:
98+ **kwargs: extra arguments passed into __init__ beyond the explicitly defined args.
99+
100+ Returns:
101+ filtered_swi_params: A filtered dictionary of arguments to be passed to sliding_window_inference.
102+ filtered_base_init_params: A filtered dictionary of arguments to be passed to the base class __init__.
103+ """
104+
105+ logger = logging .getLogger (f"{ __name__ } .{ MonaiSegInferenceOperator .__name__ } " )
106+
107+ init_params = inspect .signature (MonaiSegInferenceOperator ).parameters
108+ swi_params = inspect .signature (sliding_window_inference ).parameters
109+
110+ # inputs + predictor explicitly used when calling sliding_window_inference
111+ explicit_used = {"inputs" , "predictor" }
112+
113+ # Holoscan convertible types (not exhaustive)
114+ # This will be revisited when there is a better way to handle this.
115+ allowed_types = (str , int , float , bool , bytes , list , tuple , torch .Tensor , Condition , Resource )
116+
117+ filtered_swi_params = {}
118+ filtered_base_init_params = {}
119+
120+ for name , val in kwargs .items ():
121+ # Drop explicitly defined kwargs
122+ if name in init_params or name in explicit_used :
123+ logger .warning (f"{ name !r} is already explicitly defined or used; dropping kwarg." )
124+ # SWI params
125+ elif name in swi_params :
126+ filtered_swi_params [name ] = val
127+ logger .debug (f"{ name !r} used in sliding_window_inference; keeping kwarg for inference call." )
128+ # Drop kwargs that can't be converted by Holoscan
129+ elif not isinstance (val , allowed_types ):
130+ logger .warning (
131+ f"{ name !r} type of { type (val ).__name__ !r} is a non-convertible kwarg for Holoscan; dropping kwarg."
132+ )
133+ # Base __init__ params
134+ else :
135+ filtered_base_init_params [name ] = val
136+ logger .debug (
137+ f"{ name !r} type of { type (val ).__name__ !r} can be converted by Holoscan; keeping kwarg for base init."
138+ )
139+ return filtered_swi_params , filtered_base_init_params
140+
77141 def __init__ (
78142 self ,
79143 fragment : Fragment ,
@@ -83,7 +147,7 @@ def __init__(
83147 post_transforms : Compose ,
84148 app_context : AppContext ,
85149 model_name : Optional [str ] = "" ,
86- overlap : float = 0.25 ,
150+ overlap : Union [ Sequence [ float ], float ] = 0.25 ,
87151 sw_batch_size : int = 4 ,
88152 inferer : Union [InfererType , str ] = InfererType .SLIDING_WINDOW ,
89153 model_path : Path = MODEL_LOCAL_PATH ,
@@ -93,19 +157,19 @@ def __init__(
93157
94158 Args:
95159 fragment (Fragment): An instance of the Application class which is derived from Fragment.
96- roi_size (Union[Sequence[int], int]): The window size to execute "SLIDING_WINDOW" evaluation.
97- An optional input only to be passed for "SLIDING_WINDOW".
98- If using a "SIMPLE" Inferer, this input is ignored.
160+ roi_size (Sequence[int], int, optional): The window size to execute "SLIDING_WINDOW" evaluation.
161+ Applicable for "SLIDING_WINDOW" only.
99162 pre_transforms (Compose): MONAI Compose object used for pre-transforms.
100163 post_transforms (Compose): MONAI Compose object used for post-transforms.
101164 app_context (AppContext): Object holding the I/O and model paths, and potentially loaded models.
102165 model_name (str, optional): Name of the model. Default to "" for single model app.
103- overlap (float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25.
166+ overlap (Sequence[float], float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25.
167+ Applicable for "SLIDING_WINDOW" only.
168+ sw_batch_size (int): The batch size to run window slices. Defaults to 4.
104169 Applicable for "SLIDING_WINDOW" only.
105- sw_batch_size(int): The batch size to run window slices. Defaults to 4.
106- Applicable for "SLIDING_WINDOW" only.
107- inferer (InfererType): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
170+ inferer (InfererType, str): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
108171 model_path (Path): Path to the model file. Defaults to model/models.ts of current working dir.
172+ **kwargs: any other sliding window parameters to forward (e.g. `mode`, `cval`, etc.).
109173 """
110174
111175 self ._logger = logging .getLogger ("{}.{}" .format (__name__ , type (self ).__name__ ))
@@ -115,26 +179,30 @@ def __init__(
115179 self ._pred_dataset_key = "pred"
116180 self ._input_image = None # Image will come in when compute is called.
117181 self ._reader : Any = None
118- self ._roi_size = ensure_tuple (roi_size )
119- self ._pre_transform = pre_transforms
120- self ._post_transforms = post_transforms
182+ self .roi_size = ensure_tuple (roi_size )
183+ self .pre_transforms = pre_transforms
184+ self .post_transforms = post_transforms
121185 self ._model_name = model_name .strip () if isinstance (model_name , str ) else ""
122- self ._overlap = overlap
123- self ._sw_batch_size = sw_batch_size
124- self ._inferer = inferer
186+ self .overlap = overlap
187+ self .sw_batch_size = sw_batch_size
188+ self .inferer = inferer
189+ self ._filtered_swi_params , self ._filtered_base_init_params = self .filter_sw_kwargs (
190+ ** kwargs
191+ ) # Filter keyword args
125192
126193 # Add this so that the local model path can be set from the calling app
127194 self .model_path = model_path
128- self .input_name_image = "image"
129- self .output_name_seg = "seg_image"
195+ self ._input_name_image = "image"
196+ self ._output_name_seg = "seg_image"
130197
131198 # The execution context passed in on compute does not have the required model info, so need to
132199 # get and keep the model via the AppContext obj on construction.
133200 self .app_context = app_context
134201
135- self .model = self ._get_model (self .app_context , self .model_path , self ._model_name )
202+ self ._model = self ._get_model (self .app_context , self .model_path , self ._model_name )
136203
137- super ().__init__ (fragment , * args , ** kwargs )
204+ # Pass filtered base init params
205+ super ().__init__ (fragment , * args , ** self ._filtered_base_init_params )
138206
139207 def _get_model (self , app_context : AppContext , model_path : Path , model_name : str ):
140208 """Load the model with the given name from context or model path
@@ -159,8 +227,8 @@ def _get_model(self, app_context: AppContext, model_path: Path, model_name: str)
159227 return model
160228
161229 def setup (self , spec : OperatorSpec ):
162- spec .input (self .input_name_image )
163- spec .output (self .output_name_seg ).condition (ConditionType .NONE ) # Downstream receiver optional.
230+ spec .input (self ._input_name_image )
231+ spec .output (self ._output_name_seg ).condition (ConditionType .NONE ) # Downstream receiver optional.
164232
165233 @property
166234 def roi_size (self ):
@@ -199,9 +267,13 @@ def overlap(self):
199267 return self ._overlap
200268
201269 @overlap .setter
202- def overlap (self , val : float ):
203- if val < 0 or val > 1 :
204- raise ValueError ("Overlap must be between 0 and 1." )
270+ def overlap (self , val : Union [Sequence [float ], float ]):
271+ if not isinstance (val , (Sequence , int , float )) or isinstance (val , str ):
272+ raise TypeError (f"Overlap must be type Sequence[float] | float, got { type (val ).__name__ } ." )
273+ elif isinstance (val , Sequence ) and not all (isinstance (x , (int , float )) and 0 <= x < 1 for x in val ):
274+ raise ValueError ("Each overlap value must be >= 0 and < 1." )
275+ elif isinstance (val , (int , float )) and not (0 <= float (val ) < 1 ):
276+ raise ValueError (f"Overlap must be >= 0 and < 1, got { val } ." )
205277 self ._overlap = val
206278
207279 @property
@@ -221,10 +293,20 @@ def inferer(self) -> Union[InfererType, str]:
221293 return self ._inferer
222294
223295 @inferer .setter
224- def inferer (self , val : InfererType ):
225- if not isinstance (val , InfererType ):
226- raise ValueError (f"Value must be of the correct type { InfererType } ." )
227- self ._inferer = val
296+ def inferer (self , val : Union [InfererType , str ]):
297+ if isinstance (val , InfererType ):
298+ self ._inferer = val
299+ return
300+
301+ if isinstance (val , str ):
302+ s = val .strip ().lower ()
303+ valid = (InfererType .SIMPLE .value , InfererType .SLIDING_WINDOW .value )
304+ if s in valid :
305+ self ._inferer = InfererType (s )
306+ return
307+ raise ValueError (f"inferer must be one of { valid } , got { val !r} ." )
308+
309+ raise TypeError (f"inferer must be InfererType or str, got { type (val ).__name__ } ." )
228310
229311 def _convert_dicom_metadata_datatype (self , metadata : Dict ):
230312 """Converts metadata in pydicom types to the corresponding native types.
@@ -279,10 +361,10 @@ def compute(self, op_input, op_output, context):
279361 else :
280362 self ._executing = True
281363 try :
282- input_image = op_input .receive (self .input_name_image )
364+ input_image = op_input .receive (self ._input_name_image )
283365 if not input_image :
284366 raise ValueError ("Input is None." )
285- op_output .emit (self .compute_impl (input_image , context ), self .output_name_seg )
367+ op_output .emit (self .compute_impl (input_image , context ), self ._output_name_seg )
286368 finally :
287369 # Reset state on completing this method execution.
288370 with self ._lock :
@@ -297,12 +379,12 @@ def compute_impl(self, input_image, context):
297379 # Need to give a name to the image as in-mem Image obj has no name.
298380 img_name = str (input_img_metadata .get ("SeriesInstanceUID" , "Img_in_context" ))
299381
300- pre_transforms : Compose = self ._pre_transform
301- post_transforms : Compose = self ._post_transforms
382+ pre_transforms : Compose = self .pre_transforms
383+ post_transforms : Compose = self .post_transforms
302384 self ._reader = InMemImageReader (input_image )
303385
304- pre_transforms = self ._pre_transform if self ._pre_transform else self .pre_process (self ._reader )
305- post_transforms = self ._post_transforms if self ._post_transforms else self .post_process (pre_transforms )
386+ pre_transforms = self .pre_transforms if self .pre_transforms else self .pre_process (self ._reader )
387+ post_transforms = self .post_transforms if self .post_transforms else self .post_process (pre_transforms )
306388
307389 device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
308390 dataset = Dataset (data = [{self ._input_dataset_key : img_name }], transform = pre_transforms )
@@ -314,20 +396,21 @@ def compute_impl(self, input_image, context):
314396 for d in dataloader :
315397 images = d [self ._input_dataset_key ].to (device )
316398 self ._logger .info (f"Input of { type (images )} shape: { images .shape } " )
317- if self ._inferer == InfererType .SLIDING_WINDOW :
399+ if self .inferer == InfererType .SLIDING_WINDOW :
318400 d [self ._pred_dataset_key ] = sliding_window_inference (
319401 inputs = images ,
320- roi_size = self ._roi_size ,
321- sw_batch_size = self .sw_batch_size ,
402+ roi_size = self .roi_size ,
322403 overlap = self .overlap ,
323- predictor = self .model ,
404+ sw_batch_size = self .sw_batch_size ,
405+ predictor = self ._model ,
406+ ** self ._filtered_swi_params , # Additional sliding window arguments
324407 )
325- elif self ._inferer == InfererType .SIMPLE :
408+ elif self .inferer == InfererType .SIMPLE :
326409 # Instantiates the SimpleInferer and directly uses its __call__ function
327- d [self ._pred_dataset_key ] = simple_inference ()(inputs = images , network = self .model )
410+ d [self ._pred_dataset_key ] = simple_inference ()(inputs = images , network = self ._model )
328411 else :
329412 raise ValueError (
330- f"Unknown inferer: { self ._inferer !r} . Available options are "
413+ f"Unknown inferer: { self .inferer !r} . Available options are "
331414 f"{ InfererType .SLIDING_WINDOW !r} and { InfererType .SIMPLE !r} ."
332415 )
333416
0 commit comments