Skip to content

Commit 5866be6

Browse files
committed
Refactor type hinting in MonaiBundleInferenceOperator
- Updated type hints for device and return values in model loading functions to use the alias `torch_typing` for improved clarity and consistency. - Enhanced type safety by ensuring all model-related return types are correctly annotated. Signed-off-by: Victor Chang <[email protected]>
1 parent 8e9914e commit 5866be6

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

monai/deploy/operators/monai_bundle_inference_operator.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from .inference_operator import InferenceOperator
3131

3232
if TYPE_CHECKING:
33-
import torch
33+
import torch as torch_typing
3434

3535
MONAI_UTILS = "monai.utils"
3636
nibabel, _ = optional_import("nibabel", "3.2.1")
@@ -70,7 +70,9 @@ def _ensure_bundle_in_sys_path(bundle_path: Union[str, Path]) -> None:
7070
sys.path.insert(0, bundle_root)
7171

7272

73-
def _load_model_from_directory_bundle(bundle_path: Path, device: torch.device, parser: Any = None) -> torch.nn.Module:
73+
def _load_model_from_directory_bundle(
74+
bundle_path: Path, device: "torch_typing.device", parser: Any = None
75+
) -> "torch_typing.nn.Module":
7476
"""Helper function to load model from a directory-based bundle.
7577
7678
Args:
@@ -79,7 +81,7 @@ def _load_model_from_directory_bundle(bundle_path: Path, device: torch.device, p
7981
parser: Optional ConfigParser for eager model loading
8082
8183
Returns:
82-
torch.nn.Module: Loaded model network
84+
torch_typing.nn.Module: Loaded model network
8385
8486
Raises:
8587
IOError: If model files are not found
@@ -96,12 +98,12 @@ def _load_model_from_directory_bundle(bundle_path: Path, device: torch.device, p
9698
# Load model based on file type
9799
if model_path.suffix == ".ts":
98100
# TorchScript bundle
99-
return cast("torch.nn.Module", torch.jit.load(str(model_path), map_location=device).eval())
101+
return cast("torch_typing.nn.Module", torch.jit.load(str(model_path), map_location=device).eval())
100102
else:
101103
# .pt checkpoint: instantiate network from config and load state dict
102104
try:
103105
# Some .pt files may still be TorchScript; try jit first
104-
return cast("torch.nn.Module", torch.jit.load(str(model_path), map_location=device).eval())
106+
return cast("torch_typing.nn.Module", torch.jit.load(str(model_path), map_location=device).eval())
105107
except Exception as ex:
106108
# Fallback to eager model with loaded weights
107109
if parser is None:
@@ -131,7 +133,7 @@ def _load_model_from_directory_bundle(bundle_path: Path, device: torch.device, p
131133
# Assume raw state dict
132134
state_dict = checkpoint
133135
network.load_state_dict(state_dict, strict=True)
134-
return cast("torch.nn.Module", network.eval())
136+
return cast("torch_typing.nn.Module", network.eval())
135137

136138

137139
def _read_directory_bundle_config(bundle_path_obj: Path, config_names: List[str]) -> ConfigParser:

0 commit comments

Comments
 (0)