Skip to content

Commit fbafb67

Browse files
committed
Add a flag for skipping image generation
Signed-off-by: ajrasane <[email protected]>
1 parent 6184567 commit fbafb67

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

examples/diffusers/quantization/diffusion_trt.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def generate_image(pipe, prompt, image_name):
5959
print(f"Image generated saved as {image_name}")
6060

6161

62-
def benchmark_model(pipe, prompt, num_warmup=3, num_runs=10):
62+
def benchmark_model(pipe, prompt, num_warmup=3, num_runs=10, num_inference_steps=10):
6363
"""Benchmark the backbone model inference time."""
6464
backbone = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
6565

@@ -82,7 +82,7 @@ def forward_hook(module, input, output):
8282
_ = pipe(
8383
prompt,
8484
output_type="pil",
85-
num_inference_steps=10,
85+
num_inference_steps=num_inference_steps,
8686
generator=torch.Generator("cuda").manual_seed(42),
8787
)
8888

@@ -92,15 +92,15 @@ def forward_hook(module, input, output):
9292
_ = pipe(
9393
prompt,
9494
output_type="pil",
95-
num_inference_steps=10,
95+
num_inference_steps=num_inference_steps,
9696
generator=torch.Generator("cuda").manual_seed(42),
9797
)
9898
finally:
9999
pre_handle.remove()
100100
post_handle.remove()
101101

102102
total_backbone_time = sum(backbone_times)
103-
avg_latency = total_backbone_time / num_runs
103+
avg_latency = total_backbone_time / (num_runs * num_inference_steps)
104104
return avg_latency
105105

106106

@@ -138,18 +138,21 @@ def main():
138138
"--onnx-load-path", type=str, default="", help="Path to load the ONNX model"
139139
)
140140
parser.add_argument(
141-
"--trt-engine-load-path", type=str, default=None, help="Path to load the TRT engine"
141+
"--trt-engine-load-path", type=str, default=None, help="Path to load the TensorRT engine"
142142
)
143143
parser.add_argument(
144144
"--dq-only", action="store_true", help="Converts the ONNX model to a dq_only model"
145145
)
146146
parser.add_argument(
147-
"--torch", action="store_true", help="Generate an image using the torch pipeline"
147+
"--torch",
148+
action="store_true",
149+
help="Use the torch pipeline for image generation or benchmarking",
148150
)
149151
parser.add_argument("--save-image-as", type=str, default=None, help="Name of the image to save")
150152
parser.add_argument(
151-
"--benchmark", action="store_true", help="Benchmark the model inference time"
153+
"--benchmark", action="store_true", help="Benchmark the model backbone inference time"
152154
)
155+
parser.add_argument("--skip-image", action="store_true", help="Skip image generation")
153156
args = parser.parse_args()
154157

155158
image_name = args.save_image_as if args.save_image_as else f"{args.model}.png"
@@ -186,7 +189,8 @@ def main():
186189
torch_latency = benchmark_model(pipe, args.prompt)
187190
print(f"Inference latency of the torch pipeline is {torch_latency:.2f} ms")
188191

189-
generate_image(pipe, args.prompt, image_name)
192+
if not args.skip_image:
193+
generate_image(pipe, args.prompt, image_name)
190194
return
191195

192196
backbone.to("cuda")
@@ -266,8 +270,9 @@ def main():
266270
raise ValueError("Pipeline does not have a transformer or unet backbone")
267271
pipe.to("cuda")
268272

269-
generate_image(pipe, args.prompt, image_name)
270-
print(f"Image generated using {args.model} model saved as {image_name}")
273+
if not args.skip_image:
274+
generate_image(pipe, args.prompt, image_name)
275+
print(f"Image generated using {args.model} model saved as {image_name}")
271276

272277
if args.benchmark:
273278
print(

0 commit comments

Comments
 (0)