Skip to content

Commit cca0b15

Browse files
committed
split kwargs into swi params & base init params
Signed-off-by: bluna301 <[email protected]>
1 parent 5d9061d commit cca0b15

File tree

2 files changed

+38
-66
lines changed

2 files changed

+38
-66
lines changed

monai/deploy/core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2021 MONAI Consortium
1+
# Copyright 2021-2025 MONAI Consortium
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at

monai/deploy/operators/monai_seg_inference_operator.py

Lines changed: 37 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
import os
1515
from pathlib import Path
1616
from threading import Lock
17-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
17+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
1818

1919
import numpy as np
20-
import torch
2120

2221
from monai.deploy.utils.importutil import optional_import
2322
from monai.utils import StrEnum # Will use the built-in StrEnum when SDK requires Python 3.11.
2423

2524
MONAI_UTILS = "monai.utils"
25+
torch, _ = optional_import("torch", "1.5")
2626
np_str_obj_array_pattern, _ = optional_import("torch.utils.data._utils.collate", name="np_str_obj_array_pattern")
2727
Dataset, _ = optional_import("monai.data", name="Dataset")
2828
DataLoader, _ = optional_import("monai.data", name="DataLoader")
@@ -82,23 +82,30 @@ class MonaiSegInferenceOperator(InferenceOperator):
8282
MODEL_LOCAL_PATH = Path(os.environ.get("HOLOSCAN_MODEL_PATH", Path.cwd() / "model/model.ts"))
8383

8484
@staticmethod
85-
def filter_sw_kwargs(**kwargs) -> Dict[str, Any]:
85+
def filter_sw_kwargs(**kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
8686
"""
87-
Returns a dictionary of named parameters of the sliding_window_inference function that:
87+
Filters the keyword arguments into a tuple of two dictionaries:
88+
89+
1. A dictionary of named parameters to pass to the sliding_window_inference function that:
8890
- Are not explicitly defined in the __init__ of this class
8991
- Are not explicitly used when calling sliding_window_inference
92+
93+
2. A dicionary of named parameters to pass to the base class __init__ of this class that:
94+
- Are not used when calling sliding_window_inference
9095
- Can be successfully converted from Python --> Holoscan's C++ layer
9196
9297
Args:
9398
**kwargs: extra arguments passed into __init__ beyond the explicitly defined args.
9499
95100
Returns:
96-
filtered_params: A filtered dictionary of arguments to be passed to sliding_window_inference.
101+
filtered_swi_params: A filtered dictionary of arguments to be passed to sliding_window_inference.
102+
filtered_base_init_params: A filtered dictionary of arguments to be passed to the base class __init__.
97103
"""
98104

99105
logger = logging.getLogger(f"{__name__}.{MonaiSegInferenceOperator.__name__}")
100106

101107
init_params = inspect.signature(MonaiSegInferenceOperator).parameters
108+
swi_params = inspect.signature(sliding_window_inference).parameters
102109

103110
# inputs + predictor explicitly used when calling sliding_window_inference
104111
explicit_used = {"inputs", "predictor"}
@@ -107,20 +114,33 @@ def filter_sw_kwargs(**kwargs) -> Dict[str, Any]:
107114
# This will be revisited when there is a better way to handle this.
108115
allowed_types = (str, int, float, bool, bytes, list, tuple, torch.Tensor, Condition, Resource)
109116

110-
filtered_params = {}
117+
filtered_swi_params = {}
118+
filtered_base_init_params = {}
119+
111120
for name, val in kwargs.items():
112121
# Drop explicitly defined kwargs
113122
if name in init_params or name in explicit_used:
114123
logger.warning(f"{name!r} is already explicitly defined or used; dropping kwarg.")
115124
continue
125+
# SWI params
126+
elif name in swi_params:
127+
filtered_swi_params[name] = val
128+
logger.debug(f"{name!r} used in sliding_window_inference; keeping kwarg for inference call.")
129+
continue
116130
# Drop kwargs that can't be converted by Holoscan
117-
if not isinstance(val, allowed_types):
131+
elif not isinstance(val, allowed_types):
118132
logger.warning(
119133
f"{name!r} type of {type(val).__name__!r} is a non-convertible kwarg for Holoscan; dropping kwarg."
120134
)
121135
continue
122-
filtered_params[name] = val
123-
return filtered_params
136+
# Base __init__ params
137+
else:
138+
filtered_base_init_params[name] = val
139+
logger.debug(
140+
f"{name!r} type of {type(val).__name__!r} can be converted by Holoscan; keeping kwarg for base init."
141+
)
142+
continue
143+
return filtered_swi_params, filtered_base_init_params
124144

125145
def __init__(
126146
self,
@@ -131,11 +151,8 @@ def __init__(
131151
post_transforms: Compose,
132152
app_context: AppContext,
133153
model_name: Optional[str] = "",
134-
sw_batch_size: int = 4,
135154
overlap: Union[Sequence[float], float] = 0.25,
136-
sw_device: Optional[Union[torch.device, str]] = None,
137-
device: Optional[Union[torch.device, str]] = None,
138-
process_fn: Optional[Callable] = None,
155+
sw_batch_size: int = 4,
139156
inferer: Union[InfererType, str] = InfererType.SLIDING_WINDOW,
140157
model_path: Path = MODEL_LOCAL_PATH,
141158
**kwargs,
@@ -150,15 +167,9 @@ def __init__(
150167
post_transforms (Compose): MONAI Compose object used for post-transforms.
151168
app_context (AppContext): Object holding the I/O and model paths, and potentially loaded models.
152169
model_name (str, optional): Name of the model. Default to "" for single model app.
153-
sw_batch_size (int): The batch size to run window slices. Defaults to 4.
154-
Applicable for "SLIDING_WINDOW" only.
155170
overlap (Sequence[float], float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25.
156171
Applicable for "SLIDING_WINDOW" only.
157-
sw_device (torch.device, str, optional): Device for the window data. Defaults to None.
158-
Applicable for "SLIDING_WINDOW" only.
159-
device: (torch.device, str, optional): Device for the stitched output prediction. Defaults to None.
160-
Applicable for "SLIDING_WINDOW" only.
161-
process_fn: (Callable, optional): process inference output and adjust the importance map per window. Defaults to None.
172+
sw_batch_size (int): The batch size to run window slices. Defaults to 4.
162173
Applicable for "SLIDING_WINDOW" only.
163174
inferer (InfererType, str): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
164175
model_path (Path): Path to the model file. Defaults to model/models.ts of current working dir.
@@ -179,12 +190,9 @@ def __init__(
179190
self.overlap = overlap
180191
self.sw_batch_size = sw_batch_size
181192
self.inferer = inferer
182-
self._implicit_params = self.filter_sw_kwargs(**kwargs) # Filter keyword args
183-
184-
# Sliding window inference args whose type Holoscan can't convert - define explicitly
185-
self.sw_device = sw_device
186-
self.device = device
187-
self.process_fn = process_fn
193+
self._filtered_swi_params, self._filtered_base_init_params = self.filter_sw_kwargs(
194+
**kwargs
195+
) # Filter keyword args
188196

189197
# Add this so that the local model path can be set from the calling app
190198
self.model_path = model_path
@@ -197,8 +205,8 @@ def __init__(
197205

198206
self._model = self._get_model(self.app_context, self.model_path, self._model_name)
199207

200-
# Pass filtered kwargs
201-
super().__init__(fragment, *args, **self._implicit_params)
208+
# Pass filtered base init params
209+
super().__init__(fragment, *args, **self._filtered_base_init_params)
202210

203211
def _get_model(self, app_context: AppContext, model_path: Path, model_name: str):
204212
"""Load the model with the given name from context or model path
@@ -283,39 +291,6 @@ def sw_batch_size(self, val: int):
283291
raise ValueError("sw_batch_size must be a positive integer.")
284292
self._sw_batch_size = val
285293

286-
@property
287-
def sw_device(self):
288-
"""Device for the window data."""
289-
return self._sw_device
290-
291-
@sw_device.setter
292-
def sw_device(self, val: Optional[Union[torch.device, str]]):
293-
if val is not None and not isinstance(val, (torch.device, str)):
294-
raise TypeError(f"sw_device must be type torch.device | str | None, got {type(val).__name__}.")
295-
self._sw_device = val
296-
297-
@property
298-
def device(self):
299-
"""Device for the stitched output prediction."""
300-
return self._device
301-
302-
@device.setter
303-
def device(self, val: Optional[Union[torch.device, str]]):
304-
if val is not None and not isinstance(val, (torch.device, str)):
305-
raise TypeError(f"device must be type torch.device | str | None, got {type(val).__name__}.")
306-
self._device = val
307-
308-
@property
309-
def process_fn(self):
310-
"""Process inference output and adjust the importance map per window."""
311-
return self._process_fn
312-
313-
@process_fn.setter
314-
def process_fn(self, val: Optional[Callable]):
315-
if val is not None and not callable(val):
316-
raise TypeError(f"process_fn must be type Callable | None, got {type(val).__name__}.")
317-
self._process_fn = val
318-
319294
@property
320295
def inferer(self) -> Union[InfererType, str]:
321296
"""The type of inferer to use"""
@@ -431,11 +406,8 @@ def compute_impl(self, input_image, context):
431406
roi_size=self.roi_size,
432407
overlap=self.overlap,
433408
sw_batch_size=self.sw_batch_size,
434-
sw_device=self.sw_device,
435-
device=self.device,
436-
process_fn=self.process_fn,
437409
predictor=self._model,
438-
**self._implicit_params, # Additional sliding window arguments
410+
**self._filtered_swi_params, # Additional sliding window arguments
439411
)
440412
elif self.inferer == InfererType.SIMPLE:
441413
# Instantiates the SimpleInferer and directly uses its __call__ function

0 commit comments

Comments
 (0)