Skip to content

Commit 8e9914e

Browse files
committed
Enhance type hinting in MonaiBundleInferenceOperator
- Added TYPE_CHECKING imports for torch to improve type hinting and static analysis. - Utilized type casting for model loading return values to ensure proper type safety and clarity. Signed-off-by: Victor Chang <[email protected]>
1 parent 33901aa commit 8e9914e

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

monai/deploy/operators/monai_bundle_inference_operator.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from copy import deepcopy
2121
from pathlib import Path
2222
from threading import Lock
23-
from typing import Any, Dict, List, Optional, Tuple, Type, Union
23+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, cast
2424

2525
import numpy as np
2626

@@ -29,6 +29,9 @@
2929

3030
from .inference_operator import InferenceOperator
3131

32+
if TYPE_CHECKING:
33+
import torch
34+
3235
MONAI_UTILS = "monai.utils"
3336
nibabel, _ = optional_import("nibabel", "3.2.1")
3437
torch, _ = optional_import("torch", "1.10.2")
@@ -93,12 +96,12 @@ def _load_model_from_directory_bundle(bundle_path: Path, device: torch.device, p
9396
# Load model based on file type
9497
if model_path.suffix == ".ts":
9598
# TorchScript bundle
96-
return torch.jit.load(str(model_path), map_location=device).eval()
99+
return cast("torch.nn.Module", torch.jit.load(str(model_path), map_location=device).eval())
97100
else:
98101
# .pt checkpoint: instantiate network from config and load state dict
99102
try:
100103
# Some .pt files may still be TorchScript; try jit first
101-
return torch.jit.load(str(model_path), map_location=device).eval()
104+
return cast("torch.nn.Module", torch.jit.load(str(model_path), map_location=device).eval())
102105
except Exception as ex:
103106
# Fallback to eager model with loaded weights
104107
if parser is None:
@@ -128,7 +131,7 @@ def _load_model_from_directory_bundle(bundle_path: Path, device: torch.device, p
128131
# Assume raw state dict
129132
state_dict = checkpoint
130133
network.load_state_dict(state_dict, strict=True)
131-
return network.eval()
134+
return cast("torch.nn.Module", network.eval())
132135

133136

134137
def _read_directory_bundle_config(bundle_path_obj: Path, config_names: List[str]) -> ConfigParser:

0 commit comments

Comments
 (0)