Skip to content

Commit 94eafcd

Browse files
committed
Refactor MonaiBundleInferenceOperator for improved path handling and image conversion
- Updated bundle path handling to use `Path` objects for consistency and clarity. - Simplified image conversion logic by consolidating conditional checks into a single line for better readability. - Ensured that directory checks and model loading operations utilize the updated path handling. Signed-off-by: Victor Chang <[email protected]>
1 parent 5866be6 commit 94eafcd

File tree

1 file changed

+10
-27
lines changed

1 file changed

+10
-27
lines changed

monai/deploy/operators/monai_bundle_inference_operator.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -561,8 +561,10 @@ def _init_config(self, config_names):
561561
"""
562562

563563
# Ensure bundle root is on sys.path for directory-based bundles
564-
if self._bundle_path and self._bundle_path.is_dir():
565-
_ensure_bundle_in_sys_path(self._bundle_path)
564+
if self._bundle_path:
565+
bundle_path_obj = Path(self._bundle_path)
566+
if bundle_path_obj.is_dir():
567+
_ensure_bundle_in_sys_path(bundle_path_obj)
566568

567569
parser = get_bundle_config(str(self._bundle_path), config_names)
568570
self._parser = parser
@@ -697,7 +699,7 @@ def compute(self, op_input, op_output, context):
697699
if not self._init_completed:
698700
with self._lock:
699701
if not self._init_completed:
700-
self._bundle_path = self._model_network.path
702+
self._bundle_path = Path(self._model_network.path)
701703
logging.info(f"Parsing from bundle_path: {self._bundle_path}")
702704
self._init_config(self._bundle_config_names.config_names)
703705
self._init_completed = True
@@ -708,7 +710,8 @@ def compute(self, op_input, op_output, context):
708710
logging.debug(f"Model network not loaded. Trying to load from model path: {self._bundle_path}")
709711

710712
# Check if bundle_path is a directory
711-
if self._bundle_path.is_dir():
713+
bundle_path_obj = Path(self._bundle_path)
714+
if bundle_path_obj.is_dir():
712715
# Ensure device is set
713716
if not hasattr(self, "_device"):
714717
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -720,10 +723,10 @@ def compute(self, op_input, op_output, context):
720723
self._init_completed = True
721724

722725
# Load model using helper function
723-
self._model_network = _load_model_from_directory_bundle(self._bundle_path, self._device, self._parser)
726+
self._model_network = _load_model_from_directory_bundle(bundle_path_obj, self._device, self._parser)
724727
else:
725728
# Original ZIP bundle handling
726-
self._model_network = torch.jit.load(self._bundle_path, map_location=self._device).eval()
729+
self._model_network = torch.jit.load(bundle_path_obj, map_location=self._device).eval()
727730
else:
728731
raise IOError("Model network is not load and model file not found.")
729732

@@ -963,27 +966,7 @@ def _send_output(self, value: Any, name: str, metadata: Dict, op_output, context
963966

964967
logging.debug(f"Output {name} numpy image shape: {value.shape}")
965968

966-
# Handle 2D masks and generic 2D tensors gracefully
967-
if value.ndim == 2:
968-
# Already HxW image; binarize/scale left to downstream operators
969-
out_img = value.astype(np.uint8)
970-
result: Any = Image(out_img, metadata=metadata)
971-
elif value.ndim == 3:
972-
# Could be (C, H, W) with C==1 or (H, W, C)
973-
if value.shape[0] == 1: # (1, H, W) -> (H, W)
974-
out_img = value[0].astype(np.uint8)
975-
result = Image(out_img, metadata=metadata)
976-
elif value.shape[-1] == 1: # (H, W, 1) -> (H, W)
977-
out_img = value[..., 0].astype(np.uint8)
978-
result = Image(out_img, metadata=metadata)
979-
else:
980-
# Fallback to original behavior for 3D volumetric layout assumptions
981-
out_img = np.swapaxes(np.squeeze(value, 0), 0, 2).astype(np.uint8)
982-
result = Image(out_img, metadata=metadata)
983-
else:
984-
# Keep existing behavior for higher-dimensional data (e.g., 3D volumes)
985-
out_img = np.swapaxes(np.squeeze(value, 0), 0, 2).astype(np.uint8)
986-
result = Image(out_img, metadata=metadata)
969+
result: Any = Image(np.swapaxes(np.squeeze(value, 0), 0, 2).astype(np.uint8), metadata=metadata)
987970
logging.debug(f"Converted Image shape: {result.asnumpy().shape}")
988971
elif otype == np.ndarray:
989972
result = np.asarray(value)

0 commit comments

Comments
 (0)