Skip to content

Commit b267d20

Browse files
committed
mode + padding_mode input args added
Signed-off-by: bluna301 <[email protected]>
1 parent dec9305 commit b267d20

File tree

1 file changed

+51
-2
lines changed

1 file changed

+51
-2
lines changed

monai/deploy/operators/monai_seg_inference_operator.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2021-2023 MONAI Consortium
1+
# Copyright 2021-2025 MONAI Consortium
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -47,13 +47,29 @@
4747
__all__ = ["MonaiSegInferenceOperator", "InfererType", "InMemImageReader"]
4848

4949

50+
class BlendModeType(StrEnum):
51+
"""Represents the supported blend modes for sliding window inference."""
52+
53+
CONSTANT = "constant"
54+
GAUSSIAN = "gaussian"
55+
56+
5057
class InfererType(StrEnum):
5158
"""Represents the supported types of the inferer, e.g. Simple and Sliding Window."""
5259

5360
SIMPLE = "simple"
5461
SLIDING_WINDOW = "sliding_window"
5562

5663

64+
class PytorchPadModeType(StrEnum):
65+
"""Represents the supported padding modes for sliding window inference."""
66+
67+
CONSTANT = "constant"
68+
REFLECT = "reflect"
69+
REPLICATE = "replicate"
70+
CIRCULAR = "circular"
71+
72+
5773
# @md.env(pip_packages=["monai>=1.0.0", "torch>=1.10.2", "numpy>=1.21"])
5874
class MonaiSegInferenceOperator(InferenceOperator):
5975
"""This segmentation operator uses MONAI transforms and Sliding Window Inference.
@@ -85,6 +101,8 @@ def __init__(
85101
model_name: Optional[str] = "",
86102
overlap: float = 0.25,
87103
sw_batch_size: int = 4,
104+
mode: Union[BlendModeType, str] = BlendModeType.CONSTANT,
105+
padding_mode: Union[PytorchPadModeType, str] = PytorchPadModeType.CONSTANT,
88106
inferer: Union[InfererType, str] = InfererType.SLIDING_WINDOW,
89107
model_path: Path = MODEL_LOCAL_PATH,
90108
**kwargs,
@@ -103,7 +121,12 @@ def __init__(
103121
overlap (float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25.
104122
Applicable for "SLIDING_WINDOW" only.
105123
sw_batch_size(int): The batch size to run window slices. Defaults to 4.
106-
Applicable for "SLIDING_WINDOW" only.
124+
Applicable for "SLIDING_WINDOW" only.
125+
mode (BlendMode): How to blend output of overlapping windows, "CONSTANT" or "GAUSSIAN". Defaults to "CONSTANT".
126+
Applicable for "SLIDING_WINDOW" only.
127+
padding_mode (PytorchPadMode): Padding mode for ``inputs``, when ``roi_size`` is larger than inputs,
128+
"CONSTANT", "REFLECT", "REPLICATE", or "CIRCULAR". Defaults to "CONSTANT".
129+
Applicable for "SLIDING_WINDOW" only.
107130
inferer (InfererType): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
108131
model_path (Path): Path to the model file. Defaults to model/models.ts of current working dir.
109132
"""
@@ -121,6 +144,8 @@ def __init__(
121144
self._model_name = model_name.strip() if isinstance(model_name, str) else ""
122145
self._overlap = overlap
123146
self._sw_batch_size = sw_batch_size
147+
self._mode = mode
148+
self._padding_mode = padding_mode
124149
self._inferer = inferer
125150

126151
# Add this so that the local model path can be set from the calling app
@@ -215,6 +240,28 @@ def sw_batch_size(self, val: int):
215240
raise ValueError("sw_batch_size must be a positive integer.")
216241
self._sw_batch_size = val
217242

243+
@property
244+
def mode(self) -> Union[BlendModeType, str]:
245+
"""The blend mode used during sliding window inference"""
246+
return self._mode
247+
248+
@mode.setter
249+
def mode(self, val: BlendModeType):
250+
if not isinstance(val, BlendModeType):
251+
raise ValueError(f"Value must be of the correct type {BlendModeType}.")
252+
self._mode = val
253+
254+
@property
255+
def padding_mode(self) -> Union[PytorchPadModeType, str]:
256+
"""The padding mode to use when padding input images for inference"""
257+
return self._padding_mode
258+
259+
@padding_mode.setter
260+
def padding_mode(self, val: PytorchPadModeType):
261+
if not isinstance(val, PytorchPadModeType):
262+
raise ValueError(f"Value must be of the correct type {PytorchPadModeType}.")
263+
self._padding_mode = val
264+
218265
@property
219266
def inferer(self) -> Union[InfererType, str]:
220267
"""The type of inferer to use"""
@@ -320,6 +367,8 @@ def compute_impl(self, input_image, context):
320367
roi_size=self._roi_size,
321368
sw_batch_size=self.sw_batch_size,
322369
overlap=self.overlap,
370+
mode=self._mode,
371+
padding_mode=self._padding_mode,
323372
predictor=self.model,
324373
)
325374
elif self._inferer == InfererType.SIMPLE:

0 commit comments

Comments
 (0)