Skip to content

Commit 8f405ed

Browse files
authored
Merge branch 'main' into cache-non-lora-outputs
2 parents 1569fca + 0a15111 commit 8f405ed

File tree

2 files changed

+3
-9
lines changed

2 files changed

+3
-9
lines changed

src/diffusers/pipelines/chroma/pipeline_chroma.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,15 +238,15 @@ def _get_t5_prompt_embeds(
238238
# Chroma requires the attention mask to include one padding token
239239
seq_lengths = attention_mask.sum(dim=1)
240240
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
241-
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
241+
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool()
242242

243243
prompt_embeds = self.text_encoder(
244244
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
245245
)[0]
246246

247247
dtype = self.text_encoder.dtype
248248
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
249-
attention_mask = attention_mask.to(dtype=dtype, device=device)
249+
attention_mask = attention_mask.to(device=device)
250250

251251
_, seq_len, _ = prompt_embeds.shape
252252

@@ -605,10 +605,9 @@ def _prepare_attention_mask(
605605

606606
# Extend the prompt attention mask to account for image tokens in the final sequence
607607
attention_mask = torch.cat(
608-
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
608+
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device, dtype=torch.bool)],
609609
dim=1,
610610
)
611-
attention_mask = attention_mask.to(dtype)
612611

613612
return attention_mask
614613

tests/models/test_modeling_common.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1793,11 +1793,6 @@ def test_group_offloading_with_disk(self, offload_type, record_stream, atol=1e-5
17931793
if not self.model_class._supports_group_offloading:
17941794
pytest.skip("Model does not support group offloading.")
17951795

1796-
if self.model_class.__name__ == "QwenImageTransformer2DModel":
1797-
pytest.skip(
1798-
"QwenImageTransformer2DModel doesn't support group offloading with disk. Needs to be investigated."
1799-
)
1800-
18011796
def _has_generator_arg(model):
18021797
sig = inspect.signature(model.forward)
18031798
params = sig.parameters

0 commit comments

Comments
 (0)