File tree Expand file tree Collapse file tree 1 file changed +13
-1
lines changed Expand file tree Collapse file tree 1 file changed +13
-1
lines changed Original file line number Diff line number Diff line change @@ -1227,10 +1227,22 @@ def _get_source_transforms(  # noqa
12271227    if  args .expand_rope_table :
12281228        transforms .append (materialze_broadcast_of_rope_freq_cis )
12291229
1230+     use_attention_mask_for_custom_sdpa  =  False 
1231+     if  isinstance (args , argparse .Namespace ):
1232+         if  getattr (args , "use_custom_sdpa_with_attention_mask" , None ):
1233+             use_attention_mask_for_custom_sdpa  =  True 
1234+ 
12301235    if  args .use_sdpa_with_kv_cache :
12311236        transforms .append (replace_kv_cache_with_custom_kv_cache )
12321237        # todo: do this optionally 
1233-         transforms .append (replace_sdpa_with_custom_op )
1238+         # if use attention mask instead of causal attention 
1239+         # then create partial function that sets use_attention_mask=True 
1240+         if  use_attention_mask_for_custom_sdpa :
1241+             transforms .append (
1242+                 partial (replace_sdpa_with_custom_op , use_attention_mask = True )
1243+             )
1244+         else :
1245+             transforms .append (replace_sdpa_with_custom_op )
12341246
12351247    if  args .quantize_kv_cache :
12361248        assert  args .use_kv_cache , "quantize_kv_cache requires use_kv_cache=True" 
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments