Skip to content

Commit 6f72b0b

Browse files
committed
StrEnums from MONAI Core; kwargs filtering for sliding_window_inference forwarding
Signed-off-by: bluna301 <[email protected]>
1 parent b267d20 commit 6f72b0b

File tree

1 file changed

+46
-17
lines changed

1 file changed

+46
-17
lines changed

monai/deploy/operators/monai_seg_inference_operator.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
import inspect
1213
import logging
1314
import os
1415
from pathlib import Path
@@ -19,6 +20,7 @@
1920

2021
from monai.deploy.utils.importutil import optional_import
2122
from monai.utils import StrEnum # Will use the built-in StrEnum when SDK requires Python 3.11.
23+
from monai.utils import BlendMode, PytorchPadMode
2224

2325
MONAI_UTILS = "monai.utils"
2426
torch, _ = optional_import("torch", "1.5")
@@ -47,39 +49,34 @@
4749
__all__ = ["MonaiSegInferenceOperator", "InfererType", "InMemImageReader"]
4850

4951

50-
class BlendModeType(StrEnum):
51-
"""Represents the supported blend modes for sliding window inference."""
52-
53-
CONSTANT = "constant"
54-
GAUSSIAN = "gaussian"
55-
56-
5752
class InfererType(StrEnum):
5853
"""Represents the supported types of the inferer, e.g. Simple and Sliding Window."""
5954

6055
SIMPLE = "simple"
6156
SLIDING_WINDOW = "sliding_window"
6257

6358

64-
class PytorchPadModeType(StrEnum):
65-
"""Represents the supported padding modes for sliding window inference."""
66-
67-
CONSTANT = "constant"
68-
REFLECT = "reflect"
69-
REPLICATE = "replicate"
70-
CIRCULAR = "circular"
59+
# define other StrEnum types
60+
BlendModeType = BlendMode
61+
PytorchPadModeType = PytorchPadMode
7162

7263

7364
# @md.env(pip_packages=["monai>=1.0.0", "torch>=1.10.2", "numpy>=1.21"])
7465
class MonaiSegInferenceOperator(InferenceOperator):
75-
"""This segmentation operator uses MONAI transforms and Sliding Window Inference.
66+
"""This segmentation operator uses MONAI transforms and performs Simple or Sliding Window Inference.
7667
7768
This operator performs pre-transforms on a input image, inference
7869
using a given model, and post-transforms. The segmentation image is saved
7970
as a named Image object in memory.
8071
8172
If specified in the post transforms, results may also be saved to disk.
8273
74+
This operator uses the MONAI inference utils functions for sliding window and simple inference,
75+
and thus input parameters need to be as expected by these functions.
76+
77+
Any additional sliding window arguments not explicitly defined in this operator can be passed via
78+
**kwargs for forwarding to 'sliding_window_inference'.
79+
8380
Named Input:
8481
image: Image object of the input image.
8582
@@ -90,6 +87,35 @@ class MonaiSegInferenceOperator(InferenceOperator):
9087
# For testing the app directly, the model should be at the following path.
9188
MODEL_LOCAL_PATH = Path(os.environ.get("HOLOSCAN_MODEL_PATH", Path.cwd() / "model/model.ts"))
9289

90+
@staticmethod
91+
def filter_sw_kwargs(**kwargs) -> Dict[str, Any]:
92+
"""
93+
Returns a dictionary of named parameters of the sliding_window_inference function that are:
94+
- Not explicitly defined in the __init__ of this class
95+
- Not explicitly used when calling sliding_window_inference
96+
97+
Args:
98+
**kwargs: extra arguments passed into __init__ beyond the explicitly defined args.
99+
100+
Returns:
101+
filtered_params: A filtered dictionary of arguments to be passed to sliding_window_inference.
102+
"""
103+
104+
init_params = inspect.signature(MonaiSegInferenceOperator).parameters
105+
106+
# inputs + predictor explicitly used when calling sliding_window_inference
107+
explicit_used = ["inputs", "predictor"]
108+
109+
filtered_params = {}
110+
for name, val in kwargs.items():
111+
if name in init_params or name in explicit_used:
112+
# match log formatting
113+
logger = logging.getLogger(f"{__name__}.{MonaiSegInferenceOperator.__name__}")
114+
logger.warning(f"{name!r} is already explicity defined or used; ignoring input arg")
115+
else:
116+
filtered_params[name] = val
117+
return filtered_params
118+
93119
def __init__(
94120
self,
95121
fragment: Fragment,
@@ -122,13 +148,14 @@ def __init__(
122148
Applicable for "SLIDING_WINDOW" only.
123149
sw_batch_size(int): The batch size to run window slices. Defaults to 4.
124150
Applicable for "SLIDING_WINDOW" only.
125-
mode (BlendMode): How to blend output of overlapping windows, "CONSTANT" or "GAUSSIAN". Defaults to "CONSTANT".
151+
mode (BlendModeType): How to blend output of overlapping windows, "CONSTANT" or "GAUSSIAN". Defaults to "CONSTANT".
126152
Applicable for "SLIDING_WINDOW" only.
127-
padding_mode (PytorchPadMode): Padding mode for ``inputs``, when ``roi_size`` is larger than inputs,
153+
padding_mode (PytorchPadModeType): Padding mode for ``inputs``, when ``roi_size`` is larger than inputs,
128154
"CONSTANT", "REFLECT", "REPLICATE", or "CIRCULAR". Defaults to "CONSTANT".
129155
Applicable for "SLIDING_WINDOW" only.
130156
inferer (InfererType): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
131157
model_path (Path): Path to the model file. Defaults to model/models.ts of current working dir.
158+
**kwargs: any other sliding window parameters to forward (e.g. `sigma_scale`, `cval`, etc.).
132159
"""
133160

134161
self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
@@ -147,6 +174,7 @@ def __init__(
147174
self._mode = mode
148175
self._padding_mode = padding_mode
149176
self._inferer = inferer
177+
self._implicit_params = self.filter_sw_kwargs(**kwargs) # Filter keyword args
150178

151179
# Add this so that the local model path can be set from the calling app
152180
self.model_path = model_path
@@ -370,6 +398,7 @@ def compute_impl(self, input_image, context):
370398
mode=self._mode,
371399
padding_mode=self._padding_mode,
372400
predictor=self.model,
401+
**self._implicit_params, # additional sliding window arguments
373402
)
374403
elif self._inferer == InfererType.SIMPLE:
375404
# Instantiates the SimpleInferer and directly uses its __call__ function

0 commit comments

Comments
 (0)