Skip to content
268 changes: 237 additions & 31 deletions src/diffusers/models/transformers/transformer_cogview4.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ def __init__(self, embedding_dim: int, dim: int) -> None:
def forward(
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states = self.norm(hidden_states)
norm_encoder_hidden_states = self.norm_context(encoder_hidden_states)
dtype = hidden_states.dtype
norm_hidden_states = self.norm(hidden_states).to(dtype=dtype)
norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype)

emb = self.linear(temb)
(
Expand Down Expand Up @@ -124,62 +125,263 @@ def __call__(
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
image_rotary_emb: Optional[tuple[torch.Tensor, torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]]] = None,
) -> torch.Tensor:
"""
Args:
attn (`Attention`):
The attention module.
hidden_states (`torch.Tensor`):
The input hidden states.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states for cross-attention.
attention_mask (`Dict[str, torch.Tensor]`, *optional*):
Dictionary containing mask configurations:
- `batch_flag` (`torch.Tensor`, *optional*):
Values from 0 to n-1 indicating which samples belong to the same batch.
Samples with the same batch_flag are packed together.
Example: [0, 1, 1, 2, 2] means sample 0 forms batch0, samples 1-2 form batch1, and samples 3-4 form batch2.
If None, no packing is used.
- `text_embedding_attn_mask` (`torch.Tensor`, *optional*):
Mask for text tokens where 0 indicates pad token and 1 indicates non-pad token.
If None, full attention is used for all text tokens.
- `latent_embedding_attn_mask` (`torch.Tensor`, *optional*):
Mask for latent tokens where 0 indicates pad token and 1 indicates non-pad token.
If None, full attention is used for all latent tokens.
Note: the shape of latent_embedding_attn_mask is (batch_size, num_latent_tokens).
image_rotary_emb (`torch.Tensor` or `list[torch.Tensor]`, *optional*):
The rotary embedding for the image part of the input.

Returns:
`Tuple[torch.Tensor, torch.Tensor]`: The processed hidden states for both image and text streams.
"""

# Get dimensions and device info
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
batch_size, image_seq_length, embed_dim = hidden_states.shape
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
dtype = encoder_hidden_states.dtype
device = encoder_hidden_states.device
latent_hidden_states = hidden_states
# Combine text and image streams for joint processing
mixed_hidden_states = torch.cat([encoder_hidden_states, latent_hidden_states], dim=1)

# Initialize mask variables
text_attention_mask, latent_attention_mask, batch_flag = None, None, None

# 1. Construct attention mask and maybe packing input
if attention_mask is not None:
# Extract mask components from the dictionary
text_attention_mask = attention_mask.get("text_embedding_attn_mask", None)
latent_attention_mask = attention_mask.get("latent_embedding_attn_mask", None)
batch_flag = attention_mask.get("batch_flag", None)

# Create default masks if not provided
if text_attention_mask is None:
text_attention_mask = torch.ones((batch_size, text_seq_length), dtype=torch.int32, device=device)
if latent_attention_mask is None:
latent_attention_mask = torch.ones((batch_size, image_seq_length), dtype=torch.int32, device=device)

# Validate mask shapes and types
assert text_attention_mask.dim() == 2, "the shape of text_attention_mask should be (batch_size, text_seq_length)"
assert text_attention_mask.dtype == torch.int32, "the dtype of text_attention_mask should be torch.int32"
assert latent_attention_mask.dim() == 2, "the shape of latent_attention_mask should be (batch_size, num_latent_tokens)"
assert latent_attention_mask.dtype == torch.int32, "the dtype of latent_attention_mask should be torch.int32"

# Create combined mask for text and image tokens
mixed_attention_mask = torch.ones(
(batch_size, text_seq_length + image_seq_length), dtype=torch.int32, device=device
)
mixed_attention_mask[:, :text_seq_length] = text_attention_mask
mixed_attention_mask[:, text_seq_length:] = latent_attention_mask

# Convert mask to attention matrix format (where 1 means attend, 0 means don't attend)
mixed_attention_mask_input = mixed_attention_mask.unsqueeze(2).to(dtype=dtype)
attention_mask_matrix = mixed_attention_mask_input @ mixed_attention_mask_input.transpose(1, 2)

# Handle batch packing if enabled
if batch_flag is not None:
assert batch_flag.dim() == 1
# Determine packed batch size based on batch_flag
packing_batch_size = torch.max(batch_flag).item() + 1

# Calculate actual sequence lengths for each sample based on masks
text_seq_length = torch.sum(text_attention_mask, dim=1)
latent_seq_length = torch.sum(latent_attention_mask, dim=1)
mixed_seq_length = text_seq_length + latent_seq_length

# Calculate packed sequence lengths for each packed batch
text_seq_length_packed = [
torch.sum(text_attention_mask[batch_flag == batch_idx]).item()
for batch_idx in range(packing_batch_size)
]
latent_seq_length_packed = [
torch.sum(latent_attention_mask[batch_flag == batch_idx]).item()
for batch_idx in range(packing_batch_size)
]
mixed_seq_length_packed = [
torch.sum(mixed_attention_mask[batch_flag == batch_idx]).item()
for batch_idx in range(packing_batch_size)
]

assert len(mixed_seq_length_packed) == packing_batch_size

# Pack sequences by removing padding tokens
mixed_attention_mask_flatten = mixed_attention_mask.flatten(0, 1)
mixed_hidden_states_flatten = mixed_hidden_states.flatten(0, 1)
mixed_hidden_states_unpad = mixed_hidden_states_flatten[mixed_attention_mask_flatten == 1]
assert torch.sum(mixed_seq_length) == mixed_hidden_states_unpad.shape[0]

# Split the unpadded sequence into packed batches
mixed_hidden_states_packed = torch.split(mixed_hidden_states_unpad, mixed_seq_length_packed)

# Re-pad to create packed batches with right-side padding
mixed_hidden_states_packed_padded = torch.nn.utils.rnn.pad_sequence(
mixed_hidden_states_packed,
batch_first=True,
padding_value=0.0,
padding_side="right",
)

# 1. QKV projections
# Create attention mask for packed batches
l = mixed_hidden_states_packed_padded.shape[1]
attention_mask_matrix = torch.zeros(
(packing_batch_size, l, l),
dtype=dtype,
device=device,
)

# Fill attention mask with block diagonal matrices
# This ensures that tokens can only attend to other tokens within the same original sample
for idx, mask in enumerate(attention_mask_matrix):
seq_lengths = mixed_seq_length[batch_flag == idx]
offset = 0
for length in seq_lengths:
# Create a block of 1s for each sample in the packed batch
mask[offset : offset + length, offset : offset + length] = 1
offset += length

attention_mask_matrix = attention_mask_matrix.to(dtype=torch.bool)
attention_mask_matrix = attention_mask_matrix.unsqueeze(1) # Add attention head dim
attention_mask = attention_mask_matrix

# Prepare hidden states for attention computation
if batch_flag is None:
# If no packing, just combine text and image tokens
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
else:
# If packing, use the packed sequence
hidden_states = mixed_hidden_states_packed_padded

# 2. QKV projections - convert hidden states to query, key, value
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

# Reshape for multi-head attention: [batch, seq_len, heads*dim] -> [batch, heads, seq_len, dim]
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)

# 2. QK normalization
# 3. QK normalization - apply layer norm to queries and keys if configured
if attn.norm_q is not None:
query = attn.norm_q(query)
query = attn.norm_q(query).to(dtype=dtype)
if attn.norm_k is not None:
key = attn.norm_k(key)
key = attn.norm_k(key).to(dtype=dtype)

# 3. Rotational positional embeddings applied to latent stream
# 4. Apply rotary positional embeddings to image tokens only
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb

query[:, :, text_seq_length:, :] = apply_rotary_emb(
query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
key[:, :, text_seq_length:, :] = apply_rotary_emb(
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)

# 4. Attention
if attention_mask is not None:
text_attention_mask = attention_mask.float().to(query.device)
actual_text_seq_length = text_attention_mask.size(1)
new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device)
new_attention_mask[:, :actual_text_seq_length] = text_attention_mask
new_attention_mask = new_attention_mask.unsqueeze(2)
attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2)
attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype)
if batch_flag is None:
# Apply RoPE only to image tokens (after text tokens)
query[:, :, text_seq_length:, :] = apply_rotary_emb(
query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
key[:, :, text_seq_length:, :] = apply_rotary_emb(
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
else:
# For packed batches, need to carefully apply RoPE to appropriate tokens
assert query.shape[0] == packing_batch_size
assert key.shape[0] == packing_batch_size
assert len(image_rotary_emb) == batch_size

rope_idx = 0
for idx in range(packing_batch_size):
offset = 0
# Get text and image sequence lengths for samples in this packed batch
text_seq_length_bi = text_seq_length[batch_flag == idx]
latent_seq_length_bi = latent_seq_length[batch_flag == idx]

# Apply RoPE to each image segment in the packed sequence
for tlen, llen in zip(text_seq_length_bi, latent_seq_length_bi):
mlen = tlen + llen
# Apply RoPE only to image tokens (after text tokens)
query[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb(
query[idx, :, offset + tlen : offset + mlen, :],
image_rotary_emb[rope_idx],
use_real_unbind_dim=-2,
)
key[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb(
key[idx, :, offset + tlen : offset + mlen, :],
image_rotary_emb[rope_idx],
use_real_unbind_dim=-2,
)
offset += mlen
rope_idx += 1

hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

# Reshape back: [batch, heads, seq_len, dim] -> [batch, seq_len, heads*dim]
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.type_as(query)

# 5. Output projection
# 5. Output projection - project attention output to model dimension
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)

encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
# Split the output back into text and image streams
if batch_flag is None:
# Simple split for non-packed case
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
else:
# For packed case: need to unpack, split text/image, then restore to original shapes
# First, unpad the sequence based on the packed sequence lengths
hidden_states_unpad = torch.nn.utils.rnn.unpad_sequence(
hidden_states,
lengths=torch.tensor(mixed_seq_length_packed),
batch_first=True,
)
# Concatenate all unpadded sequences
hidden_states_flatten = torch.cat(hidden_states_unpad, dim=0)
# Split by original sample sequence lengths
hidden_states_unpack = torch.split(hidden_states_flatten, mixed_seq_length.tolist())
assert len(hidden_states_unpack) == batch_size

# Further split each sample's sequence into text and image parts
hidden_states_unpack = [
torch.split(h, [tlen, llen])
for h, tlen, llen in zip(hidden_states_unpack, text_seq_length, latent_seq_length)
]
# Separate text and image sequences
encoder_hidden_states_unpad = [h[0] for h in hidden_states_unpack]
hidden_states_unpad = [h[1] for h in hidden_states_unpack]

# Update the original tensors with the processed values, respecting the attention masks
for idx in range(batch_size):
# Place unpacked text tokens back in the encoder_hidden_states tensor
encoder_hidden_states[idx][text_attention_mask[idx] == 1] = encoder_hidden_states_unpad[idx]
# Place unpacked image tokens back in the latent_hidden_states tensor
latent_hidden_states[idx][latent_attention_mask[idx] == 1] = hidden_states_unpad[idx]

# Update the output hidden states
hidden_states = latent_hidden_states

return hidden_states, encoder_hidden_states


Expand Down Expand Up @@ -402,6 +604,7 @@ def forward(
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[tuple[torch.Tensor, torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]]] = None,
**kwargs,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
if attention_kwargs is not None:
Expand All @@ -422,7 +625,10 @@ def forward(
batch_size, num_channels, height, width = hidden_states.shape

# 1. RoPE
image_rotary_emb = self.rope(hidden_states)
if image_rotary_emb is None:
image_rotary_emb = self.rope(hidden_states)
else:
image_rotary_emb = image_rotary_emb

# 2. Patch & Timestep embeddings
p = self.config.patch_size
Expand Down