Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2543,7 +2543,9 @@ def __call__(
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)

hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
Expand Down Expand Up @@ -2776,7 +2778,9 @@ def __call__(
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)

hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

Expand Down
20 changes: 15 additions & 5 deletions src/diffusers/models/transformers/transformer_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,15 +250,21 @@ def forward(
hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
joint_attention_kwargs = joint_attention_kwargs or {}

if attention_mask is not None:
attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]

attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
**joint_attention_kwargs,
)

Expand Down Expand Up @@ -312,6 +318,7 @@ def forward(
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
temb_img, temb_txt = temb[:, :6], temb[:, 6:]
Expand All @@ -321,11 +328,15 @@ def forward(
encoder_hidden_states, emb=temb_txt
)
joint_attention_kwargs = joint_attention_kwargs or {}
if attention_mask is not None:
attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]

# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
**joint_attention_kwargs,
)

Expand Down Expand Up @@ -570,6 +581,7 @@ def forward(
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
attention_mask: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
Expand Down Expand Up @@ -659,11 +671,7 @@ def forward(
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask
)

else:
Expand All @@ -672,6 +680,7 @@ def forward(
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
joint_attention_kwargs=joint_attention_kwargs,
)

Expand Down Expand Up @@ -704,6 +713,7 @@ def forward(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
joint_attention_kwargs=joint_attention_kwargs,
)

Expand Down
65 changes: 61 additions & 4 deletions src/diffusers/pipelines/chroma/pipeline_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,14 +235,18 @@ def _get_t5_prompt_embeds(

dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
attention_mask = attention_mask.to(dtype=dtype, device=device)

_, seq_len, _ = prompt_embeds.shape

# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

return prompt_embeds
attention_mask = attention_mask.repeat(1, num_images_per_prompt)
attention_mask = attention_mask.view(batch_size * num_images_per_prompt, seq_len)

return prompt_embeds, attention_mask

def encode_prompt(
self,
Expand All @@ -252,6 +256,8 @@ def encode_prompt(
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
do_classifier_free_guidance: bool = True,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should include negative_prompt_embeds (and prompt_attention_mask, negative_prompt_attention mask) in docs. I missed it in the pipeline PR

Expand Down Expand Up @@ -293,7 +299,7 @@ def encode_prompt(
batch_size = prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
Expand Down Expand Up @@ -323,20 +329,28 @@ def encode_prompt(
" the batch size of `prompt`."
)

negative_prompt_embeds = self._get_t5_prompt_embeds(
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)

negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)

if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, text_ids, negative_prompt_embeds, negative_text_ids
return (
prompt_embeds,
text_ids,
prompt_attention_mask,
negative_prompt_embeds,
negative_text_ids,
negative_prompt_attention_mask,
)

# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt):
Expand Down Expand Up @@ -534,6 +548,31 @@ def prepare_latents(

return latents, latent_image_ids

def _prepare_attention_mask(
self, batch_size, sequence_length, dtype, prompt_attention_mask=None, negative_prompt_attention_mask=None
):
attention_mask = None
if prompt_attention_mask is not None:
# Extend the prompt attention mask to account for image tokens in the final sequence
attention_mask = torch.cat(
[prompt_attention_mask, torch.ones(batch_size, sequence_length, device=prompt_attention_mask.device)],
dim=1,
)
attention_mask = attention_mask.to(dtype)

negative_attention_mask = None
if negative_prompt_attention_mask is not None:
negative_attention_mask = torch.cat(
[
negative_prompt_attention_mask,
torch.ones(batch_size, sequence_length, device=negative_prompt_attention_mask.device),
],
dim=1,
)
negative_attention_mask = negative_attention_mask.to(dtype)

return attention_mask, negative_attention_mask

@property
def guidance_scale(self):
return self._guidance_scale
Expand Down Expand Up @@ -578,6 +617,8 @@ def __call__(
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -704,13 +745,17 @@ def __call__(
(
prompt_embeds,
text_ids,
prompt_attention_mask,
negative_prompt_embeds,
negative_text_ids,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
do_classifier_free_guidance=self.do_classifier_free_guidance,
device=device,
num_images_per_prompt=num_images_per_prompt,
Expand All @@ -730,6 +775,7 @@ def __call__(
generator,
latents,
)

# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
Expand All @@ -740,6 +786,15 @@ def __call__(
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)

attention_mask, negative_attention_mask = self._prepare_attention_mask(
batch_size=latents.shape[0],
sequence_length=image_seq_len,
dtype=latents.dtype,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)

timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
Expand Down Expand Up @@ -801,6 +856,7 @@ def __call__(
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
attention_mask=attention_mask,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
Expand All @@ -814,6 +870,7 @@ def __call__(
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_image_ids,
attention_mask=negative_attention_mask,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
Expand Down