1- # Copyright 2021-2023 MONAI Consortium
1+ # Copyright 2021-2025 MONAI Consortium
22# Licensed under the Apache License, Version 2.0 (the "License");
33# you may not use this file except in compliance with the License.
44# You may obtain a copy of the License at
4747__all__ = ["MonaiSegInferenceOperator" , "InfererType" , "InMemImageReader" ]
4848
4949
50+ class BlendModeType (StrEnum ):
51+ """Represents the supported blend modes for sliding window inference."""
52+
53+ CONSTANT = "constant"
54+ GAUSSIAN = "gaussian"
55+
56+
5057class InfererType (StrEnum ):
5158 """Represents the supported types of the inferer, e.g. Simple and Sliding Window."""
5259
5360 SIMPLE = "simple"
5461 SLIDING_WINDOW = "sliding_window"
5562
5663
64+ class PytorchPadModeType (StrEnum ):
65+ """Represents the supported padding modes for sliding window inference."""
66+
67+ CONSTANT = "constant"
68+ REFLECT = "reflect"
69+ REPLICATE = "replicate"
70+ CIRCULAR = "circular"
71+
72+
5773# @md.env(pip_packages=["monai>=1.0.0", "torch>=1.10.2", "numpy>=1.21"])
5874class MonaiSegInferenceOperator (InferenceOperator ):
5975 """This segmentation operator uses MONAI transforms and Sliding Window Inference.
@@ -85,6 +101,8 @@ def __init__(
85101 model_name : Optional [str ] = "" ,
86102 overlap : float = 0.25 ,
87103 sw_batch_size : int = 4 ,
104+ mode : Union [BlendModeType , str ] = BlendModeType .CONSTANT ,
105+ padding_mode : Union [PytorchPadModeType , str ] = PytorchPadModeType .CONSTANT ,
88106 inferer : Union [InfererType , str ] = InfererType .SLIDING_WINDOW ,
89107 model_path : Path = MODEL_LOCAL_PATH ,
90108 ** kwargs ,
@@ -103,7 +121,12 @@ def __init__(
103121 overlap (float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25.
104122 Applicable for "SLIDING_WINDOW" only.
105123 sw_batch_size(int): The batch size to run window slices. Defaults to 4.
106- Applicable for "SLIDING_WINDOW" only.
124+ Applicable for "SLIDING_WINDOW" only.
125+ mode (BlendMode): How to blend output of overlapping windows, "CONSTANT" or "GAUSSIAN". Defaults to "CONSTANT".
126+ Applicable for "SLIDING_WINDOW" only.
127+ padding_mode (PytorchPadMode): Padding mode for ``inputs``, when ``roi_size`` is larger than inputs,
128+ "CONSTANT", "REFLECT", "REPLICATE", or "CIRCULAR". Defaults to "CONSTANT".
129+ Applicable for "SLIDING_WINDOW" only.
107130 inferer (InfererType): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
108131 model_path (Path): Path to the model file. Defaults to model/models.ts of current working dir.
109132 """
@@ -121,6 +144,8 @@ def __init__(
121144 self ._model_name = model_name .strip () if isinstance (model_name , str ) else ""
122145 self ._overlap = overlap
123146 self ._sw_batch_size = sw_batch_size
147+ self ._mode = mode
148+ self ._padding_mode = padding_mode
124149 self ._inferer = inferer
125150
126151 # Add this so that the local model path can be set from the calling app
@@ -215,6 +240,28 @@ def sw_batch_size(self, val: int):
215240 raise ValueError ("sw_batch_size must be a positive integer." )
216241 self ._sw_batch_size = val
217242
243+ @property
244+ def mode (self ) -> Union [BlendModeType , str ]:
245+ """The blend mode used during sliding window inference"""
246+ return self ._mode
247+
248+ @mode .setter
249+ def mode (self , val : BlendModeType ):
250+ if not isinstance (val , BlendModeType ):
251+ raise ValueError (f"Value must be of the correct type { BlendModeType } ." )
252+ self ._mode = val
253+
254+ @property
255+ def padding_mode (self ) -> Union [PytorchPadModeType , str ]:
256+ """The padding mode to use when padding input images for inference"""
257+ return self ._padding_mode
258+
259+ @padding_mode .setter
260+ def padding_mode (self , val : PytorchPadModeType ):
261+ if not isinstance (val , PytorchPadModeType ):
262+ raise ValueError (f"Value must be of the correct type { PytorchPadModeType } ." )
263+ self ._padding_mode = val
264+
218265 @property
219266 def inferer (self ) -> Union [InfererType , str ]:
220267 """The type of inferer to use"""
@@ -320,6 +367,8 @@ def compute_impl(self, input_image, context):
320367 roi_size = self ._roi_size ,
321368 sw_batch_size = self .sw_batch_size ,
322369 overlap = self .overlap ,
370+ mode = self ._mode ,
371+ padding_mode = self ._padding_mode ,
323372 predictor = self .model ,
324373 )
325374 elif self ._inferer == InfererType .SIMPLE :
0 commit comments