Skip to content

Commit 476b59f

Browse files
ajrasaneEdwardf0t1
authored andcommitted
Add option to benchmark pipeline in diffusion_trt.py (#457)
Signed-off-by: ajrasane <[email protected]> Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent b8dbfc0 commit 476b59f

File tree

2 files changed

+104
-24
lines changed

2 files changed

+104
-24
lines changed

examples/diffusers/quantization/diffusion_trt.py

Lines changed: 85 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
update_dynamic_axes,
2424
)
2525
from quantize import ModelType, PipelineManager
26+
from tqdm import tqdm
2627

2728
import modelopt.torch.opt as mto
2829
from modelopt.torch._deploy._runtime import RuntimeRegistry
@@ -58,6 +59,59 @@ def generate_image(pipe, prompt, image_name):
5859
print(f"Image generated saved as {image_name}")
5960

6061

62+
def benchmark_model(
63+
pipe, prompt, num_warmup=10, num_runs=50, num_inference_steps=20, model_dtype="Half"
64+
):
65+
"""Benchmark the backbone model inference time."""
66+
backbone = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
67+
68+
backbone_times = []
69+
start_event = torch.cuda.Event(enable_timing=True)
70+
end_event = torch.cuda.Event(enable_timing=True)
71+
72+
def forward_pre_hook(_module, _input):
73+
start_event.record()
74+
75+
def forward_hook(_module, _input, _output):
76+
end_event.record()
77+
torch.cuda.synchronize()
78+
backbone_times.append(start_event.elapsed_time(end_event))
79+
80+
pre_handle = backbone.register_forward_pre_hook(forward_pre_hook)
81+
post_handle = backbone.register_forward_hook(forward_hook)
82+
83+
try:
84+
print(f"Starting warmup: {num_warmup} runs")
85+
for _ in tqdm(range(num_warmup), desc="Warmup"):
86+
with torch.amp.autocast("cuda", dtype=dtype_map[model_dtype]):
87+
_ = pipe(
88+
prompt,
89+
output_type="pil",
90+
num_inference_steps=num_inference_steps,
91+
generator=torch.Generator("cuda").manual_seed(42),
92+
)
93+
94+
backbone_times.clear()
95+
96+
print(f"Starting benchmark: {num_runs} runs")
97+
for _ in tqdm(range(num_runs), desc="Benchmark"):
98+
with torch.amp.autocast("cuda", dtype=dtype_map[model_dtype]):
99+
_ = pipe(
100+
prompt,
101+
output_type="pil",
102+
num_inference_steps=num_inference_steps,
103+
generator=torch.Generator("cuda").manual_seed(42),
104+
)
105+
finally:
106+
pre_handle.remove()
107+
post_handle.remove()
108+
109+
total_backbone_time = sum(backbone_times)
110+
avg_latency = total_backbone_time / (num_runs * num_inference_steps)
111+
print(f"Inference latency of the torch backbone: {avg_latency:.2f} ms")
112+
return avg_latency
113+
114+
61115
def main():
62116
parser = argparse.ArgumentParser()
63117
parser.add_argument(
@@ -92,15 +146,24 @@ def main():
92146
"--onnx-load-path", type=str, default="", help="Path to load the ONNX model"
93147
)
94148
parser.add_argument(
95-
"--trt-engine-load-path", type=str, default=None, help="Path to load the TRT engine"
149+
"--trt-engine-load-path", type=str, default=None, help="Path to load the TensorRT engine"
96150
)
97151
parser.add_argument(
98152
"--dq-only", action="store_true", help="Converts the ONNX model to a dq_only model"
99153
)
100154
parser.add_argument(
101-
"--torch", action="store_true", help="Generate an image using the torch pipeline"
155+
"--torch",
156+
action="store_true",
157+
help="Use the torch pipeline for image generation or benchmarking",
102158
)
103159
parser.add_argument("--save-image-as", type=str, default=None, help="Name of the image to save")
160+
parser.add_argument(
161+
"--benchmark", action="store_true", help="Benchmark the model backbone inference time"
162+
)
163+
parser.add_argument(
164+
"--torch-compile", action="store_true", help="Use torch.compile() on the backbone model"
165+
)
166+
parser.add_argument("--skip-image", action="store_true", help="Skip image generation")
104167
args = parser.parse_args()
105168

106169
image_name = args.save_image_as if args.save_image_as else f"{args.model}.png"
@@ -125,13 +188,25 @@ def main():
125188
if args.restore_from:
126189
mto.restore(backbone, args.restore_from)
127190

191+
if args.torch_compile:
192+
assert args.model_dtype in ["BFloat16", "Float", "Half"], (
193+
"torch.compile() only supports BFloat16 and Float"
194+
)
195+
print("Compiling backbone with torch.compile()...")
196+
backbone = torch.compile(backbone, mode="max-autotune")
197+
128198
if args.torch:
129199
if hasattr(pipe, "transformer"):
130200
pipe.transformer = backbone
131201
elif hasattr(pipe, "unet"):
132202
pipe.unet = backbone
133203
pipe.to("cuda")
134-
generate_image(pipe, args.prompt, image_name)
204+
205+
if args.benchmark:
206+
benchmark_model(pipe, args.prompt, model_dtype=args.model_dtype)
207+
208+
if not args.skip_image:
209+
generate_image(pipe, args.prompt, image_name)
135210
return
136211

137212
backbone.to("cuda")
@@ -211,10 +286,14 @@ def main():
211286
raise ValueError("Pipeline does not have a transformer or unet backbone")
212287
pipe.to("cuda")
213288

214-
generate_image(pipe, args.prompt, image_name)
215-
print(f"Image generated using {args.model} model saved as {image_name}")
289+
if not args.skip_image:
290+
generate_image(pipe, args.prompt, image_name)
291+
print(f"Image generated using {args.model} model saved as {image_name}")
216292

217-
print(f"Inference latency of the backbone of the pipeline is {device_model.get_latency()} ms")
293+
if args.benchmark:
294+
print(
295+
f"Inference latency of the TensorRT optimized backbone: {device_model.get_latency()} ms"
296+
)
218297

219298

220299
if __name__ == "__main__":

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,15 @@
4545
get_node_names,
4646
get_output_names,
4747
get_output_shapes,
48+
infer_shapes,
4849
remove_node_training_mode,
4950
)
5051
from modelopt.torch.quantization.export_onnx import configure_linear_module_onnx_quantizers
5152
from modelopt.torch.utils import flatten_tree, standardize_named_model_args
5253
from modelopt.torch.utils._pytree import TreeSpec
5354

5455
from ..utils.onnx_optimizer import Optimizer
55-
from .onnx_utils import _get_onnx_external_data_tensors, check_model_uses_external_data
56+
from .onnx_utils import check_model_uses_external_data
5657

5758
ModelMetadata = dict[str, Any]
5859
ModelType = Any
@@ -83,15 +84,8 @@ def __init__(self, onnx_load_path: str) -> None:
8384
self.onnx_load_path = os.path.abspath(onnx_load_path)
8485
self.onnx_model = {}
8586
self.model_name = ""
86-
onnx_model = onnx.load(self.onnx_load_path, load_external_data=False)
8787

88-
# Check for external data
89-
external_data_format = False
90-
for initializer in onnx_model.graph.initializer:
91-
if initializer.external_data:
92-
external_data_format = True
93-
94-
if external_data_format:
88+
if has_external_data(onnx_load_path):
9589
onnx_model_dir = os.path.dirname(self.onnx_load_path)
9690
for onnx_model_file in os.listdir(onnx_model_dir):
9791
with open(os.path.join(onnx_model_dir, onnx_model_file), "rb") as f:
@@ -419,9 +413,7 @@ def get_onnx_bytes_and_metadata(
419413
# Export onnx model from pytorch model
420414
# As the maximum size of protobuf is 2GB, we cannot use io.BytesIO() buffer during export.
421415
model_name = model.__class__.__name__
422-
onnx_build_folder = os.path.join(tempfile.gettempdir(), "modelopt_build/onnx/")
423-
onnx_path = os.path.join(onnx_build_folder, model_name)
424-
os.makedirs(onnx_path, exist_ok=True)
416+
onnx_path = tempfile.mkdtemp(prefix=f"modelopt_{model_name}_")
425417
onnx_save_path = os.path.join(onnx_path, f"{model_name}.onnx")
426418

427419
# Configure quantizers if the model is quantized in NVFP4 or MXFP8 mode
@@ -452,7 +444,7 @@ def get_onnx_bytes_and_metadata(
452444
onnx_graph = onnx.load(onnx_save_path, load_external_data=True)
453445

454446
try:
455-
onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph)
447+
onnx_graph = infer_shapes(onnx_graph)
456448
except Exception as e:
457449
print(f"Shape inference failed: {e}")
458450

@@ -502,28 +494,37 @@ def get_onnx_bytes_and_metadata(
502494

503495
# If the onnx model contains external data store the external tensors in one file and save the onnx model
504496
if has_external_data(onnx_save_path):
505-
tensor_paths = _get_onnx_external_data_tensors(onnx_opt_graph)
497+
tensor_paths = get_external_tensor_paths(onnx_path)
506498
onnx.save_model(
507499
onnx_opt_graph,
508500
onnx_save_path,
509501
save_as_external_data=True,
510502
all_tensors_to_one_file=True,
511503
location=f"{model_name}.onnx_data",
512504
size_threshold=1024,
505+
convert_attribute=False,
513506
)
514-
for tensor in tensor_paths:
515-
tensor_path = os.path.join(onnx_path, tensor)
516-
os.remove(tensor_path)
507+
for path in tensor_paths:
508+
os.remove(path)
517509
else:
518510
onnx.save_model(onnx_opt_graph, onnx_save_path)
519511

520512
onnx_bytes = OnnxBytes(onnx_save_path)
521513

522514
if remove_exported_model:
523-
shutil.rmtree(os.path.dirname(onnx_build_folder))
515+
shutil.rmtree(onnx_path)
524516
return onnx_bytes.to_bytes(), model_metadata
525517

526518

519+
def get_external_tensor_paths(model_dir: str) -> list[str]:
520+
"""Get the paths of the external data tensors in the model."""
521+
return [
522+
os.path.join(model_dir, file)
523+
for file in os.listdir(model_dir)
524+
if not file.endswith(".onnx")
525+
]
526+
527+
527528
def has_external_data(onnx_model_path: str):
528529
"""Check if the onnx model has external data."""
529530
onnx_model = onnx.load(onnx_model_path, load_external_data=False)

0 commit comments

Comments
 (0)