| 
45 | 45 |     get_node_names,  | 
46 | 46 |     get_output_names,  | 
47 | 47 |     get_output_shapes,  | 
 | 48 | +    infer_shapes,  | 
48 | 49 |     remove_node_training_mode,  | 
49 | 50 | )  | 
50 | 51 | from modelopt.torch.quantization.export_onnx import configure_linear_module_onnx_quantizers  | 
51 | 52 | from modelopt.torch.utils import flatten_tree, standardize_named_model_args  | 
52 | 53 | from modelopt.torch.utils._pytree import TreeSpec  | 
53 | 54 | 
 
  | 
54 | 55 | from ..utils.onnx_optimizer import Optimizer  | 
55 |  | -from .onnx_utils import _get_onnx_external_data_tensors, check_model_uses_external_data  | 
 | 56 | +from .onnx_utils import check_model_uses_external_data  | 
56 | 57 | 
 
  | 
57 | 58 | ModelMetadata = dict[str, Any]  | 
58 | 59 | ModelType = Any  | 
@@ -83,15 +84,8 @@ def __init__(self, onnx_load_path: str) -> None:  | 
83 | 84 |         self.onnx_load_path = os.path.abspath(onnx_load_path)  | 
84 | 85 |         self.onnx_model = {}  | 
85 | 86 |         self.model_name = ""  | 
86 |  | -        onnx_model = onnx.load(self.onnx_load_path, load_external_data=False)  | 
87 | 87 | 
 
  | 
88 |  | -        # Check for external data  | 
89 |  | -        external_data_format = False  | 
90 |  | -        for initializer in onnx_model.graph.initializer:  | 
91 |  | -            if initializer.external_data:  | 
92 |  | -                external_data_format = True  | 
93 |  | - | 
94 |  | -        if external_data_format:  | 
 | 88 | +        if has_external_data(onnx_load_path):  | 
95 | 89 |             onnx_model_dir = os.path.dirname(self.onnx_load_path)  | 
96 | 90 |             for onnx_model_file in os.listdir(onnx_model_dir):  | 
97 | 91 |                 with open(os.path.join(onnx_model_dir, onnx_model_file), "rb") as f:  | 
@@ -419,9 +413,7 @@ def get_onnx_bytes_and_metadata(  | 
419 | 413 |     # Export onnx model from pytorch model  | 
420 | 414 |     # As the maximum size of protobuf is 2GB, we cannot use io.BytesIO() buffer during export.  | 
421 | 415 |     model_name = model.__class__.__name__  | 
422 |  | -    onnx_build_folder = os.path.join(tempfile.gettempdir(), "modelopt_build/onnx/")  | 
423 |  | -    onnx_path = os.path.join(onnx_build_folder, model_name)  | 
424 |  | -    os.makedirs(onnx_path, exist_ok=True)  | 
 | 416 | +    onnx_path = tempfile.mkdtemp(prefix=f"modelopt_{model_name}_")  | 
425 | 417 |     onnx_save_path = os.path.join(onnx_path, f"{model_name}.onnx")  | 
426 | 418 | 
 
  | 
427 | 419 |     # Configure quantizers if the model is quantized in NVFP4 or MXFP8 mode  | 
@@ -452,7 +444,7 @@ def get_onnx_bytes_and_metadata(  | 
452 | 444 |     onnx_graph = onnx.load(onnx_save_path, load_external_data=True)  | 
453 | 445 | 
 
  | 
454 | 446 |     try:  | 
455 |  | -        onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph)  | 
 | 447 | +        onnx_graph = infer_shapes(onnx_graph)  | 
456 | 448 |     except Exception as e:  | 
457 | 449 |         print(f"Shape inference failed: {e}")  | 
458 | 450 | 
 
  | 
@@ -502,28 +494,37 @@ def get_onnx_bytes_and_metadata(  | 
502 | 494 | 
 
  | 
503 | 495 |     # If the onnx model contains external data store the external tensors in one file and save the onnx model  | 
504 | 496 |     if has_external_data(onnx_save_path):  | 
505 |  | -        tensor_paths = _get_onnx_external_data_tensors(onnx_opt_graph)  | 
 | 497 | +        tensor_paths = get_external_tensor_paths(onnx_path)  | 
506 | 498 |         onnx.save_model(  | 
507 | 499 |             onnx_opt_graph,  | 
508 | 500 |             onnx_save_path,  | 
509 | 501 |             save_as_external_data=True,  | 
510 | 502 |             all_tensors_to_one_file=True,  | 
511 | 503 |             location=f"{model_name}.onnx_data",  | 
512 | 504 |             size_threshold=1024,  | 
 | 505 | +            convert_attribute=False,  | 
513 | 506 |         )  | 
514 |  | -        for tensor in tensor_paths:  | 
515 |  | -            tensor_path = os.path.join(onnx_path, tensor)  | 
516 |  | -            os.remove(tensor_path)  | 
 | 507 | +        for path in tensor_paths:  | 
 | 508 | +            os.remove(path)  | 
517 | 509 |     else:  | 
518 | 510 |         onnx.save_model(onnx_opt_graph, onnx_save_path)  | 
519 | 511 | 
 
  | 
520 | 512 |     onnx_bytes = OnnxBytes(onnx_save_path)  | 
521 | 513 | 
 
  | 
522 | 514 |     if remove_exported_model:  | 
523 |  | -        shutil.rmtree(os.path.dirname(onnx_build_folder))  | 
 | 515 | +        shutil.rmtree(onnx_path)  | 
524 | 516 |     return onnx_bytes.to_bytes(), model_metadata  | 
525 | 517 | 
 
  | 
526 | 518 | 
 
  | 
 | 519 | +def get_external_tensor_paths(model_dir: str) -> list[str]:  | 
 | 520 | +    """Get the paths of the external data tensors in the model."""  | 
 | 521 | +    return [  | 
 | 522 | +        os.path.join(model_dir, file)  | 
 | 523 | +        for file in os.listdir(model_dir)  | 
 | 524 | +        if not file.endswith(".onnx")  | 
 | 525 | +    ]  | 
 | 526 | + | 
 | 527 | + | 
527 | 528 | def has_external_data(onnx_model_path: str):  | 
528 | 529 |     """Check if the onnx model has external data."""  | 
529 | 530 |     onnx_model = onnx.load(onnx_model_path, load_external_data=False)  | 
 | 
0 commit comments