diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index 41c4cbbf97c7..aef368f91ac0 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -73,8 +73,9 @@ def __init__(self, embedding_dim: int, dim: int) -> None: def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - norm_hidden_states = self.norm(hidden_states) - norm_encoder_hidden_states = self.norm_context(encoder_hidden_states) + dtype = hidden_states.dtype + norm_hidden_states = self.norm(hidden_states).to(dtype=dtype) + norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype) emb = self.linear(temb) ( @@ -111,8 +112,11 @@ def forward( class CogView4AttnProcessor: """ - Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary embedding on query and key vectors, but does not include spatial normalization. + + The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size, + text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token. """ def __init__(self): @@ -125,8 +129,10 @@ def __call__( hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + dtype = encoder_hidden_states.dtype + batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape batch_size, image_seq_length, embed_dim = hidden_states.shape hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) @@ -142,9 +148,9 @@ def __call__( # 2. QK normalization if attn.norm_q is not None: - query = attn.norm_q(query) + query = attn.norm_q(query).to(dtype=dtype) if attn.norm_k is not None: - key = attn.norm_k(key) + key = attn.norm_k(key).to(dtype=dtype) # 3. Rotational positional embeddings applied to latent stream if image_rotary_emb is not None: @@ -159,13 +165,14 @@ def __call__( # 4. Attention if attention_mask is not None: - text_attention_mask = attention_mask.float().to(query.device) - actual_text_seq_length = text_attention_mask.size(1) - new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device) - new_attention_mask[:, :actual_text_seq_length] = text_attention_mask - new_attention_mask = new_attention_mask.unsqueeze(2) - attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2) - attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype) + text_attn_mask = attention_mask + assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)" + text_attn_mask = text_attn_mask.float().to(query.device) + mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device) + mix_attn_mask[:, :text_seq_length] = text_attn_mask + mix_attn_mask = mix_attn_mask.unsqueeze(2) + attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2) + attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False @@ -183,9 +190,276 @@ def __call__( return hidden_states, encoder_hidden_states +class CogView4TrainingAttnProcessor: + """ + Training Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary + embedding on query and key vectors, but does not include spatial normalization. + + This processor differs from CogView4AttnProcessor in several important ways: + 1. It supports attention masking with variable sequence lengths for multi-resolution training + 2. It unpacks and repacks sequences for efficient training with variable sequence lengths when batch_flag is + provided + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + latent_attn_mask: Optional[torch.Tensor] = None, + text_attn_mask: Optional[torch.Tensor] = None, + batch_flag: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[ + Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]] + ] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + attn (`Attention`): + The attention module. + hidden_states (`torch.Tensor`): + The input hidden states. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states for cross-attention. + latent_attn_mask (`torch.Tensor`, *optional*): + Mask for latent tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full + attention is used for all latent tokens. Note: the shape of latent_attn_mask is (batch_size, + num_latent_tokens). + text_attn_mask (`torch.Tensor`, *optional*): + Mask for text tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full attention + is used for all text tokens. + batch_flag (`torch.Tensor`, *optional*): + Values from 0 to n-1 indicating which samples belong to the same batch. Samples with the same + batch_flag are packed together. Example: [0, 1, 1, 2, 2] means sample 0 forms batch0, samples 1-2 form + batch1, and samples 3-4 form batch2. If None, no packing is used. + image_rotary_emb (`Tuple[torch.Tensor, torch.Tensor]` or `list[Tuple[torch.Tensor, torch.Tensor]]`, *optional*): + The rotary embedding for the image part of the input. + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: The processed hidden states for both image and text streams. + """ + + # Get dimensions and device info + batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape + batch_size, image_seq_length, embed_dim = hidden_states.shape + dtype = encoder_hidden_states.dtype + device = encoder_hidden_states.device + latent_hidden_states = hidden_states + # Combine text and image streams for joint processing + mixed_hidden_states = torch.cat([encoder_hidden_states, latent_hidden_states], dim=1) + + # 1. Construct attention mask and maybe packing input + # Create default masks if not provided + if text_attn_mask is None: + text_attn_mask = torch.ones((batch_size, text_seq_length), dtype=torch.int32, device=device) + if latent_attn_mask is None: + latent_attn_mask = torch.ones((batch_size, image_seq_length), dtype=torch.int32, device=device) + + # Validate mask shapes and types + assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)" + assert text_attn_mask.dtype == torch.int32, "the dtype of text_attn_mask should be torch.int32" + assert latent_attn_mask.dim() == 2, "the shape of latent_attn_mask should be (batch_size, num_latent_tokens)" + assert latent_attn_mask.dtype == torch.int32, "the dtype of latent_attn_mask should be torch.int32" + + # Create combined mask for text and image tokens + mixed_attn_mask = torch.ones( + (batch_size, text_seq_length + image_seq_length), dtype=torch.int32, device=device + ) + mixed_attn_mask[:, :text_seq_length] = text_attn_mask + mixed_attn_mask[:, text_seq_length:] = latent_attn_mask + + # Convert mask to attention matrix format (where 1 means attend, 0 means don't attend) + mixed_attn_mask_input = mixed_attn_mask.unsqueeze(2).to(dtype=dtype) + attn_mask_matrix = mixed_attn_mask_input @ mixed_attn_mask_input.transpose(1, 2) + + # Handle batch packing if enabled + if batch_flag is not None: + assert batch_flag.dim() == 1 + # Determine packed batch size based on batch_flag + packing_batch_size = torch.max(batch_flag).item() + 1 + + # Calculate actual sequence lengths for each sample based on masks + text_seq_length = torch.sum(text_attn_mask, dim=1) + latent_seq_length = torch.sum(latent_attn_mask, dim=1) + mixed_seq_length = text_seq_length + latent_seq_length + + # Calculate packed sequence lengths for each packed batch + mixed_seq_length_packed = [ + torch.sum(mixed_attn_mask[batch_flag == batch_idx]).item() for batch_idx in range(packing_batch_size) + ] + + assert len(mixed_seq_length_packed) == packing_batch_size + + # Pack sequences by removing padding tokens + mixed_attn_mask_flatten = mixed_attn_mask.flatten(0, 1) + mixed_hidden_states_flatten = mixed_hidden_states.flatten(0, 1) + mixed_hidden_states_unpad = mixed_hidden_states_flatten[mixed_attn_mask_flatten == 1] + assert torch.sum(mixed_seq_length) == mixed_hidden_states_unpad.shape[0] + + # Split the unpadded sequence into packed batches + mixed_hidden_states_packed = torch.split(mixed_hidden_states_unpad, mixed_seq_length_packed) + + # Re-pad to create packed batches with right-side padding + mixed_hidden_states_packed_padded = torch.nn.utils.rnn.pad_sequence( + mixed_hidden_states_packed, + batch_first=True, + padding_value=0.0, + padding_side="right", + ) + + # Create attention mask for packed batches + l = mixed_hidden_states_packed_padded.shape[1] + attn_mask_matrix = torch.zeros( + (packing_batch_size, l, l), + dtype=dtype, + device=device, + ) + + # Fill attention mask with block diagonal matrices + # This ensures that tokens can only attend to other tokens within the same original sample + for idx, mask in enumerate(attn_mask_matrix): + seq_lengths = mixed_seq_length[batch_flag == idx] + offset = 0 + for length in seq_lengths: + # Create a block of 1s for each sample in the packed batch + mask[offset : offset + length, offset : offset + length] = 1 + offset += length + + attn_mask_matrix = attn_mask_matrix.to(dtype=torch.bool) + attn_mask_matrix = attn_mask_matrix.unsqueeze(1) # Add attention head dim + attention_mask = attn_mask_matrix + + # Prepare hidden states for attention computation + if batch_flag is None: + # If no packing, just combine text and image tokens + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + else: + # If packing, use the packed sequence + hidden_states = mixed_hidden_states_packed_padded + + # 2. QKV projections - convert hidden states to query, key, value + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + # Reshape for multi-head attention: [batch, seq_len, heads*dim] -> [batch, heads, seq_len, dim] + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + # 3. QK normalization - apply layer norm to queries and keys if configured + if attn.norm_q is not None: + query = attn.norm_q(query).to(dtype=dtype) + if attn.norm_k is not None: + key = attn.norm_k(key).to(dtype=dtype) + + # 4. Apply rotary positional embeddings to image tokens only + if image_rotary_emb is not None: + from ..embeddings import apply_rotary_emb + + if batch_flag is None: + # Apply RoPE only to image tokens (after text tokens) + query[:, :, text_seq_length:, :] = apply_rotary_emb( + query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 + ) + key[:, :, text_seq_length:, :] = apply_rotary_emb( + key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 + ) + else: + # For packed batches, need to carefully apply RoPE to appropriate tokens + assert query.shape[0] == packing_batch_size + assert key.shape[0] == packing_batch_size + assert len(image_rotary_emb) == batch_size + + rope_idx = 0 + for idx in range(packing_batch_size): + offset = 0 + # Get text and image sequence lengths for samples in this packed batch + text_seq_length_bi = text_seq_length[batch_flag == idx] + latent_seq_length_bi = latent_seq_length[batch_flag == idx] + + # Apply RoPE to each image segment in the packed sequence + for tlen, llen in zip(text_seq_length_bi, latent_seq_length_bi): + mlen = tlen + llen + # Apply RoPE only to image tokens (after text tokens) + query[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb( + query[idx, :, offset + tlen : offset + mlen, :], + image_rotary_emb[rope_idx], + use_real_unbind_dim=-2, + ) + key[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb( + key[idx, :, offset + tlen : offset + mlen, :], + image_rotary_emb[rope_idx], + use_real_unbind_dim=-2, + ) + offset += mlen + rope_idx += 1 + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + # Reshape back: [batch, heads, seq_len, dim] -> [batch, seq_len, heads*dim] + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + # 5. Output projection - project attention output to model dimension + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + # Split the output back into text and image streams + if batch_flag is None: + # Simple split for non-packed case + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + else: + # For packed case: need to unpack, split text/image, then restore to original shapes + # First, unpad the sequence based on the packed sequence lengths + hidden_states_unpad = torch.nn.utils.rnn.unpad_sequence( + hidden_states, + lengths=torch.tensor(mixed_seq_length_packed), + batch_first=True, + ) + # Concatenate all unpadded sequences + hidden_states_flatten = torch.cat(hidden_states_unpad, dim=0) + # Split by original sample sequence lengths + hidden_states_unpack = torch.split(hidden_states_flatten, mixed_seq_length.tolist()) + assert len(hidden_states_unpack) == batch_size + + # Further split each sample's sequence into text and image parts + hidden_states_unpack = [ + torch.split(h, [tlen, llen]) + for h, tlen, llen in zip(hidden_states_unpack, text_seq_length, latent_seq_length) + ] + # Separate text and image sequences + encoder_hidden_states_unpad = [h[0] for h in hidden_states_unpack] + hidden_states_unpad = [h[1] for h in hidden_states_unpack] + + # Update the original tensors with the processed values, respecting the attention masks + for idx in range(batch_size): + # Place unpacked text tokens back in the encoder_hidden_states tensor + encoder_hidden_states[idx][text_attn_mask[idx] == 1] = encoder_hidden_states_unpad[idx] + # Place unpacked image tokens back in the latent_hidden_states tensor + latent_hidden_states[idx][latent_attn_mask[idx] == 1] = hidden_states_unpad[idx] + + # Update the output hidden states + hidden_states = latent_hidden_states + + return hidden_states, encoder_hidden_states + + class CogView4TransformerBlock(nn.Module): def __init__( - self, dim: int = 2560, num_attention_heads: int = 64, attention_head_dim: int = 40, time_embed_dim: int = 512 + self, + dim: int = 2560, + num_attention_heads: int = 64, + attention_head_dim: int = 40, + time_embed_dim: int = 512, ) -> None: super().__init__() @@ -213,9 +487,11 @@ def forward( hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **kwargs, + image_rotary_emb: Optional[ + Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]] + ] = None, + attention_mask: Optional[Dict[str, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: # 1. Timestep conditioning ( @@ -232,12 +508,14 @@ def forward( ) = self.norm1(hidden_states, encoder_hidden_states, temb) # 2. Attention + if attention_kwargs is None: + attention_kwargs = {} attn_hidden_states, attn_encoder_hidden_states = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, attention_mask=attention_mask, - **kwargs, + **attention_kwargs, ) hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1) encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1) @@ -402,7 +680,9 @@ def forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, attention_mask: Optional[torch.Tensor] = None, - **kwargs, + image_rotary_emb: Optional[ + Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]] + ] = None, ) -> Union[torch.Tensor, Transformer2DModelOutput]: if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() @@ -422,7 +702,8 @@ def forward( batch_size, num_channels, height, width = hidden_states.shape # 1. RoPE - image_rotary_emb = self.rope(hidden_states) + if image_rotary_emb is None: + image_rotary_emb = self.rope(hidden_states) # 2. Patch & Timestep embeddings p = self.config.patch_size @@ -438,11 +719,22 @@ def forward( for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( - block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + attention_mask, + attention_kwargs, ) else: hidden_states, encoder_hidden_states = block( - hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + attention_mask, + attention_kwargs, ) # 4. Output norm & projection