-
Notifications
You must be signed in to change notification settings - Fork 190
Add option to benchmark pipeline in diffusion_trt.py #457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
ecaaf76
6184567
fbafb67
33afe32
89d66d8
b5223b1
779be79
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |
| # limitations under the License. | ||
|
|
||
| import argparse | ||
| import time | ||
|
|
||
| import torch | ||
| from onnx_utils.export import ( | ||
|
|
@@ -23,6 +24,7 @@ | |
| update_dynamic_axes, | ||
| ) | ||
| from quantize import ModelType, PipelineManager | ||
| from tqdm import tqdm | ||
|
|
||
| import modelopt.torch.opt as mto | ||
| from modelopt.torch._deploy._runtime import RuntimeRegistry | ||
|
|
@@ -58,6 +60,54 @@ def generate_image(pipe, prompt, image_name): | |
| print(f"Image generated saved as {image_name}") | ||
|
|
||
|
|
||
| def benchmark_model(pipe, prompt, num_warmup=10, num_runs=50, num_inference_steps=20): | ||
| """Benchmark the backbone model inference time.""" | ||
| backbone = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet | ||
|
|
||
| backbone_times = [] | ||
|
|
||
| def forward_pre_hook(module, input): | ||
| torch.cuda.synchronize() | ||
| module._start_time = time.time() | ||
|
|
||
| def forward_hook(module, input, output): | ||
| torch.cuda.synchronize() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can use even synchronize. |
||
| module._end_time = time.time() | ||
| backbone_times.append((module._end_time - module._start_time) * 1000) # Convert to ms | ||
|
||
|
|
||
| pre_handle = backbone.register_forward_pre_hook(forward_pre_hook) | ||
| post_handle = backbone.register_forward_hook(forward_hook) | ||
|
|
||
| try: | ||
| print(f"Starting warmup: {num_warmup} runs") | ||
| for _ in tqdm(range(num_warmup), desc="Warmup"): | ||
| _ = pipe( | ||
| prompt, | ||
| output_type="pil", | ||
| num_inference_steps=num_inference_steps, | ||
| generator=torch.Generator("cuda").manual_seed(42), | ||
| ) | ||
|
|
||
| backbone_times.clear() | ||
|
|
||
| print(f"Starting benchmark: {num_runs} runs") | ||
| for _ in tqdm(range(num_runs), desc="Benchmark"): | ||
| _ = pipe( | ||
| prompt, | ||
| output_type="pil", | ||
| num_inference_steps=num_inference_steps, | ||
| generator=torch.Generator("cuda").manual_seed(42), | ||
| ) | ||
| finally: | ||
| pre_handle.remove() | ||
| post_handle.remove() | ||
|
|
||
| total_backbone_time = sum(backbone_times) | ||
| avg_latency = total_backbone_time / (num_runs * num_inference_steps) | ||
| print(f"Inference latency of the torch backbone: {avg_latency:.2f} ms") | ||
| return avg_latency | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument( | ||
|
|
@@ -92,15 +142,21 @@ def main(): | |
| "--onnx-load-path", type=str, default="", help="Path to load the ONNX model" | ||
| ) | ||
| parser.add_argument( | ||
| "--trt-engine-load-path", type=str, default=None, help="Path to load the TRT engine" | ||
| "--trt-engine-load-path", type=str, default=None, help="Path to load the TensorRT engine" | ||
| ) | ||
| parser.add_argument( | ||
| "--dq-only", action="store_true", help="Converts the ONNX model to a dq_only model" | ||
| ) | ||
| parser.add_argument( | ||
| "--torch", action="store_true", help="Generate an image using the torch pipeline" | ||
| "--torch", | ||
| action="store_true", | ||
| help="Use the torch pipeline for image generation or benchmarking", | ||
| ) | ||
| parser.add_argument("--save-image-as", type=str, default=None, help="Name of the image to save") | ||
| parser.add_argument( | ||
| "--benchmark", action="store_true", help="Benchmark the model backbone inference time" | ||
| ) | ||
| parser.add_argument("--skip-image", action="store_true", help="Skip image generation") | ||
| args = parser.parse_args() | ||
|
|
||
| image_name = args.save_image_as if args.save_image_as else f"{args.model}.png" | ||
|
|
@@ -131,7 +187,12 @@ def main(): | |
| elif hasattr(pipe, "unet"): | ||
| pipe.unet = backbone | ||
| pipe.to("cuda") | ||
| generate_image(pipe, args.prompt, image_name) | ||
|
|
||
| if args.benchmark: | ||
| benchmark_model(pipe, args.prompt) | ||
|
|
||
| if not args.skip_image: | ||
| generate_image(pipe, args.prompt, image_name) | ||
| return | ||
|
|
||
| backbone.to("cuda") | ||
|
|
@@ -211,10 +272,14 @@ def main(): | |
| raise ValueError("Pipeline does not have a transformer or unet backbone") | ||
| pipe.to("cuda") | ||
|
|
||
| generate_image(pipe, args.prompt, image_name) | ||
| print(f"Image generated using {args.model} model saved as {image_name}") | ||
| if not args.skip_image: | ||
| generate_image(pipe, args.prompt, image_name) | ||
| print(f"Image generated using {args.model} model saved as {image_name}") | ||
|
|
||
| print(f"Inference latency of the backbone of the pipeline is {device_model.get_latency()} ms") | ||
| if args.benchmark: | ||
| print( | ||
| f"Inference latency of the TensorRT optimized backbone: {device_model.get_latency()} ms" | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you feel it will be more valuable to show the GPU time using cuda event instead of the CPU time?
With the GPU time you don't need to call explicit sync
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done