File tree Expand file tree Collapse file tree 1 file changed +13
-1
lines changed Expand file tree Collapse file tree 1 file changed +13
-1
lines changed Original file line number Diff line number Diff line change @@ -1222,10 +1222,22 @@ def _get_source_transforms( # noqa
12221222 if args .expand_rope_table :
12231223 transforms .append (materialze_broadcast_of_rope_freq_cis )
12241224
1225+ use_attention_mask_for_custom_sdpa = False
1226+ if isinstance (args , argparse .Namespace ):
1227+ if getattr (args , "use_custom_sdpa_with_attention_mask" , None ):
1228+ use_attention_mask_for_custom_sdpa = True
1229+
12251230 if args .use_sdpa_with_kv_cache :
12261231 transforms .append (replace_kv_cache_with_custom_kv_cache )
12271232 # todo: do this optionally
1228- transforms .append (replace_sdpa_with_custom_op )
1233+ # if use attention mask instead of causal attention
1234+ # then create partial function that sets use_attention_mask=True
1235+ if use_attention_mask_for_custom_sdpa :
1236+ transforms .append (
1237+ partial (replace_sdpa_with_custom_op , use_attention_mask = True )
1238+ )
1239+ else :
1240+ transforms .append (replace_sdpa_with_custom_op )
12291241
12301242 if args .quantize_kv_cache :
12311243 assert args .use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
You can’t perform that action at this time.
0 commit comments