|
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | 16 | import argparse |
| 17 | +import time |
17 | 18 |
|
18 | 19 | import torch |
19 | 20 | from onnx_utils.export import ( |
@@ -58,6 +59,34 @@ def generate_image(pipe, prompt, image_name): |
58 | 59 | print(f"Image generated saved as {image_name}") |
59 | 60 |
|
60 | 61 |
|
| 62 | +def benchmark_model(pipe, prompt, num_warmup=3, num_runs=10): |
| 63 | + """Benchmark the model inference time.""" |
| 64 | + # Warmup runs |
| 65 | + for _ in range(num_warmup): |
| 66 | + _ = pipe( |
| 67 | + prompt, |
| 68 | + output_type="pil", |
| 69 | + num_inference_steps=30, |
| 70 | + generator=torch.Generator("cuda").manual_seed(42), |
| 71 | + ) |
| 72 | + |
| 73 | + # Benchmark runs |
| 74 | + torch.cuda.synchronize() |
| 75 | + start = time.time() |
| 76 | + for _ in range(num_runs): |
| 77 | + _ = pipe( |
| 78 | + prompt, |
| 79 | + output_type="pil", |
| 80 | + num_inference_steps=30, |
| 81 | + generator=torch.Generator("cuda").manual_seed(42), |
| 82 | + ) |
| 83 | + torch.cuda.synchronize() |
| 84 | + end = time.time() |
| 85 | + |
| 86 | + avg_latency = (end - start) / num_runs * 1000 # Convert to ms |
| 87 | + return avg_latency |
| 88 | + |
| 89 | + |
61 | 90 | def main(): |
62 | 91 | parser = argparse.ArgumentParser() |
63 | 92 | parser.add_argument( |
@@ -101,6 +130,9 @@ def main(): |
101 | 130 | "--torch", action="store_true", help="Generate an image using the torch pipeline" |
102 | 131 | ) |
103 | 132 | parser.add_argument("--save-image-as", type=str, default=None, help="Name of the image to save") |
| 133 | + parser.add_argument( |
| 134 | + "--benchmark", action="store_true", help="Benchmark the model inference time" |
| 135 | + ) |
104 | 136 | args = parser.parse_args() |
105 | 137 |
|
106 | 138 | image_name = args.save_image_as if args.save_image_as else f"{args.model}.png" |
@@ -131,6 +163,12 @@ def main(): |
131 | 163 | elif hasattr(pipe, "unet"): |
132 | 164 | pipe.unet = backbone |
133 | 165 | pipe.to("cuda") |
| 166 | + |
| 167 | + if args.benchmark: |
| 168 | + # Benchmark the torch model |
| 169 | + torch_latency = benchmark_model(pipe, args.prompt) |
| 170 | + print(f"Inference latency of the torch pipeline is {torch_latency:.2f} ms") |
| 171 | + |
134 | 172 | generate_image(pipe, args.prompt, image_name) |
135 | 173 | return |
136 | 174 |
|
@@ -214,7 +252,10 @@ def main(): |
214 | 252 | generate_image(pipe, args.prompt, image_name) |
215 | 253 | print(f"Image generated using {args.model} model saved as {image_name}") |
216 | 254 |
|
217 | | - print(f"Inference latency of the backbone of the pipeline is {device_model.get_latency()} ms") |
| 255 | + if args.benchmark: |
| 256 | + print( |
| 257 | + f"Inference latency of the backbone of the pipeline is {device_model.get_latency()} ms" |
| 258 | + ) |
218 | 259 |
|
219 | 260 |
|
220 | 261 | if __name__ == "__main__": |
|
0 commit comments