@@ -214,7 +214,7 @@ def use_compile(pipeline):
214214 pipeline .vae .decode , mode = "max-autotune" , fullgraph = True
215215 )
216216
217- # warmup for a few iterations
217+ # warmup for a few iterations (`num_inference_steps` shouldn't matter)
218218 for _ in range (3 ):
219219 pipeline (
220220 "dummy prompt to trigger torch compilation" ,
@@ -233,21 +233,23 @@ def download_hosted_file(filename, output_path):
233233 hf_hub_download (REPO_NAME , filename , local_dir = os .path .dirname (output_path ))
234234
235235
236- def use_export_aoti (pipeline , cache_dir , serialize = False ):
236+ def use_export_aoti (pipeline , cache_dir , serialize = False , is_timestep_distilled = True ):
237237 # create cache dir if needed
238238 pathlib .Path (cache_dir ).mkdir (parents = True , exist_ok = True )
239239
240240 def _example_tensor (* shape ):
241241 return torch .randn (* shape , device = "cuda" , dtype = torch .bfloat16 )
242242
243243 # === Transformer compile / export ===
244+ seq_length = 256 if is_timestep_distilled else 512
245+ # these shapes are for 1024x1024 resolution.
244246 transformer_kwargs = {
245247 "hidden_states" : _example_tensor (1 , 4096 , 64 ),
246248 "timestep" : torch .tensor ([1. ], device = "cuda" , dtype = torch .bfloat16 ),
247- "guidance" : None ,
249+ "guidance" : None if is_timestep_distilled else torch . tensor ([ 1. ], device = "cuda" , dtype = torch . bfloat16 ) ,
248250 "pooled_projections" : _example_tensor (1 , 768 ),
249- "encoder_hidden_states" : _example_tensor (1 , 512 , 4096 ),
250- "txt_ids" : _example_tensor (512 , 3 ),
251+ "encoder_hidden_states" : _example_tensor (1 , seq_length , 4096 ),
252+ "txt_ids" : _example_tensor (seq_length , 3 ),
251253 "img_ids" : _example_tensor (4096 , 3 ),
252254 "joint_attention_kwargs" : {},
253255 "return_dict" : False ,
@@ -291,9 +293,7 @@ def _example_tensor(*shape):
291293 # hack to get around export's limitations
292294 pipeline .vae .forward = pipeline .vae .decode
293295
294- vae_decode_kwargs = {
295- "return_dict" : False ,
296- }
296+ vae_decode_kwargs = {"return_dict" : False }
297297
298298 # Possibly serialize model out
299299 decoder_package_path = os .path .join (cache_dir , "exported_decoder.pt2" )
@@ -334,7 +334,7 @@ def _example_tensor(*shape):
334334
335335
336336def optimize (pipeline , args ):
337- pipeline . set_progress_bar_config ( disable = True )
337+ is_timestep_distilled = args . ckpt == "black-forest-labs/FLUX.1-schnell"
338338
339339 # fuse QKV projections in Transformer and VAE
340340 if not args .disable_fused_projections :
@@ -376,7 +376,9 @@ def optimize(pipeline, args):
376376 pipeline = use_compile (pipeline )
377377 elif args .compile_export_mode == "export_aoti" :
378378 # NB: Using a cached export + AOTI model is not supported yet
379- pipeline = use_export_aoti (pipeline , cache_dir = args .cache_dir , serialize = True )
379+ pipeline = use_export_aoti (
380+ pipeline , cache_dir = args .cache_dir , serialize = True , is_timestep_distilled = is_timestep_distilled
381+ )
380382 elif args .compile_export_mode == "disabled" :
381383 pass
382384 else :
@@ -390,5 +392,6 @@ def optimize(pipeline, args):
390392def load_pipeline (args ):
391393 load_dtype = torch .float32 if args .disable_bf16 else torch .bfloat16
392394 pipeline = FluxPipeline .from_pretrained (args .ckpt , torch_dtype = load_dtype ).to (args .device )
395+ pipeline .set_progress_bar_config (disable = True )
393396 pipeline = optimize (pipeline , args )
394397 return pipeline
0 commit comments