File tree Expand file tree Collapse file tree 1 file changed +13
-1
lines changed Expand file tree Collapse file tree 1 file changed +13
-1
lines changed Original file line number Diff line number Diff line change @@ -1227,10 +1227,22 @@ def _get_source_transforms( # noqa
12271227 if args .expand_rope_table :
12281228 transforms .append (materialze_broadcast_of_rope_freq_cis )
12291229
1230+ use_attention_mask_for_custom_sdpa = False
1231+ if isinstance (args , argparse .Namespace ):
1232+ if getattr (args , "use_custom_sdpa_with_attention_mask" , None ):
1233+ use_attention_mask_for_custom_sdpa = True
1234+
12301235 if args .use_sdpa_with_kv_cache :
12311236 transforms .append (replace_kv_cache_with_custom_kv_cache )
12321237 # todo: do this optionally
1233- transforms .append (replace_sdpa_with_custom_op )
1238+ # if use attention mask instead of causal attention
1239+ # then create partial function that sets use_attention_mask=True
1240+ if use_attention_mask_for_custom_sdpa :
1241+ transforms .append (
1242+ partial (replace_sdpa_with_custom_op , use_attention_mask = True )
1243+ )
1244+ else :
1245+ transforms .append (replace_sdpa_with_custom_op )
12341246
12351247 if args .quantize_kv_cache :
12361248 assert args .use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
You can’t perform that action at this time.
0 commit comments