Skip to content

Conversation

@SahilCarterr
Copy link
Contributor

@SahilCarterr SahilCarterr commented Oct 15, 2024

What does this PR do?

Fixes #9637 resolve Attention Mask Padding Issue for Compatibility with xFormers

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed.
@sayakpaul @yiyixuxu


Code from Issue

from diffusers.models.attention_processor import Attention, XFormersAttnProcessor
import torch

# Initialize the attention processor
attn_processer = XFormersAttnProcessor()

# Create the Attention module
attn = Attention(
    query_dim=256,
    heads=8,
    dim_head=64,
    processor=attn_processer,
).to(device="cuda", dtype=torch.bfloat16)

# Create dummy inputs
q = torch.zeros((2, 350, 256), device="cuda", dtype=torch.bfloat16)
kv = torch.zeros((2, 700, 256), device="cuda", dtype=torch.bfloat16)
attn_mask = torch.zeros((2, 1, 700), device="cuda", dtype=torch.bfloat16)

# Perform the attention operation
out = attn(q, kv, attn_mask)

# Print the output shape
print(out.shape)

Output

torch.Size([2, 350, 256])

Hardware Information

  • GPU: NVIDIA A100
  • Environment: Google Colab

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I left some comments, let me know if it makes sense.

Comment on lines 417 to 435
attention_mask_shape_before = attention_mask.shape[-1]
if attention_mask.dtype == torch.bfloat16 and attention_mask.shape[-1] % 8 != 0:
padded_length = math.ceil(attention_mask.shape[-1] / 8) * 8
mask = torch.zeros(
(attention_mask.shape[0], attention_mask.shape[1], padded_length),
device=attention_mask.device,
dtype=attention_mask.dtype,
)
mask[:, :, : attention_mask.shape[-1]] = attention_mask
attention_mask = mask

assert attention_mask.shape[-1] % 8 == 0, "Attention mask not padded to a multiple of 8"
assert attention_mask[:, :, :attention_mask_shape_before].equal(
attention_mask[:, :, :attention_mask_shape_before]
), "Original values in attention mask are not preserved"

expanded_attention_mask = attention_mask.expand(-1, query_tokens, -1)

assert expanded_attention_mask.shape[1] == query_tokens, "Attention mask expansion for query tokens failed"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in this test, we want to have a check for functional correctness instead of applying the same logic that we're applying within the attention processor class.

So, this means we could first enable xformers attention on the UNet and then do a forward pass and then design our tests accordingly.

@SahilCarterr
Copy link
Contributor Author

can you help to fix this error when i run the test script RuntimeError: expand(CUDABFloat16Type{[16, 1, 1, 278]}, size=[16, 1, 278]): the number of sizes provided (3) must be greater or equal to the number of dimensions in the tensor (4) . @sayakpaul . i have updated the test

@SahilCarterr SahilCarterr closed this by deleting the head repository Oct 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

XFormer fails when passing attention mask while using bfloat and key's sequence length not being a multiple of 8

3 participants