Skip to content
7 changes: 3 additions & 4 deletions src/diffusers/pipelines/chroma/pipeline_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,15 +238,15 @@ def _get_t5_prompt_embeds(
# Chroma requires the attention mask to include one padding token
seq_lengths = attention_mask.sum(dim=1)
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool()

prompt_embeds = self.text_encoder(
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
)[0]

dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
attention_mask = attention_mask.to(dtype=dtype, device=device)
attention_mask = attention_mask.to(device=device)

_, seq_len, _ = prompt_embeds.shape

Expand Down Expand Up @@ -605,10 +605,9 @@ def _prepare_attention_mask(

# Extend the prompt attention mask to account for image tokens in the final sequence
attention_mask = torch.cat(
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device, dtype=torch.bool)],
dim=1,
)
attention_mask = attention_mask.to(dtype)

return attention_mask

Expand Down