Skip to content

Commit 7bb6e4a

Browse files
committed
fix: preserve boolean dtype for attention masks in ChromaPipeline
- Convert attention masks to bool and prevent dtype corruption - Fix both positive and negative mask handling in _get_t5_prompt_embeds - Remove float conversion in _prepare_attention_mask method Fixes #12116
1 parent 9b721db commit 7bb6e4a

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/diffusers/pipelines/chroma/pipeline_chroma.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,19 +233,20 @@ def _get_t5_prompt_embeds(
233233
)
234234
text_input_ids = text_inputs.input_ids
235235
attention_mask = text_inputs.attention_mask.clone()
236+
attention_mask = attention_mask.bool() # fix here mine
236237

237238
# Chroma requires the attention mask to include one padding token
238239
seq_lengths = attention_mask.sum(dim=1)
239240
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
240-
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
241+
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool()
241242

242243
prompt_embeds = self.text_encoder(
243244
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
244245
)[0]
245246

246247
dtype = self.text_encoder.dtype
247248
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
248-
attention_mask = attention_mask.to(dtype=dtype, device=device)
249+
attention_mask = attention_mask.to(device=device)
249250

250251
_, seq_len, _ = prompt_embeds.shape
251252

@@ -580,10 +581,10 @@ def _prepare_attention_mask(
580581

581582
# Extend the prompt attention mask to account for image tokens in the final sequence
582583
attention_mask = torch.cat(
583-
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
584+
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device, dtype=torch.bool)],
584585
dim=1,
585586
)
586-
attention_mask = attention_mask.to(dtype)
587+
# attention_mask = attention_mask.to(dtype)
587588

588589
return attention_mask
589590

0 commit comments

Comments
 (0)