@@ -1881,12 +1881,15 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
18811881 return out
18821882
18831883
1884- def _prepare_attn_mask_native (
1884+ def _prepare_additive_attn_mask (
18851885 attn_mask : torch .Tensor , target_dtype : torch .dtype , reshape_4d : bool = True
18861886) -> torch .Tensor :
18871887 """
18881888 Convert a 2D boolean attention mask to an additive mask, optionally reshaping to 4D for SDPA.
18891889
1890+ This helper is used by both native SDPA and xformers backends to convert boolean masks to the additive format they
1891+ require.
1892+
18901893 Args:
18911894 attn_mask: 2D boolean tensor [batch_size, seq_len_k] where True means attend, False means mask out
18921895 target_dtype: The dtype to convert the mask to (usually query.dtype)
@@ -1939,7 +1942,7 @@ def _native_attention(
19391942 # attn_mask is [batch_size, seq_len_k] boolean: True means attend, False means mask out
19401943 # SDPA expects [batch_size, 1, 1, seq_len_k] additive mask: 0.0 for attend, -inf for mask out
19411944 # Use helper to convert boolean to additive mask and reshape to 4D
1942- attn_mask = _prepare_attn_mask_native (attn_mask , target_dtype = query .dtype )
1945+ attn_mask = _prepare_additive_attn_mask (attn_mask , target_dtype = query .dtype )
19431946
19441947 if _parallel_config is None :
19451948 query , key , value = (x .permute (0 , 2 , 1 , 3 ) for x in (query , key , value ))
@@ -2480,7 +2483,9 @@ def _xformers_attention(
24802483 )
24812484 # Fill in the actual mask values (converting boolean to additive)
24822485 # Use helper to convert 2D boolean -> 4D additive mask
2483- mask_additive = _prepare_attn_mask_native (attn_mask , target_dtype = query .dtype ) # [batch, 1, 1, seq_len_k]
2486+ mask_additive = _prepare_additive_attn_mask (
2487+ attn_mask , target_dtype = query .dtype
2488+ ) # [batch, 1, 1, seq_len_k]
24842489 # Broadcast to [batch, heads, seq_q, seq_len_k]
24852490 aligned_mask [:, :, :, :original_seq_len ] = mask_additive
24862491 # Mask out the padding (already -inf from zeros -> where with default)
0 commit comments