diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index 19ea7729c9d9..5482035b3afb 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -238,7 +238,7 @@ def _get_t5_prompt_embeds( # Chroma requires the attention mask to include one padding token seq_lengths = attention_mask.sum(dim=1) mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1) - attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long() + attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool() prompt_embeds = self.text_encoder( text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device) @@ -246,7 +246,7 @@ def _get_t5_prompt_embeds( dtype = self.text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - attention_mask = attention_mask.to(dtype=dtype, device=device) + attention_mask = attention_mask.to(device=device) _, seq_len, _ = prompt_embeds.shape @@ -605,10 +605,9 @@ def _prepare_attention_mask( # Extend the prompt attention mask to account for image tokens in the final sequence attention_mask = torch.cat( - [attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)], + [attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device, dtype=torch.bool)], dim=1, ) - attention_mask = attention_mask.to(dtype) return attention_mask