99# See the License for the specific language governing permissions and
1010# limitations under the License.
1111
12- from typing import Any , Dict , Tuple , Union
13-
14- from monai .deploy .core import Image
1512from monai .deploy .operators .monai_bundle_inference_operator import MonaiBundleInferenceOperator , get_bundle_config
1613from monai .deploy .utils .importutil import optional_import
17- from monai .transforms import ConcatItemsd , ResampleToMatch
14+ from typing import Any , Dict , Tuple , Union
15+ from monai .deploy .core import Image
1816
17+ MONAI_UTILS = "monai.utils"
18+ nibabel , _ = optional_import ("nibabel" , "3.2.1" )
1919torch , _ = optional_import ("torch" , "1.10.2" )
20+
21+ NdarrayOrTensor , _ = optional_import ("monai.config" , name = "NdarrayOrTensor" )
2022MetaTensor , _ = optional_import ("monai.data.meta_tensor" , name = "MetaTensor" )
23+ PostFix , _ = optional_import ("monai.utils.enums" , name = "PostFix" ) # For the default meta_key_postfix
24+ first , _ = optional_import ("monai.utils.misc" , name = "first" )
25+ ensure_tuple , _ = optional_import (MONAI_UTILS , name = "ensure_tuple" )
26+ convert_to_dst_type , _ = optional_import (MONAI_UTILS , name = "convert_to_dst_type" )
27+ Key , _ = optional_import (MONAI_UTILS , name = "ImageMetaKey" )
28+ MetaKeys , _ = optional_import (MONAI_UTILS , name = "MetaKeys" )
29+ SpaceKeys , _ = optional_import (MONAI_UTILS , name = "SpaceKeys" )
30+ Compose_ , _ = optional_import ("monai.transforms" , name = "Compose" )
31+ ConfigParser_ , _ = optional_import ("monai.bundle" , name = "ConfigParser" )
32+ MapTransform_ , _ = optional_import ("monai.transforms" , name = "MapTransform" )
33+ SimpleInferer , _ = optional_import ("monai.inferers" , name = "SimpleInferer" )
34+
35+ Compose : Any = Compose_
36+ MapTransform : Any = MapTransform_
37+ ConfigParser : Any = ConfigParser_
2138__all__ = ["MONetBundleInferenceOperator" ]
2239
2340
2441class MONetBundleInferenceOperator (MonaiBundleInferenceOperator ):
2542 """
26- A specialized operator for performing inference using the MONet bundle.
43+ A specialized operator for performing inference using the MONAI nnUNet bundle.
2744 This operator extends the `MonaiBundleInferenceOperator` to support nnUNet-specific
2845 configurations and prediction logic. It initializes the nnUNet predictor and provides
2946 a method for performing inference on input data.
30-
47+
3148 Attributes
3249 ----------
3350 _nnunet_predictor : torch.nn.Module
3451 The nnUNet predictor module used for inference.
35-
52+
3653 Methods
3754 -------
3855 _init_config(config_names)
@@ -48,31 +65,25 @@ def __init__(
4865 ** kwargs ,
4966 ):
5067
51- super ().__init__ (* args , ** kwargs )
52-
53- self ._nnunet_predictor : torch .nn .Module = None
5468
55- def _init_config (self , config_names ):
69+ super ().__init__ (* args , ** kwargs )
70+
71+ self ._nnunet_predictor : torch .nn .Module = None
72+
73+
74+ def _init_config (self , config_names ):
5675
5776 super ()._init_config (config_names )
58- parser = get_bundle_config (str (self ._bundle_path ), config_names )
77+ parser = get_bundle_config (str (self ._bundle_path ), config_names )
5978 self ._parser = parser
6079
6180 self ._nnunet_predictor = parser .get_parsed_content ("network_def" )
6281
6382 def predict (self , data : Any , * args , ** kwargs ) -> Union [Image , Any , Tuple [Any , ...], Dict [Any , Any ]]:
64- """Predicts output using the inferer. If multimodal data is provided as keyword arguments,
65- it concatenates the data with the main input data."""
83+ """Predicts output using the inferer."""
6684
6785 self ._nnunet_predictor .predictor .network = self ._model_network
68-
69- if len (kwargs ) > 0 :
70- multimodal_data = {"image" : data }
71- for key in kwargs .keys ():
72- if isinstance (kwargs [key ], MetaTensor ):
73- multimodal_data [key ] = ResampleToMatch (mode = "bilinear" )(kwargs [key ], img_dst = data
74- )
75- data = ConcatItemsd (keys = list (multimodal_data .keys ()),name = "image" )(multimodal_data )["image" ]
86+ #os.environ['nnUNet_def_n_proc'] = "1"
7687 if len (data .shape ) == 4 :
7788 data = data [None ]
7889 return self ._nnunet_predictor (data )
0 commit comments