99# See the License for the specific language governing permissions and
1010# limitations under the License.
1111
12+ import inspect
1213import logging
1314import os
1415from pathlib import Path
1920
2021from monai .deploy .utils .importutil import optional_import
2122from monai .utils import StrEnum # Will use the built-in StrEnum when SDK requires Python 3.11.
23+ from monai .utils import BlendMode , PytorchPadMode
2224
2325MONAI_UTILS = "monai.utils"
2426torch , _ = optional_import ("torch" , "1.5" )
4749__all__ = ["MonaiSegInferenceOperator" , "InfererType" , "InMemImageReader" ]
4850
4951
50- class BlendModeType (StrEnum ):
51- """Represents the supported blend modes for sliding window inference."""
52-
53- CONSTANT = "constant"
54- GAUSSIAN = "gaussian"
55-
56-
5752class InfererType (StrEnum ):
5853 """Represents the supported types of the inferer, e.g. Simple and Sliding Window."""
5954
6055 SIMPLE = "simple"
6156 SLIDING_WINDOW = "sliding_window"
6257
6358
64- class PytorchPadModeType (StrEnum ):
65- """Represents the supported padding modes for sliding window inference."""
66-
67- CONSTANT = "constant"
68- REFLECT = "reflect"
69- REPLICATE = "replicate"
70- CIRCULAR = "circular"
59+ # define other StrEnum types
60+ BlendModeType = BlendMode
61+ PytorchPadModeType = PytorchPadMode
7162
7263
7364# @md.env(pip_packages=["monai>=1.0.0", "torch>=1.10.2", "numpy>=1.21"])
7465class MonaiSegInferenceOperator (InferenceOperator ):
75- """This segmentation operator uses MONAI transforms and Sliding Window Inference.
66+ """This segmentation operator uses MONAI transforms and performs Simple or Sliding Window Inference.
7667
7768 This operator performs pre-transforms on a input image, inference
7869 using a given model, and post-transforms. The segmentation image is saved
7970 as a named Image object in memory.
8071
8172 If specified in the post transforms, results may also be saved to disk.
8273
74+ This operator uses the MONAI inference utils functions for sliding window and simple inference,
75+ and thus input parameters need to be as expected by these functions.
76+
77+ Any additional sliding window arguments not explicitly defined in this operator can be passed via
78+ **kwargs for forwarding to 'sliding_window_inference'.
79+
8380 Named Input:
8481 image: Image object of the input image.
8582
@@ -90,6 +87,35 @@ class MonaiSegInferenceOperator(InferenceOperator):
9087 # For testing the app directly, the model should be at the following path.
9188 MODEL_LOCAL_PATH = Path (os .environ .get ("HOLOSCAN_MODEL_PATH" , Path .cwd () / "model/model.ts" ))
9289
90+ @staticmethod
91+ def filter_sw_kwargs (** kwargs ) -> Dict [str , Any ]:
92+ """
93+ Returns a dictionary of named parameters of the sliding_window_inference function that are:
94+ - Not explicitly defined in the __init__ of this class
95+ - Not explicitly used when calling sliding_window_inference
96+
97+ Args:
98+ **kwargs: extra arguments passed into __init__ beyond the explicitly defined args.
99+
100+ Returns:
101+ filtered_params: A filtered dictionary of arguments to be passed to sliding_window_inference.
102+ """
103+
104+ init_params = inspect .signature (MonaiSegInferenceOperator ).parameters
105+
106+ # inputs + predictor explicitly used when calling sliding_window_inference
107+ explicit_used = ["inputs" , "predictor" ]
108+
109+ filtered_params = {}
110+ for name , val in kwargs .items ():
111+ if name in init_params or name in explicit_used :
112+ # match log formatting
113+ logger = logging .getLogger (f"{ __name__ } .{ MonaiSegInferenceOperator .__name__ } " )
114+ logger .warning (f"{ name !r} is already explicity defined or used; ignoring input arg" )
115+ else :
116+ filtered_params [name ] = val
117+ return filtered_params
118+
93119 def __init__ (
94120 self ,
95121 fragment : Fragment ,
@@ -122,13 +148,14 @@ def __init__(
122148 Applicable for "SLIDING_WINDOW" only.
123149 sw_batch_size(int): The batch size to run window slices. Defaults to 4.
124150 Applicable for "SLIDING_WINDOW" only.
125- mode (BlendMode ): How to blend output of overlapping windows, "CONSTANT" or "GAUSSIAN". Defaults to "CONSTANT".
151+ mode (BlendModeType ): How to blend output of overlapping windows, "CONSTANT" or "GAUSSIAN". Defaults to "CONSTANT".
126152 Applicable for "SLIDING_WINDOW" only.
127- padding_mode (PytorchPadMode ): Padding mode for ``inputs``, when ``roi_size`` is larger than inputs,
153+ padding_mode (PytorchPadModeType ): Padding mode for ``inputs``, when ``roi_size`` is larger than inputs,
128154 "CONSTANT", "REFLECT", "REPLICATE", or "CIRCULAR". Defaults to "CONSTANT".
129155 Applicable for "SLIDING_WINDOW" only.
130156 inferer (InfererType): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
131157 model_path (Path): Path to the model file. Defaults to model/models.ts of current working dir.
158+ **kwargs: any other sliding window parameters to forward (e.g. `sigma_scale`, `cval`, etc.).
132159 """
133160
134161 self ._logger = logging .getLogger ("{}.{}" .format (__name__ , type (self ).__name__ ))
@@ -147,6 +174,7 @@ def __init__(
147174 self ._mode = mode
148175 self ._padding_mode = padding_mode
149176 self ._inferer = inferer
177+ self ._implicit_params = self .filter_sw_kwargs (** kwargs ) # Filter keyword args
150178
151179 # Add this so that the local model path can be set from the calling app
152180 self .model_path = model_path
@@ -370,6 +398,7 @@ def compute_impl(self, input_image, context):
370398 mode = self ._mode ,
371399 padding_mode = self ._padding_mode ,
372400 predictor = self .model ,
401+ ** self ._implicit_params , # additional sliding window arguments
373402 )
374403 elif self ._inferer == InfererType .SIMPLE :
375404 # Instantiates the SimpleInferer and directly uses its __call__ function
0 commit comments