Skip to content

Commit b5223b1

Browse files
committed
Update get_onnx_bytes_and_metadata
Signed-off-by: ajrasane <[email protected]>
1 parent 89d66d8 commit b5223b1

File tree

1 file changed

+19
-18
lines changed

1 file changed

+19
-18
lines changed

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,15 @@
4545
get_node_names,
4646
get_output_names,
4747
get_output_shapes,
48+
infer_shapes,
4849
remove_node_training_mode,
4950
)
5051
from modelopt.torch.quantization.export_onnx import configure_linear_module_onnx_quantizers
5152
from modelopt.torch.utils import flatten_tree, standardize_named_model_args
5253
from modelopt.torch.utils._pytree import TreeSpec
5354

5455
from ..utils.onnx_optimizer import Optimizer
55-
from .onnx_utils import _get_onnx_external_data_tensors, check_model_uses_external_data
56+
from .onnx_utils import check_model_uses_external_data
5657

5758
ModelMetadata = dict[str, Any]
5859
ModelType = Any
@@ -83,15 +84,8 @@ def __init__(self, onnx_load_path: str) -> None:
8384
self.onnx_load_path = os.path.abspath(onnx_load_path)
8485
self.onnx_model = {}
8586
self.model_name = ""
86-
onnx_model = onnx.load(self.onnx_load_path, load_external_data=False)
8787

88-
# Check for external data
89-
external_data_format = False
90-
for initializer in onnx_model.graph.initializer:
91-
if initializer.external_data:
92-
external_data_format = True
93-
94-
if external_data_format:
88+
if has_external_data(onnx_load_path):
9589
onnx_model_dir = os.path.dirname(self.onnx_load_path)
9690
for onnx_model_file in os.listdir(onnx_model_dir):
9791
with open(os.path.join(onnx_model_dir, onnx_model_file), "rb") as f:
@@ -419,9 +413,7 @@ def get_onnx_bytes_and_metadata(
419413
# Export onnx model from pytorch model
420414
# As the maximum size of protobuf is 2GB, we cannot use io.BytesIO() buffer during export.
421415
model_name = model.__class__.__name__
422-
onnx_build_folder = os.path.join(tempfile.gettempdir(), "modelopt_build/onnx/")
423-
onnx_path = os.path.join(onnx_build_folder, model_name)
424-
os.makedirs(onnx_path, exist_ok=True)
416+
onnx_path = tempfile.mkdtemp(prefix=f"modelopt_{model_name}_")
425417
onnx_save_path = os.path.join(onnx_path, f"{model_name}.onnx")
426418

427419
# Configure quantizers if the model is quantized in NVFP4 or MXFP8 mode
@@ -452,7 +444,7 @@ def get_onnx_bytes_and_metadata(
452444
onnx_graph = onnx.load(onnx_save_path, load_external_data=True)
453445

454446
try:
455-
onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph)
447+
onnx_graph = infer_shapes(onnx_graph)
456448
except Exception as e:
457449
print(f"Shape inference failed: {e}")
458450

@@ -502,28 +494,37 @@ def get_onnx_bytes_and_metadata(
502494

503495
# If the onnx model contains external data store the external tensors in one file and save the onnx model
504496
if has_external_data(onnx_save_path):
505-
tensor_paths = _get_onnx_external_data_tensors(onnx_opt_graph)
497+
tensor_paths = get_external_tensor_paths(onnx_path)
506498
onnx.save_model(
507499
onnx_opt_graph,
508500
onnx_save_path,
509501
save_as_external_data=True,
510502
all_tensors_to_one_file=True,
511503
location=f"{model_name}.onnx_data",
512504
size_threshold=1024,
505+
convert_attribute=False,
513506
)
514-
for tensor in tensor_paths:
515-
tensor_path = os.path.join(onnx_path, tensor)
516-
os.remove(tensor_path)
507+
for path in tensor_paths:
508+
os.remove(path)
517509
else:
518510
onnx.save_model(onnx_opt_graph, onnx_save_path)
519511

520512
onnx_bytes = OnnxBytes(onnx_save_path)
521513

522514
if remove_exported_model:
523-
shutil.rmtree(os.path.dirname(onnx_build_folder))
515+
shutil.rmtree(onnx_path)
524516
return onnx_bytes.to_bytes(), model_metadata
525517

526518

519+
def get_external_tensor_paths(model_dir: str) -> list[str]:
520+
"""Get the paths of the external data tensors in the model."""
521+
return [
522+
os.path.join(model_dir, file)
523+
for file in os.listdir(model_dir)
524+
if not file.endswith(".onnx")
525+
]
526+
527+
527528
def has_external_data(onnx_model_path: str):
528529
"""Check if the onnx model has external data."""
529530
onnx_model = onnx.load(onnx_model_path, load_external_data=False)

0 commit comments

Comments
 (0)