Skip to content

Commit 7da6e97

Browse files
authored
MONAISegInferenceOperator Additional Arguments (#547)
* mode + padding_mode input args added Signed-off-by: bluna301 <[email protected]> * StrEnums from MONAI Core; kwargs filtering for sliding_window_inference forwarding Signed-off-by: bluna301 <[email protected]> * holoscan conversion check + explicit params added + setter cleanup Signed-off-by: bluna301 <[email protected]> * split kwargs into swi params & base init params Signed-off-by: bluna301 <[email protected]> * removed redundant continues Signed-off-by: bluna301 <[email protected]> --------- Signed-off-by: bluna301 <[email protected]>
1 parent 9f7a4ad commit 7da6e97

File tree

2 files changed

+128
-45
lines changed

2 files changed

+128
-45
lines changed

monai/deploy/core/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2021 MONAI Consortium
1+
# Copyright 2021-2025 MONAI Consortium
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -30,7 +30,7 @@
3030

3131
# Need to import explicit ones to quiet mypy complaints
3232
from holoscan.core import *
33-
from holoscan.core import Application, Condition, ConditionType, Fragment, Operator, OperatorSpec
33+
from holoscan.core import Application, Condition, ConditionType, Fragment, Operator, OperatorSpec, Resource
3434

3535
from .app_context import AppContext, init_app_context
3636
from .arg_parser import parse_args

monai/deploy/operators/monai_seg_inference_operator.py

Lines changed: 126 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2021-2023 MONAI Consortium
1+
# Copyright 2021-2025 MONAI Consortium
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -9,6 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
import inspect
1213
import logging
1314
import os
1415
from pathlib import Path
@@ -40,7 +41,7 @@
4041
# Dynamic class is not handled so make it Any for now: https://github.com/python/mypy/issues/2477
4142
Compose: Any = Compose_
4243

43-
from monai.deploy.core import AppContext, ConditionType, Fragment, Image, OperatorSpec
44+
from monai.deploy.core import AppContext, Condition, ConditionType, Fragment, Image, OperatorSpec, Resource
4445

4546
from .inference_operator import InferenceOperator
4647

@@ -56,14 +57,20 @@ class InfererType(StrEnum):
5657

5758
# @md.env(pip_packages=["monai>=1.0.0", "torch>=1.10.2", "numpy>=1.21"])
5859
class MonaiSegInferenceOperator(InferenceOperator):
59-
"""This segmentation operator uses MONAI transforms and Sliding Window Inference.
60+
"""This segmentation operator uses MONAI transforms and performs Simple or Sliding Window Inference.
6061
6162
This operator performs pre-transforms on a input image, inference
6263
using a given model, and post-transforms. The segmentation image is saved
6364
as a named Image object in memory.
6465
6566
If specified in the post transforms, results may also be saved to disk.
6667
68+
This operator uses the MONAI inference utils functions for sliding window and simple inference,
69+
and thus input parameters need to be as expected by these functions.
70+
71+
Any additional sliding window arguments not explicitly defined in this operator can be passed via
72+
**kwargs for forwarding to 'sliding_window_inference'.
73+
6774
Named Input:
6875
image: Image object of the input image.
6976
@@ -74,6 +81,63 @@ class MonaiSegInferenceOperator(InferenceOperator):
7481
# For testing the app directly, the model should be at the following path.
7582
MODEL_LOCAL_PATH = Path(os.environ.get("HOLOSCAN_MODEL_PATH", Path.cwd() / "model/model.ts"))
7683

84+
@staticmethod
85+
def filter_sw_kwargs(**kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
86+
"""
87+
Filters the keyword arguments into a tuple of two dictionaries:
88+
89+
1. A dictionary of named parameters to pass to the sliding_window_inference function that:
90+
- Are not explicitly defined in the __init__ of this class
91+
- Are not explicitly used when calling sliding_window_inference
92+
93+
2. A dicionary of named parameters to pass to the base class __init__ of this class that:
94+
- Are not used when calling sliding_window_inference
95+
- Can be successfully converted from Python --> Holoscan's C++ layer
96+
97+
Args:
98+
**kwargs: extra arguments passed into __init__ beyond the explicitly defined args.
99+
100+
Returns:
101+
filtered_swi_params: A filtered dictionary of arguments to be passed to sliding_window_inference.
102+
filtered_base_init_params: A filtered dictionary of arguments to be passed to the base class __init__.
103+
"""
104+
105+
logger = logging.getLogger(f"{__name__}.{MonaiSegInferenceOperator.__name__}")
106+
107+
init_params = inspect.signature(MonaiSegInferenceOperator).parameters
108+
swi_params = inspect.signature(sliding_window_inference).parameters
109+
110+
# inputs + predictor explicitly used when calling sliding_window_inference
111+
explicit_used = {"inputs", "predictor"}
112+
113+
# Holoscan convertible types (not exhaustive)
114+
# This will be revisited when there is a better way to handle this.
115+
allowed_types = (str, int, float, bool, bytes, list, tuple, torch.Tensor, Condition, Resource)
116+
117+
filtered_swi_params = {}
118+
filtered_base_init_params = {}
119+
120+
for name, val in kwargs.items():
121+
# Drop explicitly defined kwargs
122+
if name in init_params or name in explicit_used:
123+
logger.warning(f"{name!r} is already explicitly defined or used; dropping kwarg.")
124+
# SWI params
125+
elif name in swi_params:
126+
filtered_swi_params[name] = val
127+
logger.debug(f"{name!r} used in sliding_window_inference; keeping kwarg for inference call.")
128+
# Drop kwargs that can't be converted by Holoscan
129+
elif not isinstance(val, allowed_types):
130+
logger.warning(
131+
f"{name!r} type of {type(val).__name__!r} is a non-convertible kwarg for Holoscan; dropping kwarg."
132+
)
133+
# Base __init__ params
134+
else:
135+
filtered_base_init_params[name] = val
136+
logger.debug(
137+
f"{name!r} type of {type(val).__name__!r} can be converted by Holoscan; keeping kwarg for base init."
138+
)
139+
return filtered_swi_params, filtered_base_init_params
140+
77141
def __init__(
78142
self,
79143
fragment: Fragment,
@@ -83,7 +147,7 @@ def __init__(
83147
post_transforms: Compose,
84148
app_context: AppContext,
85149
model_name: Optional[str] = "",
86-
overlap: float = 0.25,
150+
overlap: Union[Sequence[float], float] = 0.25,
87151
sw_batch_size: int = 4,
88152
inferer: Union[InfererType, str] = InfererType.SLIDING_WINDOW,
89153
model_path: Path = MODEL_LOCAL_PATH,
@@ -93,19 +157,19 @@ def __init__(
93157
94158
Args:
95159
fragment (Fragment): An instance of the Application class which is derived from Fragment.
96-
roi_size (Union[Sequence[int], int]): The window size to execute "SLIDING_WINDOW" evaluation.
97-
An optional input only to be passed for "SLIDING_WINDOW".
98-
If using a "SIMPLE" Inferer, this input is ignored.
160+
roi_size (Sequence[int], int, optional): The window size to execute "SLIDING_WINDOW" evaluation.
161+
Applicable for "SLIDING_WINDOW" only.
99162
pre_transforms (Compose): MONAI Compose object used for pre-transforms.
100163
post_transforms (Compose): MONAI Compose object used for post-transforms.
101164
app_context (AppContext): Object holding the I/O and model paths, and potentially loaded models.
102165
model_name (str, optional): Name of the model. Default to "" for single model app.
103-
overlap (float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25.
166+
overlap (Sequence[float], float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25.
167+
Applicable for "SLIDING_WINDOW" only.
168+
sw_batch_size (int): The batch size to run window slices. Defaults to 4.
104169
Applicable for "SLIDING_WINDOW" only.
105-
sw_batch_size(int): The batch size to run window slices. Defaults to 4.
106-
Applicable for "SLIDING_WINDOW" only.
107-
inferer (InfererType): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
170+
inferer (InfererType, str): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
108171
model_path (Path): Path to the model file. Defaults to model/models.ts of current working dir.
172+
**kwargs: any other sliding window parameters to forward (e.g. `mode`, `cval`, etc.).
109173
"""
110174

111175
self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
@@ -115,26 +179,30 @@ def __init__(
115179
self._pred_dataset_key = "pred"
116180
self._input_image = None # Image will come in when compute is called.
117181
self._reader: Any = None
118-
self._roi_size = ensure_tuple(roi_size)
119-
self._pre_transform = pre_transforms
120-
self._post_transforms = post_transforms
182+
self.roi_size = ensure_tuple(roi_size)
183+
self.pre_transforms = pre_transforms
184+
self.post_transforms = post_transforms
121185
self._model_name = model_name.strip() if isinstance(model_name, str) else ""
122-
self._overlap = overlap
123-
self._sw_batch_size = sw_batch_size
124-
self._inferer = inferer
186+
self.overlap = overlap
187+
self.sw_batch_size = sw_batch_size
188+
self.inferer = inferer
189+
self._filtered_swi_params, self._filtered_base_init_params = self.filter_sw_kwargs(
190+
**kwargs
191+
) # Filter keyword args
125192

126193
# Add this so that the local model path can be set from the calling app
127194
self.model_path = model_path
128-
self.input_name_image = "image"
129-
self.output_name_seg = "seg_image"
195+
self._input_name_image = "image"
196+
self._output_name_seg = "seg_image"
130197

131198
# The execution context passed in on compute does not have the required model info, so need to
132199
# get and keep the model via the AppContext obj on construction.
133200
self.app_context = app_context
134201

135-
self.model = self._get_model(self.app_context, self.model_path, self._model_name)
202+
self._model = self._get_model(self.app_context, self.model_path, self._model_name)
136203

137-
super().__init__(fragment, *args, **kwargs)
204+
# Pass filtered base init params
205+
super().__init__(fragment, *args, **self._filtered_base_init_params)
138206

139207
def _get_model(self, app_context: AppContext, model_path: Path, model_name: str):
140208
"""Load the model with the given name from context or model path
@@ -159,8 +227,8 @@ def _get_model(self, app_context: AppContext, model_path: Path, model_name: str)
159227
return model
160228

161229
def setup(self, spec: OperatorSpec):
162-
spec.input(self.input_name_image)
163-
spec.output(self.output_name_seg).condition(ConditionType.NONE) # Downstream receiver optional.
230+
spec.input(self._input_name_image)
231+
spec.output(self._output_name_seg).condition(ConditionType.NONE) # Downstream receiver optional.
164232

165233
@property
166234
def roi_size(self):
@@ -199,9 +267,13 @@ def overlap(self):
199267
return self._overlap
200268

201269
@overlap.setter
202-
def overlap(self, val: float):
203-
if val < 0 or val > 1:
204-
raise ValueError("Overlap must be between 0 and 1.")
270+
def overlap(self, val: Union[Sequence[float], float]):
271+
if not isinstance(val, (Sequence, int, float)) or isinstance(val, str):
272+
raise TypeError(f"Overlap must be type Sequence[float] | float, got {type(val).__name__}.")
273+
elif isinstance(val, Sequence) and not all(isinstance(x, (int, float)) and 0 <= x < 1 for x in val):
274+
raise ValueError("Each overlap value must be >= 0 and < 1.")
275+
elif isinstance(val, (int, float)) and not (0 <= float(val) < 1):
276+
raise ValueError(f"Overlap must be >= 0 and < 1, got {val}.")
205277
self._overlap = val
206278

207279
@property
@@ -221,10 +293,20 @@ def inferer(self) -> Union[InfererType, str]:
221293
return self._inferer
222294

223295
@inferer.setter
224-
def inferer(self, val: InfererType):
225-
if not isinstance(val, InfererType):
226-
raise ValueError(f"Value must be of the correct type {InfererType}.")
227-
self._inferer = val
296+
def inferer(self, val: Union[InfererType, str]):
297+
if isinstance(val, InfererType):
298+
self._inferer = val
299+
return
300+
301+
if isinstance(val, str):
302+
s = val.strip().lower()
303+
valid = (InfererType.SIMPLE.value, InfererType.SLIDING_WINDOW.value)
304+
if s in valid:
305+
self._inferer = InfererType(s)
306+
return
307+
raise ValueError(f"inferer must be one of {valid}, got {val!r}.")
308+
309+
raise TypeError(f"inferer must be InfererType or str, got {type(val).__name__}.")
228310

229311
def _convert_dicom_metadata_datatype(self, metadata: Dict):
230312
"""Converts metadata in pydicom types to the corresponding native types.
@@ -279,10 +361,10 @@ def compute(self, op_input, op_output, context):
279361
else:
280362
self._executing = True
281363
try:
282-
input_image = op_input.receive(self.input_name_image)
364+
input_image = op_input.receive(self._input_name_image)
283365
if not input_image:
284366
raise ValueError("Input is None.")
285-
op_output.emit(self.compute_impl(input_image, context), self.output_name_seg)
367+
op_output.emit(self.compute_impl(input_image, context), self._output_name_seg)
286368
finally:
287369
# Reset state on completing this method execution.
288370
with self._lock:
@@ -297,12 +379,12 @@ def compute_impl(self, input_image, context):
297379
# Need to give a name to the image as in-mem Image obj has no name.
298380
img_name = str(input_img_metadata.get("SeriesInstanceUID", "Img_in_context"))
299381

300-
pre_transforms: Compose = self._pre_transform
301-
post_transforms: Compose = self._post_transforms
382+
pre_transforms: Compose = self.pre_transforms
383+
post_transforms: Compose = self.post_transforms
302384
self._reader = InMemImageReader(input_image)
303385

304-
pre_transforms = self._pre_transform if self._pre_transform else self.pre_process(self._reader)
305-
post_transforms = self._post_transforms if self._post_transforms else self.post_process(pre_transforms)
386+
pre_transforms = self.pre_transforms if self.pre_transforms else self.pre_process(self._reader)
387+
post_transforms = self.post_transforms if self.post_transforms else self.post_process(pre_transforms)
306388

307389
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
308390
dataset = Dataset(data=[{self._input_dataset_key: img_name}], transform=pre_transforms)
@@ -314,20 +396,21 @@ def compute_impl(self, input_image, context):
314396
for d in dataloader:
315397
images = d[self._input_dataset_key].to(device)
316398
self._logger.info(f"Input of {type(images)} shape: {images.shape}")
317-
if self._inferer == InfererType.SLIDING_WINDOW:
399+
if self.inferer == InfererType.SLIDING_WINDOW:
318400
d[self._pred_dataset_key] = sliding_window_inference(
319401
inputs=images,
320-
roi_size=self._roi_size,
321-
sw_batch_size=self.sw_batch_size,
402+
roi_size=self.roi_size,
322403
overlap=self.overlap,
323-
predictor=self.model,
404+
sw_batch_size=self.sw_batch_size,
405+
predictor=self._model,
406+
**self._filtered_swi_params, # Additional sliding window arguments
324407
)
325-
elif self._inferer == InfererType.SIMPLE:
408+
elif self.inferer == InfererType.SIMPLE:
326409
# Instantiates the SimpleInferer and directly uses its __call__ function
327-
d[self._pred_dataset_key] = simple_inference()(inputs=images, network=self.model)
410+
d[self._pred_dataset_key] = simple_inference()(inputs=images, network=self._model)
328411
else:
329412
raise ValueError(
330-
f"Unknown inferer: {self._inferer!r}. Available options are "
413+
f"Unknown inferer: {self.inferer!r}. Available options are "
331414
f"{InfererType.SLIDING_WINDOW!r} and {InfererType.SIMPLE!r}."
332415
)
333416

0 commit comments

Comments
 (0)