Skip to content

Commit 96ecfa0

Browse files
committed
fix
1 parent f826e3c commit 96ecfa0

File tree

2 files changed

+44
-22
lines changed

2 files changed

+44
-22
lines changed

tools/model_converter/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,26 +53,31 @@ options:
5353
### Examples
5454

5555
**List all models in configuration:**
56+
5657
```bash
5758
python model_converter.py example_config.json --list
5859
```
5960

6061
**Convert all models:**
62+
6163
```bash
6264
python model_converter.py example_config.json -o ./converted_models
6365
```
6466

6567
**Convert a specific model:**
68+
6669
```bash
6770
python model_converter.py example_config.json -o ./converted_models --model resnet50
6871
```
6972

7073
**Use custom cache directory:**
74+
7175
```bash
7276
python model_converter.py example_config.json -o ./output -c ./my_cache
7377
```
7478

7579
**Enable verbose logging:**
80+
7681
```bash
7782
python model_converter.py example_config.json -o ./output -v
7883
```
@@ -103,6 +108,7 @@ The configuration file is a JSON file with the following structure:
103108
**Important**: The `model_type` field enables automatic model detection when using [Intel's model_api](https://github.com/openvinotoolkit/model_api). When specified, this metadata is embedded in the OpenVINO IR, allowing `Model.create_model()` to automatically select the correct model wrapper class.
104109

105110
Common `model_type` values:
111+
106112
- `"Classification"` - Image classification models
107113
- `"DetectionModel"` - Object detection models
108114
- `"YOLOX"` - YOLOX detection models

tools/model_converter/model_converter.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
#!/usr/bin/env python3
2+
#
3+
# Copyright (C) 2024-2025 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
#
6+
27
"""
38
PyTorch to OpenVINO Model Converter
49
@@ -19,6 +24,7 @@
1924
import torch
2025
import torch.nn as nn
2126

27+
2228
class ModelConverter:
2329
"""Handles conversion of PyTorch models to OpenVINO format."""
2430

@@ -61,10 +67,11 @@ def get_labels(self, label_set: str) -> Optional[str]:
6167
"""
6268
if label_set == "IMAGENET1K_V1":
6369
from torchvision.models._meta import _IMAGENET_CATEGORIES
70+
6471
categories = _IMAGENET_CATEGORIES
6572
categories = [label.replace(" ", "_") for label in categories]
6673
return " ".join(categories)
67-
74+
6875
return None
6976

7077
def download_weights(self, url: str, filename: Optional[str] = None) -> Path:
@@ -91,7 +98,10 @@ def download_weights(self, url: str, filename: Optional[str] = None) -> Path:
9198
self.logger.info(f"Saving to: {cached_file}")
9299

93100
try:
94-
urllib.request.urlretrieve(url, cached_file) # nosemgrep: python.lang.security.audit.dynamic-urllib-use-detected.dynamic-urllib-use-detected
101+
urllib.request.urlretrieve( # noqa: S310
102+
url,
103+
cached_file,
104+
) # nosemgrep: python.lang.security.audit.dynamic-urllib-use-detected.dynamic-urllib-use-detected
95105
self.logger.info("✓ Download complete")
96106
return cached_file
97107
except Exception as e:
@@ -111,7 +121,9 @@ def load_model_class(self, class_path: str) -> type:
111121
try:
112122
module_path, class_name = class_path.rsplit(".", 1)
113123
self.logger.debug(f"Importing module: {module_path}")
114-
module = importlib.import_module(module_path) # nosemgrep: python.lang.security.audit.non-literal-import.non-literal-import
124+
module = importlib.import_module(
125+
module_path,
126+
) # nosemgrep: python.lang.security.audit.non-literal-import.non-literal-import
115127
model_class = getattr(module, class_name)
116128
self.logger.debug(f"Loaded class: {class_name}")
117129
return model_class
@@ -130,7 +142,11 @@ def load_checkpoint(self, checkpoint_path: Path) -> Dict[str, Any]:
130142
Checkpoint dictionary
131143
"""
132144
try:
133-
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True) # nosemgrep: trailofbits.python.pickles-in-pytorch.pickles-in-pytorch
145+
checkpoint = torch.load(
146+
checkpoint_path,
147+
map_location="cpu",
148+
weights_only=True,
149+
) # nosemgrep: trailofbits.python.pickles-in-pytorch.pickles-in-pytorch
134150
self.logger.debug(f"Loaded checkpoint from: {checkpoint_path}")
135151
return checkpoint
136152
except Exception as e:
@@ -161,22 +177,21 @@ def create_model(
161177
model = checkpoint["model"]
162178
elif "state_dict" in checkpoint:
163179
# Cannot reconstruct architecture from state_dict alone
164-
raise ValueError(
180+
error_msg = (
165181
"Checkpoint contains only state_dict. "
166182
"Please specify the model class instead of torch.nn.Module"
167183
)
184+
raise ValueError(error_msg)
168185
else:
169186
# Assume checkpoint is the model itself
170187
model = checkpoint
171-
188+
172189
if not isinstance(model, nn.Module):
173-
raise ValueError("Checkpoint does not contain a valid model")
190+
error_msg = "Checkpoint does not contain a valid model"
191+
raise ValueError(error_msg)
174192
else:
175193
# Instantiate model class
176-
if model_params:
177-
model = model_class(**model_params)
178-
else:
179-
model = model_class()
194+
model = model_class(**model_params) if model_params else model_class()
180195

181196
# Load weights
182197
if "state_dict" in checkpoint:
@@ -205,7 +220,7 @@ def export_to_openvino(
205220
output_path: Path,
206221
input_names: Optional[List[str]] = None,
207222
output_names: Optional[List[str]] = None,
208-
metadata: Optional[Dict[tuple, str]] = None
223+
metadata: Optional[Dict[tuple, str]] = None,
209224
) -> Path:
210225
"""
211226
Export PyTorch model to OpenVINO format.
@@ -232,11 +247,11 @@ def export_to_openvino(
232247

233248
# Reshape model to fixed input shape (remove dynamic dimensions)
234249
first_input = ov_model.input(0)
235-
input_name_for_reshape = list(first_input.get_names())[0] if first_input.get_names() else 0
236-
250+
input_name_for_reshape = next(iter(first_input.get_names())) if first_input.get_names() else 0
251+
237252
self.logger.debug(f"Setting fixed input shape: {input_shape}")
238253
ov_model.reshape({input_name_for_reshape: input_shape})
239-
254+
240255
# Post-process the model
241256
ov_model = self._postprocess_openvino_model(
242257
ov_model,
@@ -355,7 +370,7 @@ def process_model_config(self, config: Dict[str, Any]) -> bool:
355370
if labels_config:
356371
labels = self.get_labels(labels_config)
357372
if labels:
358-
metadata[("model_info", "labels")] = labels
373+
metadata["model_info", "labels"] = labels
359374
self.logger.info(f"Added {labels_config} labels to metadata")
360375
else:
361376
self.logger.warning(f"Could not load labels for: {labels_config}")
@@ -366,15 +381,16 @@ def process_model_config(self, config: Dict[str, Any]) -> bool:
366381
output_path=output_path,
367382
input_names=input_names,
368383
output_names=output_names,
369-
metadata=metadata
384+
metadata=metadata,
370385
)
371386

372387
self.logger.info(f"✓ Successfully converted {model_short_name}")
373388
return True
374389

375-
except Exception as e:
390+
except (ValueError, RuntimeError, ImportError, FileNotFoundError) as e:
376391
self.logger.error(f"✗ Failed to process model {model_short_name}: {e}")
377392
import traceback
393+
378394
self.logger.debug(traceback.format_exc())
379395
return False
380396

@@ -394,7 +410,7 @@ def process_config_file(
394410
Tuple of (successful_count, failed_count)
395411
"""
396412
try:
397-
with open(config_path) as f:
413+
with Path(config_path).open() as f:
398414
config = json.load(f)
399415
except Exception as e:
400416
self.logger.error(f"Failed to load configuration file: {e}")
@@ -431,9 +447,9 @@ def process_config_file(
431447
def list_models(config_path: Path):
432448
"""List all models in a configuration file."""
433449
try:
434-
with open(config_path) as f:
450+
with config_path.open() as f:
435451
config = json.load(f)
436-
except Exception as e:
452+
except (FileNotFoundError, json.JSONDecodeError, PermissionError) as e:
437453
print(f"Error loading configuration: {e}", file=sys.stderr)
438454
return
439455

@@ -566,7 +582,7 @@ def main():
566582
logger.info("=" * 80)
567583

568584
return 0 if failed == 0 else 1
569-
except Exception as e:
585+
except (ValueError, RuntimeError, ImportError, FileNotFoundError) as e:
570586
logger.error(f"Failed to process model: {e}")
571587
return 1
572588

0 commit comments

Comments
 (0)