Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2144,6 +2144,19 @@ def __call__(
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
# Pad attention mask to the next multiple of 8 for bfloat16 alignment.
if attention_mask.dtype == torch.bfloat16 and attention_mask.shape[-1] % 8 != 0:
mask_shape = attention_mask.shape
# Create a new mask with padded sequence length.
mask = torch.zeros(
(mask_shape[0], mask_shape[1], math.ceil(mask_shape[-1] / 8) * 8),
device=attention_mask.device,
dtype=attention_mask.dtype,
)
# Copy original attention mask values to the padded mask.
mask[:, :, : mask_shape[-1]] = attention_mask
# Restore the original shape from the padded mask.
attention_mask = mask[:, :, : mask_shape[-1]]
attention_mask = attention_mask.expand(-1, query_tokens, -1)

if attn.group_norm is not None:
Expand Down
28 changes: 28 additions & 0 deletions tests/models/unets/test_models_unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import copy
import gc
import math
import os
import tempfile
import unittest
Expand Down Expand Up @@ -406,6 +407,33 @@ def test_xformers_enable_works(self):
== "XFormersAttnProcessor"
), "xformers is not enabled"

def test_attention_mask_padding(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

encoder_hidden_states = inputs_dict["encoder_hidden_states"]
query_tokens = encoder_hidden_states.shape[1]
attention_mask = torch.ones((2, query_tokens, 22), dtype=torch.bfloat16, device="cuda")

attention_mask_shape_before = attention_mask.shape[-1]
if attention_mask.dtype == torch.bfloat16 and attention_mask.shape[-1] % 8 != 0:
padded_length = math.ceil(attention_mask.shape[-1] / 8) * 8
mask = torch.zeros(
(attention_mask.shape[0], attention_mask.shape[1], padded_length),
device=attention_mask.device,
dtype=attention_mask.dtype,
)
mask[:, :, : attention_mask.shape[-1]] = attention_mask
attention_mask = mask

assert attention_mask.shape[-1] % 8 == 0, "Attention mask not padded to a multiple of 8"
assert attention_mask[:, :, :attention_mask_shape_before].equal(
attention_mask[:, :, :attention_mask_shape_before]
), "Original values in attention mask are not preserved"

expanded_attention_mask = attention_mask.expand(-1, query_tokens, -1)

assert expanded_attention_mask.shape[1] == query_tokens, "Attention mask expansion for query tokens failed"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in this test, we want to have a check for functional correctness instead of applying the same logic that we're applying within the attention processor class.

So, this means we could first enable xformers attention on the UNet and then do a forward pass and then design our tests accordingly.


@require_torch_accelerator_with_training
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
Expand Down
Loading