2323 update_dynamic_axes ,
2424)
2525from quantize import ModelType , PipelineManager
26- from tqdm import tqdm
2726
2827import modelopt .torch .opt as mto
2928from modelopt .torch ._deploy ._runtime import RuntimeRegistry
@@ -59,59 +58,6 @@ def generate_image(pipe, prompt, image_name):
5958 print (f"Image generated saved as { image_name } " )
6059
6160
62- def benchmark_model (
63- pipe , prompt , num_warmup = 10 , num_runs = 50 , num_inference_steps = 20 , model_dtype = "Half"
64- ):
65- """Benchmark the backbone model inference time."""
66- backbone = pipe .transformer if hasattr (pipe , "transformer" ) else pipe .unet
67-
68- backbone_times = []
69- start_event = torch .cuda .Event (enable_timing = True )
70- end_event = torch .cuda .Event (enable_timing = True )
71-
72- def forward_pre_hook (_module , _input ):
73- start_event .record ()
74-
75- def forward_hook (_module , _input , _output ):
76- end_event .record ()
77- torch .cuda .synchronize ()
78- backbone_times .append (start_event .elapsed_time (end_event ))
79-
80- pre_handle = backbone .register_forward_pre_hook (forward_pre_hook )
81- post_handle = backbone .register_forward_hook (forward_hook )
82-
83- try :
84- print (f"Starting warmup: { num_warmup } runs" )
85- for _ in tqdm (range (num_warmup ), desc = "Warmup" ):
86- with torch .amp .autocast ("cuda" , dtype = dtype_map [model_dtype ]):
87- _ = pipe (
88- prompt ,
89- output_type = "pil" ,
90- num_inference_steps = num_inference_steps ,
91- generator = torch .Generator ("cuda" ).manual_seed (42 ),
92- )
93-
94- backbone_times .clear ()
95-
96- print (f"Starting benchmark: { num_runs } runs" )
97- for _ in tqdm (range (num_runs ), desc = "Benchmark" ):
98- with torch .amp .autocast ("cuda" , dtype = dtype_map [model_dtype ]):
99- _ = pipe (
100- prompt ,
101- output_type = "pil" ,
102- num_inference_steps = num_inference_steps ,
103- generator = torch .Generator ("cuda" ).manual_seed (42 ),
104- )
105- finally :
106- pre_handle .remove ()
107- post_handle .remove ()
108-
109- total_backbone_time = sum (backbone_times )
110- avg_latency = total_backbone_time / (num_runs * num_inference_steps )
111- print (f"Inference latency of the torch backbone: { avg_latency :.2f} ms" )
112- return avg_latency
113-
114-
11561def main ():
11662 parser = argparse .ArgumentParser ()
11763 parser .add_argument (
@@ -146,24 +92,15 @@ def main():
14692 "--onnx-load-path" , type = str , default = "" , help = "Path to load the ONNX model"
14793 )
14894 parser .add_argument (
149- "--trt-engine-load-path" , type = str , default = None , help = "Path to load the TensorRT engine"
95+ "--trt-engine-load-path" , type = str , default = None , help = "Path to load the TRT engine"
15096 )
15197 parser .add_argument (
15298 "--dq-only" , action = "store_true" , help = "Converts the ONNX model to a dq_only model"
15399 )
154100 parser .add_argument (
155- "--torch" ,
156- action = "store_true" ,
157- help = "Use the torch pipeline for image generation or benchmarking" ,
101+ "--torch" , action = "store_true" , help = "Generate an image using the torch pipeline"
158102 )
159103 parser .add_argument ("--save-image-as" , type = str , default = None , help = "Name of the image to save" )
160- parser .add_argument (
161- "--benchmark" , action = "store_true" , help = "Benchmark the model backbone inference time"
162- )
163- parser .add_argument (
164- "--torch-compile" , action = "store_true" , help = "Use torch.compile() on the backbone model"
165- )
166- parser .add_argument ("--skip-image" , action = "store_true" , help = "Skip image generation" )
167104 args = parser .parse_args ()
168105
169106 image_name = args .save_image_as if args .save_image_as else f"{ args .model } .png"
@@ -188,25 +125,13 @@ def main():
188125 if args .restore_from :
189126 mto .restore (backbone , args .restore_from )
190127
191- if args .torch_compile :
192- assert args .model_dtype in ["BFloat16" , "Float" , "Half" ], (
193- "torch.compile() only supports BFloat16 and Float"
194- )
195- print ("Compiling backbone with torch.compile()..." )
196- backbone = torch .compile (backbone , mode = "max-autotune" )
197-
198128 if args .torch :
199129 if hasattr (pipe , "transformer" ):
200130 pipe .transformer = backbone
201131 elif hasattr (pipe , "unet" ):
202132 pipe .unet = backbone
203133 pipe .to ("cuda" )
204-
205- if args .benchmark :
206- benchmark_model (pipe , args .prompt , model_dtype = args .model_dtype )
207-
208- if not args .skip_image :
209- generate_image (pipe , args .prompt , image_name )
134+ generate_image (pipe , args .prompt , image_name )
210135 return
211136
212137 backbone .to ("cuda" )
@@ -286,14 +211,10 @@ def main():
286211 raise ValueError ("Pipeline does not have a transformer or unet backbone" )
287212 pipe .to ("cuda" )
288213
289- if not args .skip_image :
290- generate_image (pipe , args .prompt , image_name )
291- print (f"Image generated using { args .model } model saved as { image_name } " )
214+ generate_image (pipe , args .prompt , image_name )
215+ print (f"Image generated using { args .model } model saved as { image_name } " )
292216
293- if args .benchmark :
294- print (
295- f"Inference latency of the TensorRT optimized backbone: { device_model .get_latency ()} ms"
296- )
217+ print (f"Inference latency of the backbone of the pipeline is { device_model .get_latency ()} ms" )
297218
298219
299220if __name__ == "__main__" :
0 commit comments