@@ -59,7 +59,9 @@ def generate_image(pipe, prompt, image_name):
5959 print (f"Image generated saved as { image_name } " )
6060
6161
62- def benchmark_model (pipe , prompt , num_warmup = 10 , num_runs = 50 , num_inference_steps = 20 ):
62+ def benchmark_model (
63+ pipe , prompt , num_warmup = 10 , num_runs = 50 , num_inference_steps = 20 , model_dtype = "Half"
64+ ):
6365 """Benchmark the backbone model inference time."""
6466 backbone = pipe .transformer if hasattr (pipe , "transformer" ) else pipe .unet
6567
@@ -81,23 +83,25 @@ def forward_hook(_module, _input, _output):
8183 try :
8284 print (f"Starting warmup: { num_warmup } runs" )
8385 for _ in tqdm (range (num_warmup ), desc = "Warmup" ):
84- _ = pipe (
85- prompt ,
86- output_type = "pil" ,
87- num_inference_steps = num_inference_steps ,
88- generator = torch .Generator ("cuda" ).manual_seed (42 ),
89- )
86+ with torch .amp .autocast ("cuda" , dtype = dtype_map [model_dtype ]):
87+ _ = pipe (
88+ prompt ,
89+ output_type = "pil" ,
90+ num_inference_steps = num_inference_steps ,
91+ generator = torch .Generator ("cuda" ).manual_seed (42 ),
92+ )
9093
9194 backbone_times .clear ()
9295
9396 print (f"Starting benchmark: { num_runs } runs" )
9497 for _ in tqdm (range (num_runs ), desc = "Benchmark" ):
95- _ = pipe (
96- prompt ,
97- output_type = "pil" ,
98- num_inference_steps = num_inference_steps ,
99- generator = torch .Generator ("cuda" ).manual_seed (42 ),
100- )
98+ with torch .amp .autocast ("cuda" , dtype = dtype_map [model_dtype ]):
99+ _ = pipe (
100+ prompt ,
101+ output_type = "pil" ,
102+ num_inference_steps = num_inference_steps ,
103+ generator = torch .Generator ("cuda" ).manual_seed (42 ),
104+ )
101105 finally :
102106 pre_handle .remove ()
103107 post_handle .remove ()
@@ -156,6 +160,9 @@ def main():
156160 parser .add_argument (
157161 "--benchmark" , action = "store_true" , help = "Benchmark the model backbone inference time"
158162 )
163+ parser .add_argument (
164+ "--torch-compile" , action = "store_true" , help = "Use torch.compile() on the backbone model"
165+ )
159166 parser .add_argument ("--skip-image" , action = "store_true" , help = "Skip image generation" )
160167 args = parser .parse_args ()
161168
@@ -181,6 +188,13 @@ def main():
181188 if args .restore_from :
182189 mto .restore (backbone , args .restore_from )
183190
191+ if args .torch_compile :
192+ assert args .model_dtype in ["BFloat16" , "Float" ], (
193+ "torch.compile() only supports BFloat16 and Float"
194+ )
195+ print ("Compiling backbone with torch.compile()..." )
196+ backbone = torch .compile (backbone )
197+
184198 if args .torch :
185199 if hasattr (pipe , "transformer" ):
186200 pipe .transformer = backbone
@@ -189,7 +203,7 @@ def main():
189203 pipe .to ("cuda" )
190204
191205 if args .benchmark :
192- benchmark_model (pipe , args .prompt )
206+ benchmark_model (pipe , args .prompt , model_dtype = args . model_dtype )
193207
194208 if not args .skip_image :
195209 generate_image (pipe , args .prompt , image_name )
0 commit comments