-
Couldn't load subscription status.
- Fork 6.4k
Fix #12116: preserve boolean dtype for attention masks in ChromaPipeline #12263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix #12116: preserve boolean dtype for attention masks in ChromaPipeline #12263
Conversation
- Convert attention masks to bool and prevent dtype corruption - Fix both positive and negative mask handling in _get_t5_prompt_embeds - Remove float conversion in _prepare_attention_mask method Fixes huggingface#12116
|
hello @DN6, just checking in to see if you’ve had a chance to look at the above PR. If you’re not the right person or are keeping busy, would you mind pointing me to someone who could review it? Thanks! |
|
thanks @akshay-babbar |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @akshay-babbar. Minor changes requested.
| ) | ||
| text_input_ids = text_inputs.input_ids | ||
| attention_mask = text_inputs.attention_mask.clone() | ||
| attention_mask = attention_mask.bool() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't think we need the type conversion here. We can just convert the final mask to bool.
| assert (output_height, output_width) == (expected_height, expected_width) | ||
|
|
||
|
|
||
| class ChromaPipelineAttentionMaskTests(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dedicated tests aren't needed here. The existing tests should catch changes in numerical output if they are significant.
|
Hello @DN6 , thanks for the review! I have made the changes, let me know your feedback and next steps. Thanks! |
|
Sorry for not seeing this earlier, I was just notified by the commit. Are you sure that passing the attention mask as boolean... It seems to work I guess, but it's documented to require a FloatTensor: Textencoder: Float between 0 and 1 |






Problem
Fixes #12116
Short prompts generate corrupted images due to attention mask dtype conversion bug.
Root Cause
Attention masks converted from bool → float16/bfloat16, but PyTorch's scaled_dot_product_attention requires boolean masks.
Solution
_get_t5_prompt_embeds_prepare_attention_maskTesting
✅ Added @slow unit tests for dtype preservation
✅ Verified fix with prompts: "man", "cat"
✅ All tests pass locally
Please review when you have a chance. Thank you for your time and consideration!