Skip to content

Commit ecaaf76

Browse files
committed
Add option to benchmark pipeline in diffusion_trt.py
Signed-off-by: ajrasane <[email protected]>
1 parent f8a9353 commit ecaaf76

File tree

1 file changed

+42
-1
lines changed

1 file changed

+42
-1
lines changed

examples/diffusers/quantization/diffusion_trt.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import argparse
17+
import time
1718

1819
import torch
1920
from onnx_utils.export import (
@@ -58,6 +59,34 @@ def generate_image(pipe, prompt, image_name):
5859
print(f"Image generated saved as {image_name}")
5960

6061

62+
def benchmark_model(pipe, prompt, num_warmup=3, num_runs=10):
63+
"""Benchmark the model inference time."""
64+
# Warmup runs
65+
for _ in range(num_warmup):
66+
_ = pipe(
67+
prompt,
68+
output_type="pil",
69+
num_inference_steps=30,
70+
generator=torch.Generator("cuda").manual_seed(42),
71+
)
72+
73+
# Benchmark runs
74+
torch.cuda.synchronize()
75+
start = time.time()
76+
for _ in range(num_runs):
77+
_ = pipe(
78+
prompt,
79+
output_type="pil",
80+
num_inference_steps=30,
81+
generator=torch.Generator("cuda").manual_seed(42),
82+
)
83+
torch.cuda.synchronize()
84+
end = time.time()
85+
86+
avg_latency = (end - start) / num_runs * 1000 # Convert to ms
87+
return avg_latency
88+
89+
6190
def main():
6291
parser = argparse.ArgumentParser()
6392
parser.add_argument(
@@ -101,6 +130,9 @@ def main():
101130
"--torch", action="store_true", help="Generate an image using the torch pipeline"
102131
)
103132
parser.add_argument("--save-image-as", type=str, default=None, help="Name of the image to save")
133+
parser.add_argument(
134+
"--benchmark", action="store_true", help="Benchmark the model inference time"
135+
)
104136
args = parser.parse_args()
105137

106138
image_name = args.save_image_as if args.save_image_as else f"{args.model}.png"
@@ -131,6 +163,12 @@ def main():
131163
elif hasattr(pipe, "unet"):
132164
pipe.unet = backbone
133165
pipe.to("cuda")
166+
167+
if args.benchmark:
168+
# Benchmark the torch model
169+
torch_latency = benchmark_model(pipe, args.prompt)
170+
print(f"Inference latency of the torch pipeline is {torch_latency:.2f} ms")
171+
134172
generate_image(pipe, args.prompt, image_name)
135173
return
136174

@@ -214,7 +252,10 @@ def main():
214252
generate_image(pipe, args.prompt, image_name)
215253
print(f"Image generated using {args.model} model saved as {image_name}")
216254

217-
print(f"Inference latency of the backbone of the pipeline is {device_model.get_latency()} ms")
255+
if args.benchmark:
256+
print(
257+
f"Inference latency of the backbone of the pipeline is {device_model.get_latency()} ms"
258+
)
218259

219260

220261
if __name__ == "__main__":

0 commit comments

Comments
 (0)