Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 85 additions & 6 deletions examples/diffusers/quantization/diffusion_trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
update_dynamic_axes,
)
from quantize import ModelType, PipelineManager
from tqdm import tqdm

import modelopt.torch.opt as mto
from modelopt.torch._deploy._runtime import RuntimeRegistry
Expand Down Expand Up @@ -58,6 +59,59 @@ def generate_image(pipe, prompt, image_name):
print(f"Image generated saved as {image_name}")


def benchmark_model(
pipe, prompt, num_warmup=10, num_runs=50, num_inference_steps=20, model_dtype="Half"
):
"""Benchmark the backbone model inference time."""
backbone = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet

backbone_times = []
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

def forward_pre_hook(_module, _input):
start_event.record()

def forward_hook(_module, _input, _output):
end_event.record()
torch.cuda.synchronize()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you feel it will be more valuable to show the GPU time using cuda event instead of the CPU time?

With the GPU time you don't need to call explicit sync

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use even synchronize.

backbone_times.append(start_event.elapsed_time(end_event))

pre_handle = backbone.register_forward_pre_hook(forward_pre_hook)
post_handle = backbone.register_forward_hook(forward_hook)

try:
print(f"Starting warmup: {num_warmup} runs")
for _ in tqdm(range(num_warmup), desc="Warmup"):
with torch.amp.autocast("cuda", dtype=dtype_map[model_dtype]):
_ = pipe(
prompt,
output_type="pil",
num_inference_steps=num_inference_steps,
generator=torch.Generator("cuda").manual_seed(42),
)

backbone_times.clear()

print(f"Starting benchmark: {num_runs} runs")
for _ in tqdm(range(num_runs), desc="Benchmark"):
with torch.amp.autocast("cuda", dtype=dtype_map[model_dtype]):
_ = pipe(
prompt,
output_type="pil",
num_inference_steps=num_inference_steps,
generator=torch.Generator("cuda").manual_seed(42),
)
finally:
pre_handle.remove()
post_handle.remove()

total_backbone_time = sum(backbone_times)
avg_latency = total_backbone_time / (num_runs * num_inference_steps)
print(f"Inference latency of the torch backbone: {avg_latency:.2f} ms")
return avg_latency


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -92,15 +146,24 @@ def main():
"--onnx-load-path", type=str, default="", help="Path to load the ONNX model"
)
parser.add_argument(
"--trt-engine-load-path", type=str, default=None, help="Path to load the TRT engine"
"--trt-engine-load-path", type=str, default=None, help="Path to load the TensorRT engine"
)
parser.add_argument(
"--dq-only", action="store_true", help="Converts the ONNX model to a dq_only model"
)
parser.add_argument(
"--torch", action="store_true", help="Generate an image using the torch pipeline"
"--torch",
action="store_true",
help="Use the torch pipeline for image generation or benchmarking",
)
parser.add_argument("--save-image-as", type=str, default=None, help="Name of the image to save")
parser.add_argument(
"--benchmark", action="store_true", help="Benchmark the model backbone inference time"
)
parser.add_argument(
"--torch-compile", action="store_true", help="Use torch.compile() on the backbone model"
)
parser.add_argument("--skip-image", action="store_true", help="Skip image generation")
args = parser.parse_args()

image_name = args.save_image_as if args.save_image_as else f"{args.model}.png"
Expand All @@ -125,13 +188,25 @@ def main():
if args.restore_from:
mto.restore(backbone, args.restore_from)

if args.torch_compile:
assert args.model_dtype in ["BFloat16", "Float", "Half"], (
"torch.compile() only supports BFloat16 and Float"
)
print("Compiling backbone with torch.compile()...")
backbone = torch.compile(backbone, mode="max-autotune")

if args.torch:
if hasattr(pipe, "transformer"):
pipe.transformer = backbone
elif hasattr(pipe, "unet"):
pipe.unet = backbone
pipe.to("cuda")
generate_image(pipe, args.prompt, image_name)

if args.benchmark:
benchmark_model(pipe, args.prompt, model_dtype=args.model_dtype)

if not args.skip_image:
generate_image(pipe, args.prompt, image_name)
return

backbone.to("cuda")
Expand Down Expand Up @@ -211,10 +286,14 @@ def main():
raise ValueError("Pipeline does not have a transformer or unet backbone")
pipe.to("cuda")

generate_image(pipe, args.prompt, image_name)
print(f"Image generated using {args.model} model saved as {image_name}")
if not args.skip_image:
generate_image(pipe, args.prompt, image_name)
print(f"Image generated using {args.model} model saved as {image_name}")

print(f"Inference latency of the backbone of the pipeline is {device_model.get_latency()} ms")
if args.benchmark:
print(
f"Inference latency of the TensorRT optimized backbone: {device_model.get_latency()} ms"
)


if __name__ == "__main__":
Expand Down
37 changes: 19 additions & 18 deletions modelopt/torch/_deploy/utils/torch_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,15 @@
get_node_names,
get_output_names,
get_output_shapes,
infer_shapes,
remove_node_training_mode,
)
from modelopt.torch.quantization.export_onnx import configure_linear_module_onnx_quantizers
from modelopt.torch.utils import flatten_tree, standardize_named_model_args
from modelopt.torch.utils._pytree import TreeSpec

from ..utils.onnx_optimizer import Optimizer
from .onnx_utils import _get_onnx_external_data_tensors, check_model_uses_external_data
from .onnx_utils import check_model_uses_external_data

ModelMetadata = dict[str, Any]
ModelType = Any
Expand Down Expand Up @@ -83,15 +84,8 @@ def __init__(self, onnx_load_path: str) -> None:
self.onnx_load_path = os.path.abspath(onnx_load_path)
self.onnx_model = {}
self.model_name = ""
onnx_model = onnx.load(self.onnx_load_path, load_external_data=False)

# Check for external data
external_data_format = False
for initializer in onnx_model.graph.initializer:
if initializer.external_data:
external_data_format = True

if external_data_format:
if has_external_data(onnx_load_path):
onnx_model_dir = os.path.dirname(self.onnx_load_path)
for onnx_model_file in os.listdir(onnx_model_dir):
with open(os.path.join(onnx_model_dir, onnx_model_file), "rb") as f:
Expand Down Expand Up @@ -419,9 +413,7 @@ def get_onnx_bytes_and_metadata(
# Export onnx model from pytorch model
# As the maximum size of protobuf is 2GB, we cannot use io.BytesIO() buffer during export.
model_name = model.__class__.__name__
onnx_build_folder = os.path.join(tempfile.gettempdir(), "modelopt_build/onnx/")
onnx_path = os.path.join(onnx_build_folder, model_name)
os.makedirs(onnx_path, exist_ok=True)
onnx_path = tempfile.mkdtemp(prefix=f"modelopt_{model_name}_")
onnx_save_path = os.path.join(onnx_path, f"{model_name}.onnx")

Comment on lines +416 to 418
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Ensure temp directory cleanup on exceptions to prevent GB-sized leaks.

If any step between export and OnnxBytes() raises, onnx_path persists. Use try/finally (or TemporaryDirectory) to always rmtree with ignore_errors=True.

Example pattern:

-    onnx_save_path = os.path.join(onnx_path, f"{model_name}.onnx")
+    onnx_save_path = os.path.join(onnx_path, f"{model_name}.onnx")
+    try:
+        # ... export, optimize, save, build onnx_bytes ...
+        onnx_bytes = OnnxBytes(onnx_save_path)
+        result_bytes = onnx_bytes.to_bytes()
+    finally:
+        if remove_exported_model:
+            shutil.rmtree(onnx_path, ignore_errors=True)
-    onnx_bytes = OnnxBytes(onnx_save_path)
-    if remove_exported_model:
-        shutil.rmtree(onnx_path)
-    return onnx_bytes.to_bytes(), model_metadata
+    return result_bytes, model_metadata

Also applies to: 514-516

🤖 Prompt for AI Agents
In modelopt/torch/_deploy/utils/torch_onnx.py around lines 416-418 (and likewise
514-516), the temporary directory created with tempfile.mkdtemp can be leaked if
an exception occurs; wrap the creation+export+OnnxBytes sequence in a
try/finally (or replace mkdtemp with tempfile.TemporaryDirectory as a context
manager) so that in the finally block you call shutil.rmtree(onnx_path,
ignore_errors=True); ensure you produce the ONNX bytes (or otherwise read any
files needed) before the cleanup so the function still returns the expected
data.

# Configure quantizers if the model is quantized in NVFP4 or MXFP8 mode
Expand Down Expand Up @@ -452,7 +444,7 @@ def get_onnx_bytes_and_metadata(
onnx_graph = onnx.load(onnx_save_path, load_external_data=True)

try:
onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph)
onnx_graph = infer_shapes(onnx_graph)
except Exception as e:
print(f"Shape inference failed: {e}")

Expand Down Expand Up @@ -502,28 +494,37 @@ def get_onnx_bytes_and_metadata(

# If the onnx model contains external data store the external tensors in one file and save the onnx model
if has_external_data(onnx_save_path):
tensor_paths = _get_onnx_external_data_tensors(onnx_opt_graph)
tensor_paths = get_external_tensor_paths(onnx_path)
onnx.save_model(
onnx_opt_graph,
onnx_save_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=f"{model_name}.onnx_data",
size_threshold=1024,
convert_attribute=False,
)
for tensor in tensor_paths:
tensor_path = os.path.join(onnx_path, tensor)
os.remove(tensor_path)
for path in tensor_paths:
os.remove(path)
else:
onnx.save_model(onnx_opt_graph, onnx_save_path)

Comment on lines 495 to 511
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Decide external-data save based on the optimized ModelProto (not the pre-export file).

Current check uses has_external_data(onnx_save_path), which reflects the pre-optimization export. If optimization changes size/layout, you can incorrectly save a >2GB model without external data (protobuf limit), or conversely over-constrain saving. Compute this from onnx_opt_graph instead and fall back to ByteSize guard.

Apply this diff:

-    # If the onnx model contains external data store the external tensors in one file and save the onnx model
-    if has_external_data(onnx_save_path):
-        tensor_paths = get_external_tensor_paths(onnx_path)
-        onnx.save_model(
-            onnx_opt_graph,
-            onnx_save_path,
-            save_as_external_data=True,
-            all_tensors_to_one_file=True,
-            location=f"{model_name}.onnx_data",
-            size_threshold=1024,
-            convert_attribute=False,
-        )
-        for path in tensor_paths:
-            os.remove(path)
-    else:
-        onnx.save_model(onnx_opt_graph, onnx_save_path)
+    # Decide external-data save from the optimized graph to avoid 2GB protobuf issues.
+    needs_external = (
+        check_model_uses_external_data(onnx_opt_graph) or onnx_opt_graph.ByteSize() > TWO_GB
+    )
+    if needs_external:
+        old_tensor_paths = get_external_tensor_paths_from_model(onnx_save_path)
+        onnx.save_model(
+            onnx_opt_graph,
+            onnx_save_path,
+            save_as_external_data=True,
+            all_tensors_to_one_file=True,
+            location=f"{model_name}.onnx_data",
+            size_threshold=1024,
+            convert_attribute=False,
+        )
+        from contextlib import suppress  # safe local import if not at top
+        for path in old_tensor_paths:
+            with suppress(FileNotFoundError):
+                os.remove(path)
+    else:
+        onnx.save_model(onnx_opt_graph, onnx_save_path)

Committable suggestion skipped: line range outside the PR's diff.

onnx_bytes = OnnxBytes(onnx_save_path)

if remove_exported_model:
shutil.rmtree(os.path.dirname(onnx_build_folder))
shutil.rmtree(onnx_path)
return onnx_bytes.to_bytes(), model_metadata


def get_external_tensor_paths(model_dir: str) -> list[str]:
"""Get the paths of the external data tensors in the model."""
return [
os.path.join(model_dir, file)
for file in os.listdir(model_dir)
if not file.endswith(".onnx")
]


def has_external_data(onnx_model_path: str):
"""Check if the onnx model has external data."""
onnx_model = onnx.load(onnx_model_path, load_external_data=False)
Expand Down