11#!/usr/bin/env python3
2+ #
3+ # Copyright (C) 2024-2025 Intel Corporation
4+ # SPDX-License-Identifier: Apache-2.0
5+ #
6+
27"""
38PyTorch to OpenVINO Model Converter
49
1924import torch
2025import torch .nn as nn
2126
27+
2228class ModelConverter :
2329 """Handles conversion of PyTorch models to OpenVINO format."""
2430
@@ -61,10 +67,11 @@ def get_labels(self, label_set: str) -> Optional[str]:
6167 """
6268 if label_set == "IMAGENET1K_V1" :
6369 from torchvision .models ._meta import _IMAGENET_CATEGORIES
70+
6471 categories = _IMAGENET_CATEGORIES
6572 categories = [label .replace (" " , "_" ) for label in categories ]
6673 return " " .join (categories )
67-
74+
6875 return None
6976
7077 def download_weights (self , url : str , filename : Optional [str ] = None ) -> Path :
@@ -91,7 +98,10 @@ def download_weights(self, url: str, filename: Optional[str] = None) -> Path:
9198 self .logger .info (f"Saving to: { cached_file } " )
9299
93100 try :
94- urllib .request .urlretrieve (url , cached_file ) # nosemgrep: python.lang.security.audit.dynamic-urllib-use-detected.dynamic-urllib-use-detected
101+ urllib .request .urlretrieve ( # noqa: S310
102+ url ,
103+ cached_file ,
104+ ) # nosemgrep: python.lang.security.audit.dynamic-urllib-use-detected.dynamic-urllib-use-detected
95105 self .logger .info ("✓ Download complete" )
96106 return cached_file
97107 except Exception as e :
@@ -111,7 +121,9 @@ def load_model_class(self, class_path: str) -> type:
111121 try :
112122 module_path , class_name = class_path .rsplit ("." , 1 )
113123 self .logger .debug (f"Importing module: { module_path } " )
114- module = importlib .import_module (module_path ) # nosemgrep: python.lang.security.audit.non-literal-import.non-literal-import
124+ module = importlib .import_module (
125+ module_path ,
126+ ) # nosemgrep: python.lang.security.audit.non-literal-import.non-literal-import
115127 model_class = getattr (module , class_name )
116128 self .logger .debug (f"Loaded class: { class_name } " )
117129 return model_class
@@ -130,7 +142,11 @@ def load_checkpoint(self, checkpoint_path: Path) -> Dict[str, Any]:
130142 Checkpoint dictionary
131143 """
132144 try :
133- checkpoint = torch .load (checkpoint_path , map_location = "cpu" , weights_only = True ) # nosemgrep: trailofbits.python.pickles-in-pytorch.pickles-in-pytorch
145+ checkpoint = torch .load (
146+ checkpoint_path ,
147+ map_location = "cpu" ,
148+ weights_only = True ,
149+ ) # nosemgrep: trailofbits.python.pickles-in-pytorch.pickles-in-pytorch
134150 self .logger .debug (f"Loaded checkpoint from: { checkpoint_path } " )
135151 return checkpoint
136152 except Exception as e :
@@ -161,22 +177,21 @@ def create_model(
161177 model = checkpoint ["model" ]
162178 elif "state_dict" in checkpoint :
163179 # Cannot reconstruct architecture from state_dict alone
164- raise ValueError (
180+ error_msg = (
165181 "Checkpoint contains only state_dict. "
166182 "Please specify the model class instead of torch.nn.Module"
167183 )
184+ raise ValueError (error_msg )
168185 else :
169186 # Assume checkpoint is the model itself
170187 model = checkpoint
171-
188+
172189 if not isinstance (model , nn .Module ):
173- raise ValueError ("Checkpoint does not contain a valid model" )
190+ error_msg = "Checkpoint does not contain a valid model"
191+ raise ValueError (error_msg )
174192 else :
175193 # Instantiate model class
176- if model_params :
177- model = model_class (** model_params )
178- else :
179- model = model_class ()
194+ model = model_class (** model_params ) if model_params else model_class ()
180195
181196 # Load weights
182197 if "state_dict" in checkpoint :
@@ -205,7 +220,7 @@ def export_to_openvino(
205220 output_path : Path ,
206221 input_names : Optional [List [str ]] = None ,
207222 output_names : Optional [List [str ]] = None ,
208- metadata : Optional [Dict [tuple , str ]] = None
223+ metadata : Optional [Dict [tuple , str ]] = None ,
209224 ) -> Path :
210225 """
211226 Export PyTorch model to OpenVINO format.
@@ -232,11 +247,11 @@ def export_to_openvino(
232247
233248 # Reshape model to fixed input shape (remove dynamic dimensions)
234249 first_input = ov_model .input (0 )
235- input_name_for_reshape = list ( first_input .get_names ())[ 0 ] if first_input .get_names () else 0
236-
250+ input_name_for_reshape = next ( iter ( first_input .get_names ())) if first_input .get_names () else 0
251+
237252 self .logger .debug (f"Setting fixed input shape: { input_shape } " )
238253 ov_model .reshape ({input_name_for_reshape : input_shape })
239-
254+
240255 # Post-process the model
241256 ov_model = self ._postprocess_openvino_model (
242257 ov_model ,
@@ -355,7 +370,7 @@ def process_model_config(self, config: Dict[str, Any]) -> bool:
355370 if labels_config :
356371 labels = self .get_labels (labels_config )
357372 if labels :
358- metadata [( "model_info" , "labels" ) ] = labels
373+ metadata ["model_info" , "labels" ] = labels
359374 self .logger .info (f"Added { labels_config } labels to metadata" )
360375 else :
361376 self .logger .warning (f"Could not load labels for: { labels_config } " )
@@ -366,15 +381,16 @@ def process_model_config(self, config: Dict[str, Any]) -> bool:
366381 output_path = output_path ,
367382 input_names = input_names ,
368383 output_names = output_names ,
369- metadata = metadata
384+ metadata = metadata ,
370385 )
371386
372387 self .logger .info (f"✓ Successfully converted { model_short_name } " )
373388 return True
374389
375- except Exception as e :
390+ except ( ValueError , RuntimeError , ImportError , FileNotFoundError ) as e :
376391 self .logger .error (f"✗ Failed to process model { model_short_name } : { e } " )
377392 import traceback
393+
378394 self .logger .debug (traceback .format_exc ())
379395 return False
380396
@@ -394,7 +410,7 @@ def process_config_file(
394410 Tuple of (successful_count, failed_count)
395411 """
396412 try :
397- with open (config_path ) as f :
413+ with Path (config_path ). open ( ) as f :
398414 config = json .load (f )
399415 except Exception as e :
400416 self .logger .error (f"Failed to load configuration file: { e } " )
@@ -431,9 +447,9 @@ def process_config_file(
431447def list_models (config_path : Path ):
432448 """List all models in a configuration file."""
433449 try :
434- with open (config_path ) as f :
450+ with config_path . open () as f :
435451 config = json .load (f )
436- except Exception as e :
452+ except ( FileNotFoundError , json . JSONDecodeError , PermissionError ) as e :
437453 print (f"Error loading configuration: { e } " , file = sys .stderr )
438454 return
439455
@@ -566,7 +582,7 @@ def main():
566582 logger .info ("=" * 80 )
567583
568584 return 0 if failed == 0 else 1
569- except Exception as e :
585+ except ( ValueError , RuntimeError , ImportError , FileNotFoundError ) as e :
570586 logger .error (f"Failed to process model: { e } " )
571587 return 1
572588
0 commit comments