1414import os
1515from pathlib import Path
1616from threading import Lock
17- from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union
17+ from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
1818
1919import numpy as np
20- import torch
2120
2221from monai .deploy .utils .importutil import optional_import
2322from monai .utils import StrEnum # Will use the built-in StrEnum when SDK requires Python 3.11.
2423
2524MONAI_UTILS = "monai.utils"
25+ torch , _ = optional_import ("torch" , "1.5" )
2626np_str_obj_array_pattern , _ = optional_import ("torch.utils.data._utils.collate" , name = "np_str_obj_array_pattern" )
2727Dataset , _ = optional_import ("monai.data" , name = "Dataset" )
2828DataLoader , _ = optional_import ("monai.data" , name = "DataLoader" )
@@ -82,23 +82,30 @@ class MonaiSegInferenceOperator(InferenceOperator):
8282 MODEL_LOCAL_PATH = Path (os .environ .get ("HOLOSCAN_MODEL_PATH" , Path .cwd () / "model/model.ts" ))
8383
8484 @staticmethod
85- def filter_sw_kwargs (** kwargs ) -> Dict [str , Any ]:
85+ def filter_sw_kwargs (** kwargs ) -> Tuple [ Dict [str , Any ], Dict [ str , Any ] ]:
8686 """
87- Returns a dictionary of named parameters of the sliding_window_inference function that:
87+ Filters the keyword arguments into a tuple of two dictionaries:
88+
89+ 1. A dictionary of named parameters to pass to the sliding_window_inference function that:
8890 - Are not explicitly defined in the __init__ of this class
8991 - Are not explicitly used when calling sliding_window_inference
92+
93+ 2. A dicionary of named parameters to pass to the base class __init__ of this class that:
94+ - Are not used when calling sliding_window_inference
9095 - Can be successfully converted from Python --> Holoscan's C++ layer
9196
9297 Args:
9398 **kwargs: extra arguments passed into __init__ beyond the explicitly defined args.
9499
95100 Returns:
96- filtered_params: A filtered dictionary of arguments to be passed to sliding_window_inference.
101+ filtered_swi_params: A filtered dictionary of arguments to be passed to sliding_window_inference.
102+ filtered_base_init_params: A filtered dictionary of arguments to be passed to the base class __init__.
97103 """
98104
99105 logger = logging .getLogger (f"{ __name__ } .{ MonaiSegInferenceOperator .__name__ } " )
100106
101107 init_params = inspect .signature (MonaiSegInferenceOperator ).parameters
108+ swi_params = inspect .signature (sliding_window_inference ).parameters
102109
103110 # inputs + predictor explicitly used when calling sliding_window_inference
104111 explicit_used = {"inputs" , "predictor" }
@@ -107,20 +114,33 @@ def filter_sw_kwargs(**kwargs) -> Dict[str, Any]:
107114 # This will be revisited when there is a better way to handle this.
108115 allowed_types = (str , int , float , bool , bytes , list , tuple , torch .Tensor , Condition , Resource )
109116
110- filtered_params = {}
117+ filtered_swi_params = {}
118+ filtered_base_init_params = {}
119+
111120 for name , val in kwargs .items ():
112121 # Drop explicitly defined kwargs
113122 if name in init_params or name in explicit_used :
114123 logger .warning (f"{ name !r} is already explicitly defined or used; dropping kwarg." )
115124 continue
125+ # SWI params
126+ elif name in swi_params :
127+ filtered_swi_params [name ] = val
128+ logger .debug (f"{ name !r} used in sliding_window_inference; keeping kwarg for inference call." )
129+ continue
116130 # Drop kwargs that can't be converted by Holoscan
117- if not isinstance (val , allowed_types ):
131+ elif not isinstance (val , allowed_types ):
118132 logger .warning (
119133 f"{ name !r} type of { type (val ).__name__ !r} is a non-convertible kwarg for Holoscan; dropping kwarg."
120134 )
121135 continue
122- filtered_params [name ] = val
123- return filtered_params
136+ # Base __init__ params
137+ else :
138+ filtered_base_init_params [name ] = val
139+ logger .debug (
140+ f"{ name !r} type of { type (val ).__name__ !r} can be converted by Holoscan; keeping kwarg for base init."
141+ )
142+ continue
143+ return filtered_swi_params , filtered_base_init_params
124144
125145 def __init__ (
126146 self ,
@@ -131,11 +151,8 @@ def __init__(
131151 post_transforms : Compose ,
132152 app_context : AppContext ,
133153 model_name : Optional [str ] = "" ,
134- sw_batch_size : int = 4 ,
135154 overlap : Union [Sequence [float ], float ] = 0.25 ,
136- sw_device : Optional [Union [torch .device , str ]] = None ,
137- device : Optional [Union [torch .device , str ]] = None ,
138- process_fn : Optional [Callable ] = None ,
155+ sw_batch_size : int = 4 ,
139156 inferer : Union [InfererType , str ] = InfererType .SLIDING_WINDOW ,
140157 model_path : Path = MODEL_LOCAL_PATH ,
141158 ** kwargs ,
@@ -150,15 +167,9 @@ def __init__(
150167 post_transforms (Compose): MONAI Compose object used for post-transforms.
151168 app_context (AppContext): Object holding the I/O and model paths, and potentially loaded models.
152169 model_name (str, optional): Name of the model. Default to "" for single model app.
153- sw_batch_size (int): The batch size to run window slices. Defaults to 4.
154- Applicable for "SLIDING_WINDOW" only.
155170 overlap (Sequence[float], float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25.
156171 Applicable for "SLIDING_WINDOW" only.
157- sw_device (torch.device, str, optional): Device for the window data. Defaults to None.
158- Applicable for "SLIDING_WINDOW" only.
159- device: (torch.device, str, optional): Device for the stitched output prediction. Defaults to None.
160- Applicable for "SLIDING_WINDOW" only.
161- process_fn: (Callable, optional): process inference output and adjust the importance map per window. Defaults to None.
172+ sw_batch_size (int): The batch size to run window slices. Defaults to 4.
162173 Applicable for "SLIDING_WINDOW" only.
163174 inferer (InfererType, str): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
164175 model_path (Path): Path to the model file. Defaults to model/models.ts of current working dir.
@@ -179,12 +190,9 @@ def __init__(
179190 self .overlap = overlap
180191 self .sw_batch_size = sw_batch_size
181192 self .inferer = inferer
182- self ._implicit_params = self .filter_sw_kwargs (** kwargs ) # Filter keyword args
183-
184- # Sliding window inference args whose type Holoscan can't convert - define explicitly
185- self .sw_device = sw_device
186- self .device = device
187- self .process_fn = process_fn
193+ self ._filtered_swi_params , self ._filtered_base_init_params = self .filter_sw_kwargs (
194+ ** kwargs
195+ ) # Filter keyword args
188196
189197 # Add this so that the local model path can be set from the calling app
190198 self .model_path = model_path
@@ -197,8 +205,8 @@ def __init__(
197205
198206 self ._model = self ._get_model (self .app_context , self .model_path , self ._model_name )
199207
200- # Pass filtered kwargs
201- super ().__init__ (fragment , * args , ** self ._implicit_params )
208+ # Pass filtered base init params
209+ super ().__init__ (fragment , * args , ** self ._filtered_base_init_params )
202210
203211 def _get_model (self , app_context : AppContext , model_path : Path , model_name : str ):
204212 """Load the model with the given name from context or model path
@@ -283,39 +291,6 @@ def sw_batch_size(self, val: int):
283291 raise ValueError ("sw_batch_size must be a positive integer." )
284292 self ._sw_batch_size = val
285293
286- @property
287- def sw_device (self ):
288- """Device for the window data."""
289- return self ._sw_device
290-
291- @sw_device .setter
292- def sw_device (self , val : Optional [Union [torch .device , str ]]):
293- if val is not None and not isinstance (val , (torch .device , str )):
294- raise TypeError (f"sw_device must be type torch.device | str | None, got { type (val ).__name__ } ." )
295- self ._sw_device = val
296-
297- @property
298- def device (self ):
299- """Device for the stitched output prediction."""
300- return self ._device
301-
302- @device .setter
303- def device (self , val : Optional [Union [torch .device , str ]]):
304- if val is not None and not isinstance (val , (torch .device , str )):
305- raise TypeError (f"device must be type torch.device | str | None, got { type (val ).__name__ } ." )
306- self ._device = val
307-
308- @property
309- def process_fn (self ):
310- """Process inference output and adjust the importance map per window."""
311- return self ._process_fn
312-
313- @process_fn .setter
314- def process_fn (self , val : Optional [Callable ]):
315- if val is not None and not callable (val ):
316- raise TypeError (f"process_fn must be type Callable | None, got { type (val ).__name__ } ." )
317- self ._process_fn = val
318-
319294 @property
320295 def inferer (self ) -> Union [InfererType , str ]:
321296 """The type of inferer to use"""
@@ -431,11 +406,8 @@ def compute_impl(self, input_image, context):
431406 roi_size = self .roi_size ,
432407 overlap = self .overlap ,
433408 sw_batch_size = self .sw_batch_size ,
434- sw_device = self .sw_device ,
435- device = self .device ,
436- process_fn = self .process_fn ,
437409 predictor = self ._model ,
438- ** self ._implicit_params , # Additional sliding window arguments
410+ ** self ._filtered_swi_params , # Additional sliding window arguments
439411 )
440412 elif self .inferer == InfererType .SIMPLE :
441413 # Instantiates the SimpleInferer and directly uses its __call__ function
0 commit comments