@@ -59,7 +59,7 @@ def generate_image(pipe, prompt, image_name):
5959 print (f"Image generated saved as { image_name } " )
6060
6161
62- def benchmark_model (pipe , prompt , num_warmup = 3 , num_runs = 10 ):
62+ def benchmark_model (pipe , prompt , num_warmup = 3 , num_runs = 10 , num_inference_steps = 10 ):
6363 """Benchmark the backbone model inference time."""
6464 backbone = pipe .transformer if hasattr (pipe , "transformer" ) else pipe .unet
6565
@@ -82,7 +82,7 @@ def forward_hook(module, input, output):
8282 _ = pipe (
8383 prompt ,
8484 output_type = "pil" ,
85- num_inference_steps = 10 ,
85+ num_inference_steps = num_inference_steps ,
8686 generator = torch .Generator ("cuda" ).manual_seed (42 ),
8787 )
8888
@@ -92,15 +92,15 @@ def forward_hook(module, input, output):
9292 _ = pipe (
9393 prompt ,
9494 output_type = "pil" ,
95- num_inference_steps = 10 ,
95+ num_inference_steps = num_inference_steps ,
9696 generator = torch .Generator ("cuda" ).manual_seed (42 ),
9797 )
9898 finally :
9999 pre_handle .remove ()
100100 post_handle .remove ()
101101
102102 total_backbone_time = sum (backbone_times )
103- avg_latency = total_backbone_time / num_runs
103+ avg_latency = total_backbone_time / ( num_runs * num_inference_steps )
104104 return avg_latency
105105
106106
@@ -138,18 +138,21 @@ def main():
138138 "--onnx-load-path" , type = str , default = "" , help = "Path to load the ONNX model"
139139 )
140140 parser .add_argument (
141- "--trt-engine-load-path" , type = str , default = None , help = "Path to load the TRT engine"
141+ "--trt-engine-load-path" , type = str , default = None , help = "Path to load the TensorRT engine"
142142 )
143143 parser .add_argument (
144144 "--dq-only" , action = "store_true" , help = "Converts the ONNX model to a dq_only model"
145145 )
146146 parser .add_argument (
147- "--torch" , action = "store_true" , help = "Generate an image using the torch pipeline"
147+ "--torch" ,
148+ action = "store_true" ,
149+ help = "Use the torch pipeline for image generation or benchmarking" ,
148150 )
149151 parser .add_argument ("--save-image-as" , type = str , default = None , help = "Name of the image to save" )
150152 parser .add_argument (
151- "--benchmark" , action = "store_true" , help = "Benchmark the model inference time"
153+ "--benchmark" , action = "store_true" , help = "Benchmark the model backbone inference time"
152154 )
155+ parser .add_argument ("--skip-image" , action = "store_true" , help = "Skip image generation" )
153156 args = parser .parse_args ()
154157
155158 image_name = args .save_image_as if args .save_image_as else f"{ args .model } .png"
@@ -186,7 +189,8 @@ def main():
186189 torch_latency = benchmark_model (pipe , args .prompt )
187190 print (f"Inference latency of the torch pipeline is { torch_latency :.2f} ms" )
188191
189- generate_image (pipe , args .prompt , image_name )
192+ if not args .skip_image :
193+ generate_image (pipe , args .prompt , image_name )
190194 return
191195
192196 backbone .to ("cuda" )
@@ -266,8 +270,9 @@ def main():
266270 raise ValueError ("Pipeline does not have a transformer or unet backbone" )
267271 pipe .to ("cuda" )
268272
269- generate_image (pipe , args .prompt , image_name )
270- print (f"Image generated using { args .model } model saved as { image_name } " )
273+ if not args .skip_image :
274+ generate_image (pipe , args .prompt , image_name )
275+ print (f"Image generated using { args .model } model saved as { image_name } " )
271276
272277 if args .benchmark :
273278 print (
0 commit comments