2323    update_dynamic_axes ,
2424)
2525from  quantize  import  ModelType , PipelineManager 
26+ from  tqdm  import  tqdm 
2627
2728import  modelopt .torch .opt  as  mto 
2829from  modelopt .torch ._deploy ._runtime  import  RuntimeRegistry 
@@ -58,6 +59,59 @@ def generate_image(pipe, prompt, image_name):
5859    print (f"Image generated saved as { image_name }  " )
5960
6061
62+ def  benchmark_model (
63+     pipe , prompt , num_warmup = 10 , num_runs = 50 , num_inference_steps = 20 , model_dtype = "Half" 
64+ ):
65+     """Benchmark the backbone model inference time.""" 
66+     backbone  =  pipe .transformer  if  hasattr (pipe , "transformer" ) else  pipe .unet 
67+ 
68+     backbone_times  =  []
69+     start_event  =  torch .cuda .Event (enable_timing = True )
70+     end_event  =  torch .cuda .Event (enable_timing = True )
71+ 
72+     def  forward_pre_hook (_module , _input ):
73+         start_event .record ()
74+ 
75+     def  forward_hook (_module , _input , _output ):
76+         end_event .record ()
77+         torch .cuda .synchronize ()
78+         backbone_times .append (start_event .elapsed_time (end_event ))
79+ 
80+     pre_handle  =  backbone .register_forward_pre_hook (forward_pre_hook )
81+     post_handle  =  backbone .register_forward_hook (forward_hook )
82+ 
83+     try :
84+         print (f"Starting warmup: { num_warmup }   runs" )
85+         for  _  in  tqdm (range (num_warmup ), desc = "Warmup" ):
86+             with  torch .amp .autocast ("cuda" , dtype = dtype_map [model_dtype ]):
87+                 _  =  pipe (
88+                     prompt ,
89+                     output_type = "pil" ,
90+                     num_inference_steps = num_inference_steps ,
91+                     generator = torch .Generator ("cuda" ).manual_seed (42 ),
92+                 )
93+ 
94+         backbone_times .clear ()
95+ 
96+         print (f"Starting benchmark: { num_runs }   runs" )
97+         for  _  in  tqdm (range (num_runs ), desc = "Benchmark" ):
98+             with  torch .amp .autocast ("cuda" , dtype = dtype_map [model_dtype ]):
99+                 _  =  pipe (
100+                     prompt ,
101+                     output_type = "pil" ,
102+                     num_inference_steps = num_inference_steps ,
103+                     generator = torch .Generator ("cuda" ).manual_seed (42 ),
104+                 )
105+     finally :
106+         pre_handle .remove ()
107+         post_handle .remove ()
108+ 
109+     total_backbone_time  =  sum (backbone_times )
110+     avg_latency  =  total_backbone_time  /  (num_runs  *  num_inference_steps )
111+     print (f"Inference latency of the torch backbone: { avg_latency :.2f}   ms" )
112+     return  avg_latency 
113+ 
114+ 
61115def  main ():
62116    parser  =  argparse .ArgumentParser ()
63117    parser .add_argument (
@@ -92,15 +146,24 @@ def main():
92146        "--onnx-load-path" , type = str , default = "" , help = "Path to load the ONNX model" 
93147    )
94148    parser .add_argument (
95-         "--trt-engine-load-path" , type = str , default = None , help = "Path to load the TRT  engine" 
149+         "--trt-engine-load-path" , type = str , default = None , help = "Path to load the TensorRT  engine" 
96150    )
97151    parser .add_argument (
98152        "--dq-only" , action = "store_true" , help = "Converts the ONNX model to a dq_only model" 
99153    )
100154    parser .add_argument (
101-         "--torch" , action = "store_true" , help = "Generate an image using the torch pipeline" 
155+         "--torch" ,
156+         action = "store_true" ,
157+         help = "Use the torch pipeline for image generation or benchmarking" ,
102158    )
103159    parser .add_argument ("--save-image-as" , type = str , default = None , help = "Name of the image to save" )
160+     parser .add_argument (
161+         "--benchmark" , action = "store_true" , help = "Benchmark the model backbone inference time" 
162+     )
163+     parser .add_argument (
164+         "--torch-compile" , action = "store_true" , help = "Use torch.compile() on the backbone model" 
165+     )
166+     parser .add_argument ("--skip-image" , action = "store_true" , help = "Skip image generation" )
104167    args  =  parser .parse_args ()
105168
106169    image_name  =  args .save_image_as  if  args .save_image_as  else  f"{ args .model }  .png" 
@@ -125,13 +188,25 @@ def main():
125188    if  args .restore_from :
126189        mto .restore (backbone , args .restore_from )
127190
191+     if  args .torch_compile :
192+         assert  args .model_dtype  in  ["BFloat16" , "Float" , "Half" ], (
193+             "torch.compile() only supports BFloat16 and Float" 
194+         )
195+         print ("Compiling backbone with torch.compile()..." )
196+         backbone  =  torch .compile (backbone , mode = "max-autotune" )
197+ 
128198    if  args .torch :
129199        if  hasattr (pipe , "transformer" ):
130200            pipe .transformer  =  backbone 
131201        elif  hasattr (pipe , "unet" ):
132202            pipe .unet  =  backbone 
133203        pipe .to ("cuda" )
134-         generate_image (pipe , args .prompt , image_name )
204+ 
205+         if  args .benchmark :
206+             benchmark_model (pipe , args .prompt , model_dtype = args .model_dtype )
207+ 
208+         if  not  args .skip_image :
209+             generate_image (pipe , args .prompt , image_name )
135210        return 
136211
137212    backbone .to ("cuda" )
@@ -211,10 +286,14 @@ def main():
211286        raise  ValueError ("Pipeline does not have a transformer or unet backbone" )
212287    pipe .to ("cuda" )
213288
214-     generate_image (pipe , args .prompt , image_name )
215-     print (f"Image generated using { args .model }   model saved as { image_name }  " )
289+     if  not  args .skip_image :
290+         generate_image (pipe , args .prompt , image_name )
291+         print (f"Image generated using { args .model }   model saved as { image_name }  " )
216292
217-     print (f"Inference latency of the backbone of the pipeline is { device_model .get_latency ()}   ms" )
293+     if  args .benchmark :
294+         print (
295+             f"Inference latency of the TensorRT optimized backbone: { device_model .get_latency ()}   ms" 
296+         )
218297
219298
220299if  __name__  ==  "__main__" :
0 commit comments