1414import os
1515from pathlib import Path
1616from threading import Lock
17- from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
17+ from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union
1818
1919import numpy as np
20+ import torch
2021
2122from monai .deploy .utils .importutil import optional_import
2223from monai .utils import StrEnum # Will use the built-in StrEnum when SDK requires Python 3.11.
23- from monai .utils import BlendMode , PytorchPadMode
2424
2525MONAI_UTILS = "monai.utils"
26- torch , _ = optional_import ("torch" , "1.5" )
2726np_str_obj_array_pattern , _ = optional_import ("torch.utils.data._utils.collate" , name = "np_str_obj_array_pattern" )
2827Dataset , _ = optional_import ("monai.data" , name = "Dataset" )
2928DataLoader , _ = optional_import ("monai.data" , name = "DataLoader" )
4241# Dynamic class is not handled so make it Any for now: https://github.com/python/mypy/issues/2477
4342Compose : Any = Compose_
4443
45- from monai .deploy .core import AppContext , ConditionType , Fragment , Image , OperatorSpec
44+ from monai .deploy .core import AppContext , Condition , ConditionType , Fragment , Image , OperatorSpec , Resource
4645
4746from .inference_operator import InferenceOperator
4847
@@ -56,11 +55,6 @@ class InfererType(StrEnum):
5655 SLIDING_WINDOW = "sliding_window"
5756
5857
59- # define other StrEnum types
60- BlendModeType = BlendMode
61- PytorchPadModeType = PytorchPadMode
62-
63-
6458# @md.env(pip_packages=["monai>=1.0.0", "torch>=1.10.2", "numpy>=1.21"])
6559class MonaiSegInferenceOperator (InferenceOperator ):
6660 """This segmentation operator uses MONAI transforms and performs Simple or Sliding Window Inference.
@@ -90,9 +84,10 @@ class MonaiSegInferenceOperator(InferenceOperator):
9084 @staticmethod
9185 def filter_sw_kwargs (** kwargs ) -> Dict [str , Any ]:
9286 """
93- Returns a dictionary of named parameters of the sliding_window_inference function that are:
94- - Not explicitly defined in the __init__ of this class
95- - Not explicitly used when calling sliding_window_inference
87+ Returns a dictionary of named parameters of the sliding_window_inference function that:
88+ - Are not explicitly defined in the __init__ of this class
89+ - Are not explicitly used when calling sliding_window_inference
90+ - Can be successfully converted from Python --> Holoscan's C++ layer
9691
9792 Args:
9893 **kwargs: extra arguments passed into __init__ beyond the explicitly defined args.
@@ -101,19 +96,30 @@ def filter_sw_kwargs(**kwargs) -> Dict[str, Any]:
10196 filtered_params: A filtered dictionary of arguments to be passed to sliding_window_inference.
10297 """
10398
99+ logger = logging .getLogger (f"{ __name__ } .{ MonaiSegInferenceOperator .__name__ } " )
100+
104101 init_params = inspect .signature (MonaiSegInferenceOperator ).parameters
105102
106103 # inputs + predictor explicitly used when calling sliding_window_inference
107- explicit_used = ["inputs" , "predictor" ]
104+ explicit_used = {"inputs" , "predictor" }
105+
106+ # Holoscan convertible types (not exhaustive)
107+ # This will be revisited when there is a better way to handle this.
108+ allowed_types = (str , int , float , bool , bytes , list , tuple , torch .Tensor , Condition , Resource )
108109
109110 filtered_params = {}
110111 for name , val in kwargs .items ():
112+ # Drop explicitly defined kwargs
111113 if name in init_params or name in explicit_used :
112- # match log formatting
113- logger = logging .getLogger (f"{ __name__ } .{ MonaiSegInferenceOperator .__name__ } " )
114- logger .warning (f"{ name !r} is already explicity defined or used; ignoring input arg" )
115- else :
116- filtered_params [name ] = val
114+ logger .warning (f"{ name !r} is already explicitly defined or used; dropping kwarg." )
115+ continue
116+ # Drop kwargs that can't be converted by Holoscan
117+ if not isinstance (val , allowed_types ):
118+ logger .warning (
119+ f"{ name !r} type of { type (val ).__name__ !r} is a non-convertible kwarg for Holoscan; dropping kwarg."
120+ )
121+ continue
122+ filtered_params [name ] = val
117123 return filtered_params
118124
119125 def __init__ (
@@ -125,10 +131,11 @@ def __init__(
125131 post_transforms : Compose ,
126132 app_context : AppContext ,
127133 model_name : Optional [str ] = "" ,
128- overlap : float = 0.25 ,
129134 sw_batch_size : int = 4 ,
130- mode : Union [BlendModeType , str ] = BlendModeType .CONSTANT ,
131- padding_mode : Union [PytorchPadModeType , str ] = PytorchPadModeType .CONSTANT ,
135+ overlap : Union [Sequence [float ], float ] = 0.25 ,
136+ sw_device : Optional [Union [torch .device , str ]] = None ,
137+ device : Optional [Union [torch .device , str ]] = None ,
138+ process_fn : Optional [Callable ] = None ,
132139 inferer : Union [InfererType , str ] = InfererType .SLIDING_WINDOW ,
133140 model_path : Path = MODEL_LOCAL_PATH ,
134141 ** kwargs ,
@@ -137,25 +144,25 @@ def __init__(
137144
138145 Args:
139146 fragment (Fragment): An instance of the Application class which is derived from Fragment.
140- roi_size (Union[Sequence[int], int]): The window size to execute "SLIDING_WINDOW" evaluation.
141- An optional input only to be passed for "SLIDING_WINDOW".
142- If using a "SIMPLE" Inferer, this input is ignored.
147+ roi_size (Sequence[int], int, optional): The window size to execute "SLIDING_WINDOW" evaluation.
148+ Applicable for "SLIDING_WINDOW" only.
143149 pre_transforms (Compose): MONAI Compose object used for pre-transforms.
144150 post_transforms (Compose): MONAI Compose object used for post-transforms.
145151 app_context (AppContext): Object holding the I/O and model paths, and potentially loaded models.
146152 model_name (str, optional): Name of the model. Default to "" for single model app.
147- overlap (float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25.
153+ sw_batch_size (int): The batch size to run window slices. Defaults to 4.
154+ Applicable for "SLIDING_WINDOW" only.
155+ overlap (Sequence[float], float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25.
148156 Applicable for "SLIDING_WINDOW" only.
149- sw_batch_size(int ): The batch size to run window slices . Defaults to 4 .
157+ sw_device (torch.device, str, optional ): Device for the window data . Defaults to None .
150158 Applicable for "SLIDING_WINDOW" only.
151- mode (BlendModeType ): How to blend output of overlapping windows, "CONSTANT" or "GAUSSIAN" . Defaults to "CONSTANT" .
159+ device: (torch.device, str, optional ): Device for the stitched output prediction . Defaults to None .
152160 Applicable for "SLIDING_WINDOW" only.
153- padding_mode (PytorchPadModeType): Padding mode for ``inputs``, when ``roi_size`` is larger than inputs,
154- "CONSTANT", "REFLECT", "REPLICATE", or "CIRCULAR". Defaults to "CONSTANT".
161+ process_fn: (Callable, optional): process inference output and adjust the importance map per window. Defaults to None.
155162 Applicable for "SLIDING_WINDOW" only.
156- inferer (InfererType): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
163+ inferer (InfererType, str ): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
157164 model_path (Path): Path to the model file. Defaults to model/models.ts of current working dir.
158- **kwargs: any other sliding window parameters to forward (e.g. `sigma_scale `, `cval`, etc.).
165+ **kwargs: any other sliding window parameters to forward (e.g. `mode `, `cval`, etc.).
159166 """
160167
161168 self ._logger = logging .getLogger ("{}.{}" .format (__name__ , type (self ).__name__ ))
@@ -165,29 +172,33 @@ def __init__(
165172 self ._pred_dataset_key = "pred"
166173 self ._input_image = None # Image will come in when compute is called.
167174 self ._reader : Any = None
168- self ._roi_size = ensure_tuple (roi_size )
169- self ._pre_transform = pre_transforms
170- self ._post_transforms = post_transforms
175+ self .roi_size = ensure_tuple (roi_size )
176+ self .pre_transforms = pre_transforms
177+ self .post_transforms = post_transforms
171178 self ._model_name = model_name .strip () if isinstance (model_name , str ) else ""
172- self ._overlap = overlap
173- self ._sw_batch_size = sw_batch_size
174- self ._mode = mode
175- self ._padding_mode = padding_mode
176- self ._inferer = inferer
179+ self .overlap = overlap
180+ self .sw_batch_size = sw_batch_size
181+ self .inferer = inferer
177182 self ._implicit_params = self .filter_sw_kwargs (** kwargs ) # Filter keyword args
178183
184+ # Sliding window inference args whose type Holoscan can't convert - define explicitly
185+ self .sw_device = sw_device
186+ self .device = device
187+ self .process_fn = process_fn
188+
179189 # Add this so that the local model path can be set from the calling app
180190 self .model_path = model_path
181- self .input_name_image = "image"
182- self .output_name_seg = "seg_image"
191+ self ._input_name_image = "image"
192+ self ._output_name_seg = "seg_image"
183193
184194 # The execution context passed in on compute does not have the required model info, so need to
185195 # get and keep the model via the AppContext obj on construction.
186196 self .app_context = app_context
187197
188- self .model = self ._get_model (self .app_context , self .model_path , self ._model_name )
198+ self ._model = self ._get_model (self .app_context , self .model_path , self ._model_name )
189199
190- super ().__init__ (fragment , * args , ** kwargs )
200+ # Pass filtered kwargs
201+ super ().__init__ (fragment , * args , ** self ._implicit_params )
191202
192203 def _get_model (self , app_context : AppContext , model_path : Path , model_name : str ):
193204 """Load the model with the given name from context or model path
@@ -212,8 +223,8 @@ def _get_model(self, app_context: AppContext, model_path: Path, model_name: str)
212223 return model
213224
214225 def setup (self , spec : OperatorSpec ):
215- spec .input (self .input_name_image )
216- spec .output (self .output_name_seg ).condition (ConditionType .NONE ) # Downstream receiver optional.
226+ spec .input (self ._input_name_image )
227+ spec .output (self ._output_name_seg ).condition (ConditionType .NONE ) # Downstream receiver optional.
217228
218229 @property
219230 def roi_size (self ):
@@ -252,9 +263,13 @@ def overlap(self):
252263 return self ._overlap
253264
254265 @overlap .setter
255- def overlap (self , val : float ):
256- if val < 0 or val > 1 :
257- raise ValueError ("Overlap must be between 0 and 1." )
266+ def overlap (self , val : Union [Sequence [float ], float ]):
267+ if not isinstance (val , (Sequence , int , float )) or isinstance (val , str ):
268+ raise TypeError (f"Overlap must be type Sequence[float] | float, got { type (val ).__name__ } ." )
269+ elif isinstance (val , Sequence ) and not all (isinstance (x , (int , float )) and 0 <= x < 1 for x in val ):
270+ raise ValueError ("Each overlap value must be >= 0 and < 1." )
271+ elif isinstance (val , (int , float )) and not (0 <= float (val ) < 1 ):
272+ raise ValueError (f"Overlap must be >= 0 and < 1, got { val } ." )
258273 self ._overlap = val
259274
260275 @property
@@ -269,37 +284,58 @@ def sw_batch_size(self, val: int):
269284 self ._sw_batch_size = val
270285
271286 @property
272- def mode (self ) -> Union [BlendModeType , str ]:
273- """The blend mode used during sliding window inference"""
274- return self ._mode
287+ def sw_device (self ):
288+ """Device for the window data."""
289+ return self ._sw_device
290+
291+ @sw_device .setter
292+ def sw_device (self , val : Optional [Union [torch .device , str ]]):
293+ if val is not None and not isinstance (val , (torch .device , str )):
294+ raise TypeError (f"sw_device must be type torch.device | str | None, got { type (val ).__name__ } ." )
295+ self ._sw_device = val
296+
297+ @property
298+ def device (self ):
299+ """Device for the stitched output prediction."""
300+ return self ._device
275301
276- @mode .setter
277- def mode (self , val : BlendModeType ):
278- if not isinstance (val , BlendModeType ):
279- raise ValueError (f"Value must be of the correct type { BlendModeType } ." )
280- self ._mode = val
302+ @device .setter
303+ def device (self , val : Optional [ Union [ torch . device , str ]] ):
304+ if val is not None and not isinstance (val , ( torch . device , str ) ):
305+ raise TypeError (f"device must be type torch.device | str | None, got { type ( val ). __name__ } ." )
306+ self ._device = val
281307
282308 @property
283- def padding_mode (self ) -> Union [ PytorchPadModeType , str ] :
284- """The padding mode to use when padding input images for inference """
285- return self ._padding_mode
309+ def process_fn (self ):
310+ """Process inference output and adjust the importance map per window. """
311+ return self ._process_fn
286312
287- @padding_mode .setter
288- def padding_mode (self , val : PytorchPadModeType ):
289- if not isinstance (val , PytorchPadModeType ):
290- raise ValueError (f"Value must be of the correct type { PytorchPadModeType } ." )
291- self ._padding_mode = val
313+ @process_fn .setter
314+ def process_fn (self , val : Optional [ Callable ] ):
315+ if val is not None and not callable (val ):
316+ raise TypeError (f"process_fn must be type Callable | None, got { type ( val ). __name__ } ." )
317+ self ._process_fn = val
292318
293319 @property
294320 def inferer (self ) -> Union [InfererType , str ]:
295321 """The type of inferer to use"""
296322 return self ._inferer
297323
298324 @inferer .setter
299- def inferer (self , val : InfererType ):
300- if not isinstance (val , InfererType ):
301- raise ValueError (f"Value must be of the correct type { InfererType } ." )
302- self ._inferer = val
325+ def inferer (self , val : Union [InfererType , str ]):
326+ if isinstance (val , InfererType ):
327+ self ._inferer = val
328+ return
329+
330+ if isinstance (val , str ):
331+ s = val .strip ().lower ()
332+ valid = (InfererType .SIMPLE .value , InfererType .SLIDING_WINDOW .value )
333+ if s in valid :
334+ self ._inferer = InfererType (s )
335+ return
336+ raise ValueError (f"inferer must be one of { valid } , got { val !r} ." )
337+
338+ raise TypeError (f"inferer must be InfererType or str, got { type (val ).__name__ } ." )
303339
304340 def _convert_dicom_metadata_datatype (self , metadata : Dict ):
305341 """Converts metadata in pydicom types to the corresponding native types.
@@ -354,10 +390,10 @@ def compute(self, op_input, op_output, context):
354390 else :
355391 self ._executing = True
356392 try :
357- input_image = op_input .receive (self .input_name_image )
393+ input_image = op_input .receive (self ._input_name_image )
358394 if not input_image :
359395 raise ValueError ("Input is None." )
360- op_output .emit (self .compute_impl (input_image , context ), self .output_name_seg )
396+ op_output .emit (self .compute_impl (input_image , context ), self ._output_name_seg )
361397 finally :
362398 # Reset state on completing this method execution.
363399 with self ._lock :
@@ -372,12 +408,12 @@ def compute_impl(self, input_image, context):
372408 # Need to give a name to the image as in-mem Image obj has no name.
373409 img_name = str (input_img_metadata .get ("SeriesInstanceUID" , "Img_in_context" ))
374410
375- pre_transforms : Compose = self ._pre_transform
376- post_transforms : Compose = self ._post_transforms
411+ pre_transforms : Compose = self .pre_transforms
412+ post_transforms : Compose = self .post_transforms
377413 self ._reader = InMemImageReader (input_image )
378414
379- pre_transforms = self ._pre_transform if self ._pre_transform else self .pre_process (self ._reader )
380- post_transforms = self ._post_transforms if self ._post_transforms else self .post_process (pre_transforms )
415+ pre_transforms = self .pre_transforms if self .pre_transforms else self .pre_process (self ._reader )
416+ post_transforms = self .post_transforms if self .post_transforms else self .post_process (pre_transforms )
381417
382418 device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
383419 dataset = Dataset (data = [{self ._input_dataset_key : img_name }], transform = pre_transforms )
@@ -389,23 +425,24 @@ def compute_impl(self, input_image, context):
389425 for d in dataloader :
390426 images = d [self ._input_dataset_key ].to (device )
391427 self ._logger .info (f"Input of { type (images )} shape: { images .shape } " )
392- if self ._inferer == InfererType .SLIDING_WINDOW :
428+ if self .inferer == InfererType .SLIDING_WINDOW :
393429 d [self ._pred_dataset_key ] = sliding_window_inference (
394430 inputs = images ,
395- roi_size = self ._roi_size ,
396- sw_batch_size = self .sw_batch_size ,
431+ roi_size = self .roi_size ,
397432 overlap = self .overlap ,
398- mode = self ._mode ,
399- padding_mode = self ._padding_mode ,
400- predictor = self .model ,
401- ** self ._implicit_params , # additional sliding window arguments
433+ sw_batch_size = self .sw_batch_size ,
434+ sw_device = self .sw_device ,
435+ device = self .device ,
436+ process_fn = self .process_fn ,
437+ predictor = self ._model ,
438+ ** self ._implicit_params , # Additional sliding window arguments
402439 )
403- elif self ._inferer == InfererType .SIMPLE :
440+ elif self .inferer == InfererType .SIMPLE :
404441 # Instantiates the SimpleInferer and directly uses its __call__ function
405- d [self ._pred_dataset_key ] = simple_inference ()(inputs = images , network = self .model )
442+ d [self ._pred_dataset_key ] = simple_inference ()(inputs = images , network = self ._model )
406443 else :
407444 raise ValueError (
408- f"Unknown inferer: { self ._inferer !r} . Available options are "
445+ f"Unknown inferer: { self .inferer !r} . Available options are "
409446 f"{ InfererType .SLIDING_WINDOW !r} and { InfererType .SIMPLE !r} ."
410447 )
411448
0 commit comments