Skip to content
8 changes: 4 additions & 4 deletions src/diffusers/pipelines/chroma/pipeline_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,19 +233,20 @@ def _get_t5_prompt_embeds(
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask.clone()
attention_mask = attention_mask.bool()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think we need the type conversion here. We can just convert the final mask to bool.


# Chroma requires the attention mask to include one padding token
seq_lengths = attention_mask.sum(dim=1)
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool()

prompt_embeds = self.text_encoder(
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
)[0]

dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
attention_mask = attention_mask.to(dtype=dtype, device=device)
attention_mask = attention_mask.to(device=device)

_, seq_len, _ = prompt_embeds.shape

Expand Down Expand Up @@ -580,10 +581,9 @@ def _prepare_attention_mask(

# Extend the prompt attention mask to account for image tokens in the final sequence
attention_mask = torch.cat(
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device, dtype=torch.bool)],
dim=1,
)
attention_mask = attention_mask.to(dtype)

return attention_mask

Expand Down
24 changes: 24 additions & 0 deletions tests/pipelines/chroma/test_pipeline_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from transformers import AutoTokenizer, T5EncoderModel

from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import slow

from ...testing_utils import torch_device
from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist
Expand Down Expand Up @@ -158,3 +159,26 @@ def test_chroma_image_output_shape(self):
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)


class ChromaPipelineAttentionMaskTests(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dedicated tests aren't needed here. The existing tests should catch changes in numerical output if they are significant.

def setUp(self):
self.pipe = ChromaPipeline.from_pretrained(
"lodestones/Chroma1-Base",
torch_dtype=torch.float16,
)

@slow
def test_attention_mask_dtype_is_bool_short_prompt(self):
prompt_embeds, attn_mask = self.pipe._get_t5_prompt_embeds("man")
self.assertEqual(attn_mask.dtype, torch.bool, f"Expected bool, got {attn_mask.dtype}")
self.assertGreater(prompt_embeds.shape[0], 0)
self.assertGreater(prompt_embeds.shape[1], 0)

@slow
def test_attention_mask_dtype_is_bool_long_prompt(self):
long_prompt = "a detailed portrait of a man standing in a garden with flowers and trees"
prompt_embeds, attn_mask = self.pipe._get_t5_prompt_embeds(long_prompt)
self.assertEqual(attn_mask.dtype, torch.bool, f"Expected bool, got {attn_mask.dtype}")
self.assertGreater(prompt_embeds.shape[0], 0)
self.assertGreater(prompt_embeds.shape[1], 0)