33import torch
44import torch .nn .functional as F
55from diffusers import FluxPipeline
6- from torch ._inductor .package import load_package
6+ from torch ._inductor .package import load_package as inductor_load_package
77from typing import List , Optional , Tuple
8+ import inspect
89
910
1011@torch .library .custom_op ("flash::flash_attn_func" , mutates_args = ())
@@ -36,23 +37,28 @@ def flash_attn_func(
3637 import flash_attn_interface
3738
3839 dtype = torch .float8_e4m3fn
40+
41+ sig = inspect .signature (flash_attn_interface .flash_attn_func )
42+ accepted = set (sig .parameters )
43+ all_kwargs = {
44+ "softmax_scale" : softmax_scale ,
45+ "causal" : causal ,
46+ "qv" : qv ,
47+ "q_descale" : q_descale ,
48+ "k_descale" : k_descale ,
49+ "v_descale" : v_descale ,
50+ "window_size" : window_size ,
51+ "sink_token_length" : sink_token_length ,
52+ "softcap" : softcap ,
53+ "num_splits" : num_splits ,
54+ "pack_gqa" : pack_gqa ,
55+ "deterministic" : deterministic ,
56+ "sm_margin" : sm_margin ,
57+ }
58+ kwargs = {k : v for k , v in all_kwargs .items () if k in accepted }
59+
3960 outputs = flash_attn_interface .flash_attn_func (
40- q .to (dtype ),
41- k .to (dtype ),
42- v .to (dtype ),
43- softmax_scale = softmax_scale ,
44- causal = causal ,
45- qv = qv ,
46- q_descale = q_descale ,
47- k_descale = k_descale ,
48- v_descale = v_descale ,
49- window_size = window_size ,
50- sink_token_length = sink_token_length ,
51- softcap = softcap ,
52- num_splits = num_splits ,
53- pack_gqa = pack_gqa ,
54- deterministic = deterministic ,
55- sm_margin = sm_margin ,
61+ q .to (dtype ), k .to (dtype ), v .to (dtype ), ** kwargs ,
5662 )
5763 return outputs [0 ]
5864
@@ -214,7 +220,7 @@ def use_compile(pipeline):
214220 pipeline .vae .decode , mode = "max-autotune" , fullgraph = True
215221 )
216222
217- # warmup for a few iterations
223+ # warmup for a few iterations (`num_inference_steps` shouldn't matter)
218224 for _ in range (3 ):
219225 pipeline (
220226 "dummy prompt to trigger torch compilation" ,
@@ -233,28 +239,40 @@ def download_hosted_file(filename, output_path):
233239 hf_hub_download (REPO_NAME , filename , local_dir = os .path .dirname (output_path ))
234240
235241
236- def use_export_aoti (pipeline , cache_dir , serialize = False ):
242+ def load_package (package_path ):
243+ if not os .path .exists (package_path ):
244+ download_hosted_file (os .path .basename (package_path ), package_path )
245+
246+ loaded_package = inductor_load_package (package_path , run_single_threaded = True )
247+ return loaded_package
248+
249+
250+ def use_export_aoti (pipeline , cache_dir , serialize = False , is_timestep_distilled = True ):
237251 # create cache dir if needed
238252 pathlib .Path (cache_dir ).mkdir (parents = True , exist_ok = True )
239253
240254 def _example_tensor (* shape ):
241255 return torch .randn (* shape , device = "cuda" , dtype = torch .bfloat16 )
242256
243257 # === Transformer compile / export ===
258+ seq_length = 256 if is_timestep_distilled else 512
259+ # these shapes are for 1024x1024 resolution.
244260 transformer_kwargs = {
245261 "hidden_states" : _example_tensor (1 , 4096 , 64 ),
246262 "timestep" : torch .tensor ([1. ], device = "cuda" , dtype = torch .bfloat16 ),
247- "guidance" : None ,
263+ "guidance" : None if is_timestep_distilled else torch . tensor ([ 1. ], device = "cuda" , dtype = torch . bfloat16 ) ,
248264 "pooled_projections" : _example_tensor (1 , 768 ),
249- "encoder_hidden_states" : _example_tensor (1 , 512 , 4096 ),
250- "txt_ids" : _example_tensor (512 , 3 ),
265+ "encoder_hidden_states" : _example_tensor (1 , seq_length , 4096 ),
266+ "txt_ids" : _example_tensor (seq_length , 3 ),
251267 "img_ids" : _example_tensor (4096 , 3 ),
252268 "joint_attention_kwargs" : {},
253269 "return_dict" : False ,
254270 }
255271
256272 # Possibly serialize model out
257- transformer_package_path = os .path .join (cache_dir , "exported_transformer.pt2" )
273+ transformer_package_path = os .path .join (
274+ cache_dir , "exported_transformer.pt2" if is_timestep_distilled else "exported_dev_transformer.pt2"
275+ )
258276 if serialize :
259277 # Apply export
260278 exported_transformer : torch .export .ExportedProgram = torch .export .export (
@@ -268,12 +286,7 @@ def _example_tensor(*shape):
268286 inductor_configs = {"max_autotune" : True , "triton.cudagraphs" : True },
269287 )
270288 # download serialized model if needed
271- if not os .path .exists (transformer_package_path ):
272- download_hosted_file (os .path .basename (transformer_package_path ), transformer_package_path )
273-
274- loaded_transformer = load_package (
275- transformer_package_path , run_single_threaded = True
276- )
289+ loaded_transformer = load_package (transformer_package_path )
277290
278291 # warmup before cudagraphing
279292 with torch .no_grad ():
@@ -291,12 +304,12 @@ def _example_tensor(*shape):
291304 # hack to get around export's limitations
292305 pipeline .vae .forward = pipeline .vae .decode
293306
294- vae_decode_kwargs = {
295- "return_dict" : False ,
296- }
307+ vae_decode_kwargs = {"return_dict" : False }
297308
298309 # Possibly serialize model out
299- decoder_package_path = os .path .join (cache_dir , "exported_decoder.pt2" )
310+ decoder_package_path = os .path .join (
311+ cache_dir , "exported_decoder.pt2" if is_timestep_distilled else "exported_dev_decoder.pt2"
312+ )
300313 if serialize :
301314 # Apply export
302315 exported_decoder : torch .export .ExportedProgram = torch .export .export (
@@ -310,10 +323,7 @@ def _example_tensor(*shape):
310323 inductor_configs = {"max_autotune" : True , "triton.cudagraphs" : True },
311324 )
312325 # download serialized model if needed
313- if not os .path .exists (decoder_package_path ):
314- download_hosted_file (os .path .basename (decoder_package_path ), decoder_package_path )
315-
316- loaded_decoder = load_package (decoder_package_path , run_single_threaded = True )
326+ loaded_decoder = load_package (decoder_package_path )
317327
318328 # warmup before cudagraphing
319329 with torch .no_grad ():
@@ -334,7 +344,7 @@ def _example_tensor(*shape):
334344
335345
336346def optimize (pipeline , args ):
337- pipeline .set_progress_bar_config ( disable = True )
347+ is_timestep_distilled = not pipeline .transformer . config . guidance_embeds
338348
339349 # fuse QKV projections in Transformer and VAE
340350 if not args .disable_fused_projections :
@@ -375,10 +385,12 @@ def optimize(pipeline, args):
375385 if args .compile_export_mode == "compile" :
376386 pipeline = use_compile (pipeline )
377387 elif args .compile_export_mode == "export_aoti" :
388+ # NB: Using a cached export + AOTI model is not supported yet
378389 pipeline = use_export_aoti (
379- pipeline ,
380- cache_dir = args .cache_dir ,
381- serialize = (not args .use_cached_model ),
390+ pipeline ,
391+ cache_dir = args .cache_dir ,
392+ serialize = (not args .use_cached_model ),
393+ is_timestep_distilled = is_timestep_distilled
382394 )
383395 elif args .compile_export_mode == "disabled" :
384396 pass
@@ -393,5 +405,6 @@ def optimize(pipeline, args):
393405def load_pipeline (args ):
394406 load_dtype = torch .float32 if args .disable_bf16 else torch .bfloat16
395407 pipeline = FluxPipeline .from_pretrained (args .ckpt , torch_dtype = load_dtype ).to (args .device )
408+ pipeline .set_progress_bar_config (disable = True )
396409 pipeline = optimize (pipeline , args )
397410 return pipeline
0 commit comments