diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 8bba5a82bc2f..51d86fbf15b2 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1408,7 +1408,7 @@ class JointAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + raise ImportError("JointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py index 1b0b4bae6410..91ce76fe75a9 100644 --- a/src/diffusers/models/controlnets/controlnet_sd3.py +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -40,6 +40,48 @@ class SD3ControlNetOutput(BaseOutput): class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + r""" + ControlNet model for [Stable Diffusion 3](https://huggingface.co/papers/2403.03206). + + Parameters: + sample_size (`int`, defaults to `128`): + The width/height of the latents. This is fixed during training since it is used to learn a number of + position embeddings. + patch_size (`int`, defaults to `2`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `16`): + The number of latent channels in the input. + num_layers (`int`, defaults to `18`): + The number of layers of transformer blocks to use. + attention_head_dim (`int`, defaults to `64`): + The number of channels in each head. + num_attention_heads (`int`, defaults to `18`): + The number of heads to use for multi-head attention. + joint_attention_dim (`int`, defaults to `4096`): + The embedding dimension to use for joint text-image attention. + caption_projection_dim (`int`, defaults to `1152`): + The embedding dimension of caption embeddings. + pooled_projection_dim (`int`, defaults to `2048`): + The embedding dimension of pooled text projections. + out_channels (`int`, defaults to `16`): + The number of latent channels in the output. + pos_embed_max_size (`int`, defaults to `96`): + The maximum latent height/width of positional embeddings. + extra_conditioning_channels (`int`, defaults to `0`): + The number of extra channels to use for conditioning for patch embedding. + dual_attention_layers (`Tuple[int, ...]`, defaults to `()`): + The number of dual-stream transformer blocks to use. + qk_norm (`str`, *optional*, defaults to `None`): + The normalization to use for query and key in the attention layer. If `None`, no normalization is used. + pos_embed_type (`str`, defaults to `"sincos"`): + The type of positional embedding to use. Choose between `"sincos"` and `None`. + use_pos_embed (`bool`, defaults to `True`): + Whether to use positional embeddings. + force_zeros_for_pooled_projection (`bool`, defaults to `True`): + Whether to force zeros for pooled projection embeddings. This is handled in the pipelines by reading the + config value of the ControlNet model. + """ + _supports_gradient_checkpointing = True @register_to_config @@ -93,7 +135,7 @@ def __init__( JointTransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, - attention_head_dim=self.config.attention_head_dim, + attention_head_dim=attention_head_dim, context_pre_only=False, qk_norm=qk_norm, use_dual_attention=True if i in dual_attention_layers else False, @@ -108,7 +150,7 @@ def __init__( SD3SingleTransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, - attention_head_dim=self.config.attention_head_dim, + attention_head_dim=attention_head_dim, ) for _ in range(num_layers) ] @@ -297,28 +339,28 @@ def from_transformer( def forward( self, - hidden_states: torch.FloatTensor, + hidden_states: torch.Tensor, controlnet_cond: torch.Tensor, conditioning_scale: float = 1.0, - encoder_hidden_states: torch.FloatTensor = None, - pooled_projections: torch.FloatTensor = None, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, timestep: torch.LongTensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, - ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + ) -> Union[torch.Tensor, Transformer2DModelOutput]: """ The [`SD3Transformer2DModel`] forward method. Args: - hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`): Input `hidden_states`. controlnet_cond (`torch.Tensor`): The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. conditioning_scale (`float`, defaults to `1.0`): The scale factor for ControlNet outputs. - encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected from the embeddings of input conditions. timestep ( `torch.LongTensor`): Used to indicate denoising step. @@ -437,11 +479,11 @@ def __init__(self, controlnets): def forward( self, - hidden_states: torch.FloatTensor, + hidden_states: torch.Tensor, controlnet_cond: List[torch.tensor], conditioning_scale: List[float], - pooled_projections: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, + pooled_projections: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, timestep: torch.LongTensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index e24a28fc3d7b..e41fad220de6 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -15,7 +15,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin @@ -39,17 +38,6 @@ @maybe_allow_in_graph class SD3SingleTransformerBlock(nn.Module): - r""" - A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet. - - Reference: https://arxiv.org/abs/2403.03206 - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - """ - def __init__( self, dim: int, @@ -59,21 +47,13 @@ def __init__( super().__init__() self.norm1 = AdaLayerNormZero(dim) - - if hasattr(F, "scaled_dot_product_attention"): - processor = JointAttnProcessor2_0() - else: - raise ValueError( - "The current PyTorch version does not support the `scaled_dot_product_attention` function." - ) - self.attn = Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, bias=True, - processor=processor, + processor=JointAttnProcessor2_0(), eps=1e-6, ) @@ -81,23 +61,17 @@ def __init__( self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor): + # 1. Attention norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) - # Attention. - attn_output = self.attn( - hidden_states=norm_hidden_states, - encoder_hidden_states=None, - ) - - # Process attention outputs for the `hidden_states`. + attn_output = self.attn(hidden_states=norm_hidden_states, encoder_hidden_states=None) attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = hidden_states + attn_output + # 2. Feed Forward norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - + norm_hidden_states = norm_hidden_states * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp.unsqueeze(1) * ff_output - hidden_states = hidden_states + ff_output return hidden_states @@ -107,26 +81,40 @@ class SD3Transformer2DModel( ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin ): """ - The Transformer model introduced in Stable Diffusion 3. - - Reference: https://arxiv.org/abs/2403.03206 + The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206). Parameters: - sample_size (`int`): The width of the latent images. This is fixed during training since - it is used to learn a number of position embeddings. - patch_size (`int`): Patch size to turn the input data into small patches. - in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. - num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use. - attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. - num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. - cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`. - pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. - out_channels (`int`, defaults to 16): Number of output channels. - + sample_size (`int`, defaults to `128`): + The width/height of the latents. This is fixed during training since it is used to learn a number of + position embeddings. + patch_size (`int`, defaults to `2`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `16`): + The number of latent channels in the input. + num_layers (`int`, defaults to `18`): + The number of layers of transformer blocks to use. + attention_head_dim (`int`, defaults to `64`): + The number of channels in each head. + num_attention_heads (`int`, defaults to `18`): + The number of heads to use for multi-head attention. + joint_attention_dim (`int`, defaults to `4096`): + The embedding dimension to use for joint text-image attention. + caption_projection_dim (`int`, defaults to `1152`): + The embedding dimension of caption embeddings. + pooled_projection_dim (`int`, defaults to `2048`): + The embedding dimension of pooled text projections. + out_channels (`int`, defaults to `16`): + The number of latent channels in the output. + pos_embed_max_size (`int`, defaults to `96`): + The maximum latent height/width of positional embeddings. + dual_attention_layers (`Tuple[int, ...]`, defaults to `()`): + The number of dual-stream transformer blocks to use. + qk_norm (`str`, *optional*, defaults to `None`): + The normalization to use for query and key in the attention layer. If `None`, no normalization is used. """ _supports_gradient_checkpointing = True + _no_split_modules = ["JointTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] @register_to_config @@ -149,36 +137,33 @@ def __init__( qk_norm: Optional[str] = None, ): super().__init__() - default_out_channels = in_channels - self.out_channels = out_channels if out_channels is not None else default_out_channels - self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.out_channels = out_channels if out_channels is not None else in_channels + self.inner_dim = num_attention_heads * attention_head_dim self.pos_embed = PatchEmbed( - height=self.config.sample_size, - width=self.config.sample_size, - patch_size=self.config.patch_size, - in_channels=self.config.in_channels, + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, embed_dim=self.inner_dim, pos_embed_max_size=pos_embed_max_size, # hard-code for now. ) self.time_text_embed = CombinedTimestepTextProjEmbeddings( - embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim + embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim ) - self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim) + self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim) - # `attention_head_dim` is doubled to account for the mixing. - # It needs to crafted when we get the actual checkpoints. self.transformer_blocks = nn.ModuleList( [ JointTransformerBlock( dim=self.inner_dim, - num_attention_heads=self.config.num_attention_heads, - attention_head_dim=self.config.attention_head_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, context_pre_only=i == num_layers - 1, qk_norm=qk_norm, use_dual_attention=True if i in dual_attention_layers else False, ) - for i in range(self.config.num_layers) + for i in range(num_layers) ] ) @@ -331,24 +316,24 @@ def unfuse_qkv_projections(self): def forward( self, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - pooled_projections: torch.FloatTensor = None, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, timestep: torch.LongTensor = None, block_controlnet_hidden_states: List = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, skip_layers: Optional[List[int]] = None, - ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + ) -> Union[torch.Tensor, Transformer2DModelOutput]: """ The [`SD3Transformer2DModel`] forward method. Args: - hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`): Input `hidden_states`. - encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): + pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected from the embeddings of input conditions. timestep (`torch.LongTensor`): Used to indicate denoising step.