5858
5959def _ensure_bundle_in_sys_path (bundle_path : Union [str , Path ]) -> None :
6060 """Helper function to ensure bundle root is on sys.path for script imports.
61-
61+
6262 Args:
6363 bundle_path: Path to the bundle directory
6464 """
@@ -67,19 +67,17 @@ def _ensure_bundle_in_sys_path(bundle_path: Union[str, Path]) -> None:
6767 sys .path .insert (0 , bundle_root )
6868
6969
70- def _load_model_from_directory_bundle (
71- bundle_path : Path , device : torch .device , parser : Any = None
72- ) -> torch .nn .Module :
70+ def _load_model_from_directory_bundle (bundle_path : Path , device : torch .device , parser : Any = None ) -> torch .nn .Module :
7371 """Helper function to load model from a directory-based bundle.
74-
72+
7573 Args:
7674 bundle_path: Path to the bundle directory
7775 device: PyTorch device to load the model on
7876 parser: Optional ConfigParser for eager model loading
79-
77+
8078 Returns:
8179 torch.nn.Module: Loaded model network
82-
80+
8381 Raises:
8482 IOError: If model files are not found
8583 RuntimeError: If network cannot be instantiated from configs
@@ -105,22 +103,14 @@ def _load_model_from_directory_bundle(
105103 # Fallback to eager model with loaded weights
106104 if parser is None :
107105 raise RuntimeError ("Parser required for loading .pt checkpoint but not provided" )
108-
106+
109107 # Ensure bundle root is on sys.path so 'scripts.*' can be imported
110108 _ensure_bundle_in_sys_path (bundle_path )
111-
112- network = (
113- parser .get_parsed_content ("network" )
114- if parser .get ("network" ) is not None
115- else None
116- )
109+
110+ network = parser .get_parsed_content ("network" ) if parser .get ("network" ) is not None else None
117111 if network is None :
118112 # Backward compatibility: some bundles use "network_def" then to(device)
119- network = (
120- parser .get_parsed_content ("network_def" )
121- if parser .get ("network_def" ) is not None
122- else None
123- )
113+ network = parser .get_parsed_content ("network_def" ) if parser .get ("network_def" ) is not None else None
124114 if network is not None :
125115 network = network .to (device )
126116 if network is None :
@@ -143,11 +133,11 @@ def _load_model_from_directory_bundle(
143133
144134def _read_directory_bundle_config (bundle_path_obj : Path , config_names : List [str ]) -> ConfigParser :
145135 """Helper function to read bundle configuration from a directory-based bundle.
146-
136+
147137 Args:
148138 bundle_path_obj: Path object pointing to the bundle directory
149139 config_names: List of config names to read
150-
140+
151141 Returns:
152142 ConfigParser: Parsed configuration object
153143 """
@@ -182,7 +172,7 @@ def _read_directory_bundle_config(bundle_path_obj: Path, config_names: List[str]
182172
183173 parser .read_config (config_files )
184174 parser .parse ()
185-
175+
186176 return parser
187177
188178
@@ -725,9 +715,7 @@ def compute(self, op_input, op_output, context):
725715 self ._init_completed = True
726716
727717 # Load model using helper function
728- self ._model_network = _load_model_from_directory_bundle (
729- self ._bundle_path , self ._device , self ._parser
730- )
718+ self ._model_network = _load_model_from_directory_bundle (self ._bundle_path , self ._device , self ._parser )
731719 else :
732720 # Original ZIP bundle handling
733721 self ._model_network = torch .jit .load (self ._bundle_path , map_location = self ._device ).eval ()
@@ -893,7 +881,7 @@ def _receive_input(self, name: str, op_input, context):
893881 # Could be (W, H, D) for 3D models or (W, H, C) for 2D models
894882 if expected_spatial_dims == 2 :
895883 # This is a 2D model expecting (W, H, C) input
896- actual_channels = value .shape [- 1 ]
884+ actual_channels = value .shape [- 1 ]
897885 if expected_channels is not None and expected_channels != actual_channels :
898886 if expected_channels == 1 and actual_channels > 1 :
899887 logging .warning (
0 commit comments