Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,7 +1408,7 @@ class JointAttnProcessor2_0:

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
raise ImportError("JointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

def __call__(
self,
Expand Down
66 changes: 54 additions & 12 deletions src/diffusers/models/controlnets/controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,48 @@ class SD3ControlNetOutput(BaseOutput):


class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
ControlNet model for [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).

Parameters:
sample_size (`int`, defaults to `128`):
The width/height of the latents. This is fixed during training since it is used to learn a number of
position embeddings.
patch_size (`int`, defaults to `2`):
Patch size to turn the input data into small patches.
in_channels (`int`, defaults to `16`):
The number of latent channels in the input.
num_layers (`int`, defaults to `18`):
The number of layers of transformer blocks to use.
attention_head_dim (`int`, defaults to `64`):
The number of channels in each head.
num_attention_heads (`int`, defaults to `18`):
The number of heads to use for multi-head attention.
joint_attention_dim (`int`, defaults to `4096`):
The embedding dimension to use for joint text-image attention.
caption_projection_dim (`int`, defaults to `1152`):
The embedding dimension of caption embeddings.
pooled_projection_dim (`int`, defaults to `2048`):
The embedding dimension of pooled text projections.
out_channels (`int`, defaults to `16`):
The number of latent channels in the output.
pos_embed_max_size (`int`, defaults to `96`):
The maximum latent height/width of positional embeddings.
extra_conditioning_channels (`int`, defaults to `0`):
The number of extra channels to use for conditioning for patch embedding.
dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
The number of dual-stream transformer blocks to use.
qk_norm (`str`, *optional*, defaults to `None`):
The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
pos_embed_type (`str`, defaults to `"sincos"`):
The type of positional embedding to use. Choose between `"sincos"` and `None`.
use_pos_embed (`bool`, defaults to `True`):
Whether to use positional embeddings.
force_zeros_for_pooled_projection (`bool`, defaults to `True`):
Whether to force zeros for pooled projection embeddings. This is handled in the pipelines by reading the
config value of the ControlNet model.
"""

_supports_gradient_checkpointing = True

@register_to_config
Expand Down Expand Up @@ -93,7 +135,7 @@ def __init__(
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
attention_head_dim=attention_head_dim,
context_pre_only=False,
qk_norm=qk_norm,
use_dual_attention=True if i in dual_attention_layers else False,
Expand All @@ -108,7 +150,7 @@ def __init__(
SD3SingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
attention_head_dim=attention_head_dim,
)
for _ in range(num_layers)
]
Expand Down Expand Up @@ -297,28 +339,28 @@ def from_transformer(

def forward(
self,
hidden_states: torch.FloatTensor,
hidden_states: torch.Tensor,
controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0,
encoder_hidden_states: torch.FloatTensor = None,
pooled_projections: torch.FloatTensor = None,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`SD3Transformer2DModel`] forward method.

Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
controlnet_cond (`torch.Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
Expand Down Expand Up @@ -437,11 +479,11 @@ def __init__(self, controlnets):

def forward(
self,
hidden_states: torch.FloatTensor,
hidden_states: torch.Tensor,
controlnet_cond: List[torch.tensor],
conditioning_scale: List[float],
pooled_projections: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
pooled_projections: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
timestep: torch.LongTensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
Expand Down
119 changes: 52 additions & 67 deletions src/diffusers/models/transformers/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
Expand All @@ -39,17 +38,6 @@

@maybe_allow_in_graph
class SD3SingleTransformerBlock(nn.Module):
r"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this on purpose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, yes! All the latest integrations don't document the internal blocks. This is because of two reasons:

  • the parameters are already documented for main transformer, so copying for each internal block is redundant and almost always has introduced mismatched explanations (from doing the refactoring, i sometimes see parameters that don't exist are documented, and sometimes the explanations are just wrong)
  • The internal blocks are not user-facing API or linked in main diffusers documentation, so they will not show up there anyway

A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet.

Reference: https://arxiv.org/abs/2403.03206

Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
"""

def __init__(
self,
dim: int,
Expand All @@ -59,45 +47,31 @@ def __init__(
super().__init__()

self.norm1 = AdaLayerNormZero(dim)

if hasattr(F, "scaled_dot_product_attention"):
processor = JointAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)

self.attn = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
processor=JointAttnProcessor2_0(),
eps=1e-6,
)

self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")

def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
# 1. Attention
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
# Attention.
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
)

# Process attention outputs for the `hidden_states`.
attn_output = self.attn(hidden_states=norm_hidden_states, encoder_hidden_states=None)
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output

# 2. Feed Forward
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

norm_hidden_states = norm_hidden_states * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output

hidden_states = hidden_states + ff_output

return hidden_states
Expand All @@ -107,26 +81,40 @@ class SD3Transformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
):
"""
The Transformer model introduced in Stable Diffusion 3.

Reference: https://arxiv.org/abs/2403.03206
The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).

Parameters:
sample_size (`int`): The width of the latent images. This is fixed during training since
it is used to learn a number of position embeddings.
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
out_channels (`int`, defaults to 16): Number of output channels.

sample_size (`int`, defaults to `128`):
The width/height of the latents. This is fixed during training since it is used to learn a number of
position embeddings.
patch_size (`int`, defaults to `2`):
Patch size to turn the input data into small patches.
in_channels (`int`, defaults to `16`):
The number of latent channels in the input.
num_layers (`int`, defaults to `18`):
The number of layers of transformer blocks to use.
attention_head_dim (`int`, defaults to `64`):
The number of channels in each head.
num_attention_heads (`int`, defaults to `18`):
The number of heads to use for multi-head attention.
joint_attention_dim (`int`, defaults to `4096`):
The embedding dimension to use for joint text-image attention.
caption_projection_dim (`int`, defaults to `1152`):
The embedding dimension of caption embeddings.
pooled_projection_dim (`int`, defaults to `2048`):
The embedding dimension of pooled text projections.
out_channels (`int`, defaults to `16`):
The number of latent channels in the output.
pos_embed_max_size (`int`, defaults to `96`):
The maximum latent height/width of positional embeddings.
dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
The number of dual-stream transformer blocks to use.
qk_norm (`str`, *optional*, defaults to `None`):
The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
"""

_supports_gradient_checkpointing = True
_no_split_modules = ["JointTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]

@register_to_config
Expand All @@ -149,36 +137,33 @@ def __init__(
qk_norm: Optional[str] = None,
):
super().__init__()
default_out_channels = in_channels
self.out_channels = out_channels if out_channels is not None else default_out_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.out_channels = out_channels if out_channels is not None else in_channels
self.inner_dim = num_attention_heads * attention_head_dim

self.pos_embed = PatchEmbed(
height=self.config.sample_size,
width=self.config.sample_size,
patch_size=self.config.patch_size,
in_channels=self.config.in_channels,
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=self.inner_dim,
pos_embed_max_size=pos_embed_max_size, # hard-code for now.
)
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)

# `attention_head_dim` is doubled to account for the mixing.
# It needs to crafted when we get the actual checkpoints.
self.transformer_blocks = nn.ModuleList(
[
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
context_pre_only=i == num_layers - 1,
qk_norm=qk_norm,
use_dual_attention=True if i in dual_attention_layers else False,
)
for i in range(self.config.num_layers)
for i in range(num_layers)
]
)

Expand Down Expand Up @@ -331,24 +316,24 @@ def unfuse_qkv_projections(self):

def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
pooled_projections: torch.FloatTensor = None,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
block_controlnet_hidden_states: List = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
skip_layers: Optional[List[int]] = None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`SD3Transformer2DModel`] forward method.

Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`):
Embeddings projected from the embeddings of input conditions.
timestep (`torch.LongTensor`):
Used to indicate denoising step.
Expand Down