2323 update_dynamic_axes ,
2424)
2525from quantize import ModelType , PipelineManager
26+ from tqdm import tqdm
2627
2728import modelopt .torch .opt as mto
2829from modelopt .torch ._deploy ._runtime import RuntimeRegistry
@@ -58,6 +59,59 @@ def generate_image(pipe, prompt, image_name):
5859 print (f"Image generated saved as { image_name } " )
5960
6061
62+ def benchmark_model (
63+ pipe , prompt , num_warmup = 10 , num_runs = 50 , num_inference_steps = 20 , model_dtype = "Half"
64+ ):
65+ """Benchmark the backbone model inference time."""
66+ backbone = pipe .transformer if hasattr (pipe , "transformer" ) else pipe .unet
67+
68+ backbone_times = []
69+ start_event = torch .cuda .Event (enable_timing = True )
70+ end_event = torch .cuda .Event (enable_timing = True )
71+
72+ def forward_pre_hook (_module , _input ):
73+ start_event .record ()
74+
75+ def forward_hook (_module , _input , _output ):
76+ end_event .record ()
77+ torch .cuda .synchronize ()
78+ backbone_times .append (start_event .elapsed_time (end_event ))
79+
80+ pre_handle = backbone .register_forward_pre_hook (forward_pre_hook )
81+ post_handle = backbone .register_forward_hook (forward_hook )
82+
83+ try :
84+ print (f"Starting warmup: { num_warmup } runs" )
85+ for _ in tqdm (range (num_warmup ), desc = "Warmup" ):
86+ with torch .amp .autocast ("cuda" , dtype = dtype_map [model_dtype ]):
87+ _ = pipe (
88+ prompt ,
89+ output_type = "pil" ,
90+ num_inference_steps = num_inference_steps ,
91+ generator = torch .Generator ("cuda" ).manual_seed (42 ),
92+ )
93+
94+ backbone_times .clear ()
95+
96+ print (f"Starting benchmark: { num_runs } runs" )
97+ for _ in tqdm (range (num_runs ), desc = "Benchmark" ):
98+ with torch .amp .autocast ("cuda" , dtype = dtype_map [model_dtype ]):
99+ _ = pipe (
100+ prompt ,
101+ output_type = "pil" ,
102+ num_inference_steps = num_inference_steps ,
103+ generator = torch .Generator ("cuda" ).manual_seed (42 ),
104+ )
105+ finally :
106+ pre_handle .remove ()
107+ post_handle .remove ()
108+
109+ total_backbone_time = sum (backbone_times )
110+ avg_latency = total_backbone_time / (num_runs * num_inference_steps )
111+ print (f"Inference latency of the torch backbone: { avg_latency :.2f} ms" )
112+ return avg_latency
113+
114+
61115def main ():
62116 parser = argparse .ArgumentParser ()
63117 parser .add_argument (
@@ -92,15 +146,24 @@ def main():
92146 "--onnx-load-path" , type = str , default = "" , help = "Path to load the ONNX model"
93147 )
94148 parser .add_argument (
95- "--trt-engine-load-path" , type = str , default = None , help = "Path to load the TRT engine"
149+ "--trt-engine-load-path" , type = str , default = None , help = "Path to load the TensorRT engine"
96150 )
97151 parser .add_argument (
98152 "--dq-only" , action = "store_true" , help = "Converts the ONNX model to a dq_only model"
99153 )
100154 parser .add_argument (
101- "--torch" , action = "store_true" , help = "Generate an image using the torch pipeline"
155+ "--torch" ,
156+ action = "store_true" ,
157+ help = "Use the torch pipeline for image generation or benchmarking" ,
102158 )
103159 parser .add_argument ("--save-image-as" , type = str , default = None , help = "Name of the image to save" )
160+ parser .add_argument (
161+ "--benchmark" , action = "store_true" , help = "Benchmark the model backbone inference time"
162+ )
163+ parser .add_argument (
164+ "--torch-compile" , action = "store_true" , help = "Use torch.compile() on the backbone model"
165+ )
166+ parser .add_argument ("--skip-image" , action = "store_true" , help = "Skip image generation" )
104167 args = parser .parse_args ()
105168
106169 image_name = args .save_image_as if args .save_image_as else f"{ args .model } .png"
@@ -125,13 +188,25 @@ def main():
125188 if args .restore_from :
126189 mto .restore (backbone , args .restore_from )
127190
191+ if args .torch_compile :
192+ assert args .model_dtype in ["BFloat16" , "Float" , "Half" ], (
193+ "torch.compile() only supports BFloat16 and Float"
194+ )
195+ print ("Compiling backbone with torch.compile()..." )
196+ backbone = torch .compile (backbone , mode = "max-autotune" )
197+
128198 if args .torch :
129199 if hasattr (pipe , "transformer" ):
130200 pipe .transformer = backbone
131201 elif hasattr (pipe , "unet" ):
132202 pipe .unet = backbone
133203 pipe .to ("cuda" )
134- generate_image (pipe , args .prompt , image_name )
204+
205+ if args .benchmark :
206+ benchmark_model (pipe , args .prompt , model_dtype = args .model_dtype )
207+
208+ if not args .skip_image :
209+ generate_image (pipe , args .prompt , image_name )
135210 return
136211
137212 backbone .to ("cuda" )
@@ -211,10 +286,14 @@ def main():
211286 raise ValueError ("Pipeline does not have a transformer or unet backbone" )
212287 pipe .to ("cuda" )
213288
214- generate_image (pipe , args .prompt , image_name )
215- print (f"Image generated using { args .model } model saved as { image_name } " )
289+ if not args .skip_image :
290+ generate_image (pipe , args .prompt , image_name )
291+ print (f"Image generated using { args .model } model saved as { image_name } " )
216292
217- print (f"Inference latency of the backbone of the pipeline is { device_model .get_latency ()} ms" )
293+ if args .benchmark :
294+ print (
295+ f"Inference latency of the TensorRT optimized backbone: { device_model .get_latency ()} ms"
296+ )
218297
219298
220299if __name__ == "__main__" :
0 commit comments