Skip to content

Commit 0a15111

Browse files
akshay-babbarDN6
andauthored
Fix huggingface#12116: preserve boolean dtype for attention masks in ChromaPipeline (huggingface#12263)
* fix: preserve boolean dtype for attention masks in ChromaPipeline - Convert attention masks to bool and prevent dtype corruption - Fix both positive and negative mask handling in _get_t5_prompt_embeds - Remove float conversion in _prepare_attention_mask method Fixes huggingface#12116 * test: add ChromaPipeline attention mask dtype tests * test: add slow ChromaPipeline attention mask tests * chore: removed comments * refactor: removing redundant type conversion * Remove dedicated dtype tests as per feedback --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent 19085ac commit 0a15111

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

src/diffusers/pipelines/chroma/pipeline_chroma.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,15 +238,15 @@ def _get_t5_prompt_embeds(
238238
# Chroma requires the attention mask to include one padding token
239239
seq_lengths = attention_mask.sum(dim=1)
240240
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
241-
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
241+
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool()
242242

243243
prompt_embeds = self.text_encoder(
244244
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
245245
)[0]
246246

247247
dtype = self.text_encoder.dtype
248248
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
249-
attention_mask = attention_mask.to(dtype=dtype, device=device)
249+
attention_mask = attention_mask.to(device=device)
250250

251251
_, seq_len, _ = prompt_embeds.shape
252252

@@ -605,10 +605,9 @@ def _prepare_attention_mask(
605605

606606
# Extend the prompt attention mask to account for image tokens in the final sequence
607607
attention_mask = torch.cat(
608-
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
608+
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device, dtype=torch.bool)],
609609
dim=1,
610610
)
611-
attention_mask = attention_mask.to(dtype)
612611

613612
return attention_mask
614613

0 commit comments

Comments
 (0)