-
Notifications
You must be signed in to change notification settings - Fork 190
Add option to benchmark pipeline in diffusion_trt.py #457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 5 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
ecaaf76
Add option to benchmark pipeline in diffusion_trt.py
ajrasane 6184567
Benchmark the backbone only
ajrasane fbafb67
Add a flag for skipping image generation
ajrasane 33afe32
Update logging
ajrasane 89d66d8
Measure GPU time
ajrasane b5223b1
Update get_onnx_bytes_and_metadata
ajrasane 779be79
Add flag for torch.compile()
ajrasane File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
| update_dynamic_axes, | ||
| ) | ||
| from quantize import ModelType, PipelineManager | ||
| from tqdm import tqdm | ||
|
|
||
| import modelopt.torch.opt as mto | ||
| from modelopt.torch._deploy._runtime import RuntimeRegistry | ||
|
|
@@ -58,6 +59,55 @@ def generate_image(pipe, prompt, image_name): | |
| print(f"Image generated saved as {image_name}") | ||
|
|
||
|
|
||
| def benchmark_model(pipe, prompt, num_warmup=10, num_runs=50, num_inference_steps=20): | ||
| """Benchmark the backbone model inference time.""" | ||
| backbone = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet | ||
|
|
||
| backbone_times = [] | ||
| start_event = torch.cuda.Event(enable_timing=True) | ||
| end_event = torch.cuda.Event(enable_timing=True) | ||
|
|
||
| def forward_pre_hook(_module, _input): | ||
| start_event.record() | ||
|
|
||
| def forward_hook(_module, _input, _output): | ||
| end_event.record() | ||
| torch.cuda.synchronize() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can use even synchronize. |
||
| backbone_times.append(start_event.elapsed_time(end_event)) | ||
|
|
||
| pre_handle = backbone.register_forward_pre_hook(forward_pre_hook) | ||
| post_handle = backbone.register_forward_hook(forward_hook) | ||
|
|
||
| try: | ||
| print(f"Starting warmup: {num_warmup} runs") | ||
| for _ in tqdm(range(num_warmup), desc="Warmup"): | ||
| _ = pipe( | ||
| prompt, | ||
| output_type="pil", | ||
| num_inference_steps=num_inference_steps, | ||
| generator=torch.Generator("cuda").manual_seed(42), | ||
| ) | ||
|
|
||
| backbone_times.clear() | ||
|
|
||
| print(f"Starting benchmark: {num_runs} runs") | ||
| for _ in tqdm(range(num_runs), desc="Benchmark"): | ||
| _ = pipe( | ||
| prompt, | ||
| output_type="pil", | ||
| num_inference_steps=num_inference_steps, | ||
| generator=torch.Generator("cuda").manual_seed(42), | ||
| ) | ||
| finally: | ||
| pre_handle.remove() | ||
| post_handle.remove() | ||
|
|
||
| total_backbone_time = sum(backbone_times) | ||
| avg_latency = total_backbone_time / (num_runs * num_inference_steps) | ||
| print(f"Inference latency of the torch backbone: {avg_latency:.2f} ms") | ||
| return avg_latency | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument( | ||
|
|
@@ -92,15 +142,21 @@ def main(): | |
| "--onnx-load-path", type=str, default="", help="Path to load the ONNX model" | ||
| ) | ||
| parser.add_argument( | ||
| "--trt-engine-load-path", type=str, default=None, help="Path to load the TRT engine" | ||
| "--trt-engine-load-path", type=str, default=None, help="Path to load the TensorRT engine" | ||
| ) | ||
| parser.add_argument( | ||
| "--dq-only", action="store_true", help="Converts the ONNX model to a dq_only model" | ||
| ) | ||
| parser.add_argument( | ||
| "--torch", action="store_true", help="Generate an image using the torch pipeline" | ||
| "--torch", | ||
| action="store_true", | ||
| help="Use the torch pipeline for image generation or benchmarking", | ||
| ) | ||
| parser.add_argument("--save-image-as", type=str, default=None, help="Name of the image to save") | ||
| parser.add_argument( | ||
| "--benchmark", action="store_true", help="Benchmark the model backbone inference time" | ||
| ) | ||
| parser.add_argument("--skip-image", action="store_true", help="Skip image generation") | ||
| args = parser.parse_args() | ||
|
|
||
| image_name = args.save_image_as if args.save_image_as else f"{args.model}.png" | ||
|
|
@@ -131,7 +187,12 @@ def main(): | |
| elif hasattr(pipe, "unet"): | ||
| pipe.unet = backbone | ||
| pipe.to("cuda") | ||
| generate_image(pipe, args.prompt, image_name) | ||
|
|
||
| if args.benchmark: | ||
| benchmark_model(pipe, args.prompt) | ||
|
|
||
| if not args.skip_image: | ||
| generate_image(pipe, args.prompt, image_name) | ||
| return | ||
|
|
||
| backbone.to("cuda") | ||
|
|
@@ -211,10 +272,14 @@ def main(): | |
| raise ValueError("Pipeline does not have a transformer or unet backbone") | ||
| pipe.to("cuda") | ||
|
|
||
| generate_image(pipe, args.prompt, image_name) | ||
| print(f"Image generated using {args.model} model saved as {image_name}") | ||
| if not args.skip_image: | ||
| generate_image(pipe, args.prompt, image_name) | ||
| print(f"Image generated using {args.model} model saved as {image_name}") | ||
|
|
||
| print(f"Inference latency of the backbone of the pipeline is {device_model.get_latency()} ms") | ||
| if args.benchmark: | ||
| print( | ||
| f"Inference latency of the TensorRT optimized backbone: {device_model.get_latency()} ms" | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you feel it will be more valuable to show the GPU time using cuda event instead of the CPU time?
With the GPU time you don't need to call explicit sync
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done