@@ -37,7 +37,7 @@ def flash_attn_func(
3737 import flash_attn_interface
3838
3939 dtype = torch .float8_e4m3fn
40-
40+
4141 sig = inspect .signature (flash_attn_interface .flash_attn_func )
4242 accepted = set (sig .parameters )
4343 all_kwargs = {
@@ -56,7 +56,7 @@ def flash_attn_func(
5656 "sm_margin" : sm_margin ,
5757 }
5858 kwargs = {k : v for k , v in all_kwargs .items () if k in accepted }
59-
59+
6060 outputs = flash_attn_interface .flash_attn_func (
6161 q .to (dtype ), k .to (dtype ), v .to (dtype ), ** kwargs ,
6262 )
@@ -385,11 +385,10 @@ def optimize(pipeline, args):
385385 if args .compile_export_mode == "compile" :
386386 pipeline = use_compile (pipeline )
387387 elif args .compile_export_mode == "export_aoti" :
388- # NB: Using a cached export + AOTI model is not supported yet
389388 pipeline = use_export_aoti (
390- pipeline ,
391- cache_dir = args .cache_dir ,
392- serialize = (not args .use_cached_model ),
389+ pipeline ,
390+ cache_dir = args .cache_dir ,
391+ serialize = (not args .use_cached_model ),
393392 is_timestep_distilled = is_timestep_distilled
394393 )
395394 elif args .compile_export_mode == "disabled" :
0 commit comments