3030from .inference_operator import InferenceOperator
3131
3232if TYPE_CHECKING :
33- import torch
33+ import torch as torch_typing
3434
3535MONAI_UTILS = "monai.utils"
3636nibabel , _ = optional_import ("nibabel" , "3.2.1" )
@@ -70,7 +70,9 @@ def _ensure_bundle_in_sys_path(bundle_path: Union[str, Path]) -> None:
7070 sys .path .insert (0 , bundle_root )
7171
7272
73- def _load_model_from_directory_bundle (bundle_path : Path , device : torch .device , parser : Any = None ) -> torch .nn .Module :
73+ def _load_model_from_directory_bundle (
74+ bundle_path : Path , device : "torch_typing.device" , parser : Any = None
75+ ) -> "torch_typing.nn.Module" :
7476 """Helper function to load model from a directory-based bundle.
7577
7678 Args:
@@ -79,7 +81,7 @@ def _load_model_from_directory_bundle(bundle_path: Path, device: torch.device, p
7981 parser: Optional ConfigParser for eager model loading
8082
8183 Returns:
82- torch .nn.Module: Loaded model network
84+ torch_typing .nn.Module: Loaded model network
8385
8486 Raises:
8587 IOError: If model files are not found
@@ -96,12 +98,12 @@ def _load_model_from_directory_bundle(bundle_path: Path, device: torch.device, p
9698 # Load model based on file type
9799 if model_path .suffix == ".ts" :
98100 # TorchScript bundle
99- return cast ("torch .nn.Module" , torch .jit .load (str (model_path ), map_location = device ).eval ())
101+ return cast ("torch_typing .nn.Module" , torch .jit .load (str (model_path ), map_location = device ).eval ())
100102 else :
101103 # .pt checkpoint: instantiate network from config and load state dict
102104 try :
103105 # Some .pt files may still be TorchScript; try jit first
104- return cast ("torch .nn.Module" , torch .jit .load (str (model_path ), map_location = device ).eval ())
106+ return cast ("torch_typing .nn.Module" , torch .jit .load (str (model_path ), map_location = device ).eval ())
105107 except Exception as ex :
106108 # Fallback to eager model with loaded weights
107109 if parser is None :
@@ -131,7 +133,7 @@ def _load_model_from_directory_bundle(bundle_path: Path, device: torch.device, p
131133 # Assume raw state dict
132134 state_dict = checkpoint
133135 network .load_state_dict (state_dict , strict = True )
134- return cast ("torch .nn.Module" , network .eval ())
136+ return cast ("torch_typing .nn.Module" , network .eval ())
135137
136138
137139def _read_directory_bundle_config (bundle_path_obj : Path , config_names : List [str ]) -> ConfigParser :
0 commit comments