Skip to content

Conversation

@akshay-babbar
Copy link
Contributor

@akshay-babbar akshay-babbar commented Aug 30, 2025

Problem

Fixes #12116

Short prompts generate corrupted images due to attention mask dtype conversion bug.

Root Cause

Attention masks converted from bool → float16/bfloat16, but PyTorch's scaled_dot_product_attention requires boolean masks.

Solution

  • Preserve boolean dtype in _get_t5_prompt_embeds
  • Remove dtype conversions in _prepare_attention_mask
  • Fix both positive and negative attention masks

Testing

✅ Added @slow unit tests for dtype preservation
✅ Verified fix with prompts: "man", "cat"
✅ All tests pass locally

Please review when you have a chance. Thank you for your time and consideration!

- Convert attention masks to bool and prevent dtype corruption
- Fix both positive and negative mask handling in _get_t5_prompt_embeds
- Remove float conversion in _prepare_attention_mask method

Fixes huggingface#12116
@akshay-babbar akshay-babbar changed the title Fix #12116: preserve boolean dtype for attention masks in ChromaPipeline Fix #12116 : preserve boolean dtype for attention masks in ChromaPipeline Aug 30, 2025
@akshay-babbar akshay-babbar changed the title Fix #12116 : preserve boolean dtype for attention masks in ChromaPipeline Fix #12116 preserve boolean dtype for attention masks in ChromaPipeline Aug 30, 2025
@akshay-babbar akshay-babbar changed the title Fix #12116 preserve boolean dtype for attention masks in ChromaPipeline Fix #12116: preserve boolean dtype for attention masks in ChromaPipeline Aug 30, 2025
@akshay-babbar
Copy link
Contributor Author

akshay-babbar commented Sep 1, 2025

hello @DN6 @yiyixuxu , can you please review this PR and share feedback!
Thanks!

@akshay-babbar
Copy link
Contributor Author

hello @DN6, just checking in to see if you’ve had a chance to look at the above PR. If you’re not the right person or are keeping busy, would you mind pointing me to someone who could review it?

Thanks!

@yiyixuxu yiyixuxu requested a review from DN6 September 10, 2025 19:51
@yiyixuxu
Copy link
Collaborator

thanks @akshay-babbar
can you show outputs before/after?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@akshay-babbar
Copy link
Contributor Author

akshay-babbar commented Sep 13, 2025

hello @yiyixuxu @DN6

Thanks for the response! I'm new to diffusers, so still learning best practices through the docs and codebase. Do let me know if there are any issues with my changes.

I used these these 3 prompts - [man,king, doctor]

negative prompt used - blurry, low quality, naked, NSFW, nude, deformed

Below are the results!

Please review and let me know your feedback and any next steps.

Thanks!

Before

Man

Before_v2_change_1_man

Doctor

Before_v2_change_2_doctor

King

Before_v2_change_4_king

After Code Changes

Man

After_v2_change_1_man

Doctor

After_v2_change_2_doctor

King

After_v2_change_4_king

@akshay-babbar
Copy link
Contributor Author

Hi @DN6, I completely understand you must be busy, but I’d really appreciate any thoughts/feedback you might have on the PR and the results I’ve shown above whenever you get the chance.
Thanks!
CC: @yiyixuxu

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @akshay-babbar. Minor changes requested.

)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask.clone()
attention_mask = attention_mask.bool()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think we need the type conversion here. We can just convert the final mask to bool.

assert (output_height, output_width) == (expected_height, expected_width)


class ChromaPipelineAttentionMaskTests(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dedicated tests aren't needed here. The existing tests should catch changes in numerical output if they are significant.

@akshay-babbar
Copy link
Contributor Author

Hello @DN6 , thanks for the review! I have made the changes, let me know your feedback and next steps.

Thanks!

@akshay-babbar akshay-babbar requested a review from DN6 September 26, 2025 02:38
@DN6 DN6 merged commit 0a15111 into huggingface:main Sep 29, 2025
10 checks passed
@dxqb
Copy link

dxqb commented Sep 29, 2025

Sorry for not seeing this earlier, I was just notified by the commit.

Are you sure that passing the attention mask as boolean...
https://github.com/akshay-babbar/diffusers/blob/58557d44893802acb2f68b1334bfee9c0e726bea/src/diffusers/pipelines/chroma/pipeline_chroma.py#L249
...to the text encoder is okay?

It seems to work I guess, but it's documented to require a FloatTensor:
https://github.com/huggingface/transformers/blob/cd74917ffc3e8f84e4a886052c5ab32b7ac623cc/src/transformers/models/t5/modeling_t5.py#L1899

Textencoder: Float between 0 and 1
Attention processor: either bool between 0 and 1, or Float between -inf and +inf

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Attention masking in Chroma pipeline

5 participants