Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2159,6 +2159,19 @@ def __call__(
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
# Pad attention mask to the next multiple of 8 for bfloat16 alignment.
if attention_mask.dtype == torch.bfloat16 and attention_mask.shape[-1] % 8 != 0:
mask_shape = attention_mask.shape
# Create a new mask with padded sequence length.
mask = torch.zeros(
(mask_shape[0], mask_shape[1], math.ceil(mask_shape[-1] / 8) * 8),
device=attention_mask.device,
dtype=attention_mask.dtype,
)
# Copy original attention mask values to the padded mask.
mask[:, :, : mask_shape[-1]] = attention_mask
# Restore the original shape from the padded mask.
attention_mask = mask[:, :, : mask_shape[-1]]
attention_mask = attention_mask.expand(-1, query_tokens, -1)

if attn.group_norm is not None:
Expand Down
21 changes: 21 additions & 0 deletions tests/models/unets/test_models_unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,27 @@ def test_xformers_enable_works(self):
== "XFormersAttnProcessor"
), "xformers is not enabled"

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_attention_mask_padding(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(init_dict)
model.to(torch_device, dtype=torch.bfloat16)
encoder_hidden_states = inputs_dict["encoder_hidden_states"].to(dtype=torch.bfloat16, device=torch_device)
attention_mask = torch.ones((2, 1, 22), dtype=torch.bfloat16, device=torch_device)
model.enable_xformers_memory_efficient_attention()
time_step = inputs_dict["timestep"].to(torch_device, dtype=torch.bfloat16)
noise = inputs_dict["sample"].to(torch_device, dtype=torch.bfloat16)
output_hidden_states = model(
attention_mask=attention_mask,
timestep=time_step,
encoder_hidden_states=encoder_hidden_states,
sample=noise,
)
assert output_hidden_states.sample.shape == encoder_hidden_states.shape, "Output hidden states shape mismatch"

@require_torch_accelerator_with_training
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
Expand Down