Skip to content

Commit 1433783

Browse files
committed
rename helper
1 parent c88bc06 commit 1433783

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1881,12 +1881,15 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
18811881
return out
18821882

18831883

1884-
def _prepare_attn_mask_native(
1884+
def _prepare_additive_attn_mask(
18851885
attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True
18861886
) -> torch.Tensor:
18871887
"""
18881888
Convert a 2D boolean attention mask to an additive mask, optionally reshaping to 4D for SDPA.
18891889
1890+
This helper is used by both native SDPA and xformers backends to convert boolean masks to the additive format they
1891+
require.
1892+
18901893
Args:
18911894
attn_mask: 2D boolean tensor [batch_size, seq_len_k] where True means attend, False means mask out
18921895
target_dtype: The dtype to convert the mask to (usually query.dtype)
@@ -1939,7 +1942,7 @@ def _native_attention(
19391942
# attn_mask is [batch_size, seq_len_k] boolean: True means attend, False means mask out
19401943
# SDPA expects [batch_size, 1, 1, seq_len_k] additive mask: 0.0 for attend, -inf for mask out
19411944
# Use helper to convert boolean to additive mask and reshape to 4D
1942-
attn_mask = _prepare_attn_mask_native(attn_mask, target_dtype=query.dtype)
1945+
attn_mask = _prepare_additive_attn_mask(attn_mask, target_dtype=query.dtype)
19431946

19441947
if _parallel_config is None:
19451948
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
@@ -2480,7 +2483,9 @@ def _xformers_attention(
24802483
)
24812484
# Fill in the actual mask values (converting boolean to additive)
24822485
# Use helper to convert 2D boolean -> 4D additive mask
2483-
mask_additive = _prepare_attn_mask_native(attn_mask, target_dtype=query.dtype) # [batch, 1, 1, seq_len_k]
2486+
mask_additive = _prepare_additive_attn_mask(
2487+
attn_mask, target_dtype=query.dtype
2488+
) # [batch, 1, 1, seq_len_k]
24842489
# Broadcast to [batch, heads, seq_q, seq_len_k]
24852490
aligned_mask[:, :, :, :original_seq_len] = mask_additive
24862491
# Mask out the padding (already -inf from zeros -> where with default)

0 commit comments

Comments
 (0)