|
45 | 45 | get_node_names, |
46 | 46 | get_output_names, |
47 | 47 | get_output_shapes, |
48 | | - infer_shapes, |
49 | 48 | remove_node_training_mode, |
50 | 49 | ) |
51 | 50 | from modelopt.torch.quantization.export_onnx import configure_linear_module_onnx_quantizers |
52 | 51 | from modelopt.torch.utils import flatten_tree, standardize_named_model_args |
53 | 52 | from modelopt.torch.utils._pytree import TreeSpec |
54 | 53 |
|
55 | 54 | from ..utils.onnx_optimizer import Optimizer |
56 | | -from .onnx_utils import check_model_uses_external_data |
| 55 | +from .onnx_utils import _get_onnx_external_data_tensors, check_model_uses_external_data |
57 | 56 |
|
58 | 57 | ModelMetadata = dict[str, Any] |
59 | 58 | ModelType = Any |
@@ -84,8 +83,15 @@ def __init__(self, onnx_load_path: str) -> None: |
84 | 83 | self.onnx_load_path = os.path.abspath(onnx_load_path) |
85 | 84 | self.onnx_model = {} |
86 | 85 | self.model_name = "" |
| 86 | + onnx_model = onnx.load(self.onnx_load_path, load_external_data=False) |
87 | 87 |
|
88 | | - if has_external_data(onnx_load_path): |
| 88 | + # Check for external data |
| 89 | + external_data_format = False |
| 90 | + for initializer in onnx_model.graph.initializer: |
| 91 | + if initializer.external_data: |
| 92 | + external_data_format = True |
| 93 | + |
| 94 | + if external_data_format: |
89 | 95 | onnx_model_dir = os.path.dirname(self.onnx_load_path) |
90 | 96 | for onnx_model_file in os.listdir(onnx_model_dir): |
91 | 97 | with open(os.path.join(onnx_model_dir, onnx_model_file), "rb") as f: |
@@ -413,7 +419,9 @@ def get_onnx_bytes_and_metadata( |
413 | 419 | # Export onnx model from pytorch model |
414 | 420 | # As the maximum size of protobuf is 2GB, we cannot use io.BytesIO() buffer during export. |
415 | 421 | model_name = model.__class__.__name__ |
416 | | - onnx_path = tempfile.mkdtemp(prefix=f"modelopt_{model_name}_") |
| 422 | + onnx_build_folder = os.path.join(tempfile.gettempdir(), "modelopt_build/onnx/") |
| 423 | + onnx_path = os.path.join(onnx_build_folder, model_name) |
| 424 | + os.makedirs(onnx_path, exist_ok=True) |
417 | 425 | onnx_save_path = os.path.join(onnx_path, f"{model_name}.onnx") |
418 | 426 |
|
419 | 427 | # Configure quantizers if the model is quantized in NVFP4 or MXFP8 mode |
@@ -444,7 +452,7 @@ def get_onnx_bytes_and_metadata( |
444 | 452 | onnx_graph = onnx.load(onnx_save_path, load_external_data=True) |
445 | 453 |
|
446 | 454 | try: |
447 | | - onnx_graph = infer_shapes(onnx_graph) |
| 455 | + onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph) |
448 | 456 | except Exception as e: |
449 | 457 | print(f"Shape inference failed: {e}") |
450 | 458 |
|
@@ -494,37 +502,28 @@ def get_onnx_bytes_and_metadata( |
494 | 502 |
|
495 | 503 | # If the onnx model contains external data store the external tensors in one file and save the onnx model |
496 | 504 | if has_external_data(onnx_save_path): |
497 | | - tensor_paths = get_external_tensor_paths(onnx_path) |
| 505 | + tensor_paths = _get_onnx_external_data_tensors(onnx_opt_graph) |
498 | 506 | onnx.save_model( |
499 | 507 | onnx_opt_graph, |
500 | 508 | onnx_save_path, |
501 | 509 | save_as_external_data=True, |
502 | 510 | all_tensors_to_one_file=True, |
503 | 511 | location=f"{model_name}.onnx_data", |
504 | 512 | size_threshold=1024, |
505 | | - convert_attribute=False, |
506 | 513 | ) |
507 | | - for path in tensor_paths: |
508 | | - os.remove(path) |
| 514 | + for tensor in tensor_paths: |
| 515 | + tensor_path = os.path.join(onnx_path, tensor) |
| 516 | + os.remove(tensor_path) |
509 | 517 | else: |
510 | 518 | onnx.save_model(onnx_opt_graph, onnx_save_path) |
511 | 519 |
|
512 | 520 | onnx_bytes = OnnxBytes(onnx_save_path) |
513 | 521 |
|
514 | 522 | if remove_exported_model: |
515 | | - shutil.rmtree(onnx_path) |
| 523 | + shutil.rmtree(os.path.dirname(onnx_build_folder)) |
516 | 524 | return onnx_bytes.to_bytes(), model_metadata |
517 | 525 |
|
518 | 526 |
|
519 | | -def get_external_tensor_paths(model_dir: str) -> list[str]: |
520 | | - """Get the paths of the external data tensors in the model.""" |
521 | | - return [ |
522 | | - os.path.join(model_dir, file) |
523 | | - for file in os.listdir(model_dir) |
524 | | - if not file.endswith(".onnx") |
525 | | - ] |
526 | | - |
527 | | - |
528 | 527 | def has_external_data(onnx_model_path: str): |
529 | 528 | """Check if the onnx model has external data.""" |
530 | 529 | onnx_model = onnx.load(onnx_model_path, load_external_data=False) |
|
0 commit comments