Skip to content
103 changes: 73 additions & 30 deletions src/diffusers/models/controlnets/controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..transformers.transformer_sd3 import SD3SingleTransformerBlock
from .controlnet import BaseOutput, zero_module


Expand Down Expand Up @@ -58,40 +59,60 @@ def __init__(
extra_conditioning_channels: int = 0,
dual_attention_layers: Tuple[int, ...] = (),
qk_norm: Optional[str] = None,
pos_embed_type: Optional[str] = "sincos",
use_pos_embed: bool = True,
force_zeros_for_pooled_projection: bool = True,
):
super().__init__()
default_out_channels = in_channels
self.out_channels = out_channels if out_channels is not None else default_out_channels
self.inner_dim = num_attention_heads * attention_head_dim

self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=self.inner_dim,
pos_embed_max_size=pos_embed_max_size,
)
if use_pos_embed:
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=self.inner_dim,
pos_embed_max_size=pos_embed_max_size,
pos_embed_type=pos_embed_type,
)
else:
self.pos_embed = None
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)

# `attention_head_dim` is doubled to account for the mixing.
# It needs to crafted when we get the actual checkpoints.
self.transformer_blocks = nn.ModuleList(
[
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
context_pre_only=False,
qk_norm=qk_norm,
use_dual_attention=True if i in dual_attention_layers else False,
)
for i in range(num_layers)
]
)
if joint_attention_dim is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a good enough condition for now. Because based joint_attention_dim we initialize both the context_embedded and transformer_blocks (that have the JointTransformerBlock type). I am okay with it.

self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)

# `attention_head_dim` is doubled to account for the mixing.
# It needs to crafted when we get the actual checkpoints.
self.transformer_blocks = nn.ModuleList(
[
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
context_pre_only=False,
qk_norm=qk_norm,
use_dual_attention=True if i in dual_attention_layers else False,
)
for i in range(num_layers)
]
)
else:
self.context_embedder = None
self.transformer_blocks = nn.ModuleList(
[
SD3SingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for _ in range(num_layers)
]
)

# controlnet_blocks
self.controlnet_blocks = nn.ModuleList([])
Expand Down Expand Up @@ -318,9 +339,27 @@ def forward(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)

hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
if self.pos_embed is not None and hidden_states.ndim != 4:
raise ValueError("hidden_states must be 4D when pos_embed is used")

# SD3.5 8b controlnet does not have a `pos_embed`,
# it use the `pos_embed` from the transformer to process input before passing to controlnet
elif self.pos_embed is None and hidden_states.ndim != 3:
raise ValueError("hidden_states must be 3D when pos_embed is not used")

if self.context_embedder is not None and encoder_hidden_states is None:
raise ValueError("encoder_hidden_states must be provided when context_embedder is used")
# SD3.5 8b controlnet does not have a `context_embedder`, it does not use `encoder_hidden_states`
elif self.context_embedder is None and encoder_hidden_states is not None:
raise ValueError("encoder_hidden_states should not be provided when context_embedder is not used")
Comment on lines +345 to +354
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very useful!


if self.pos_embed is not None:
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.

temb = self.time_text_embed(timestep, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

if self.context_embedder is not None:
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

# add
hidden_states = hidden_states + self.pos_embed_input(controlnet_cond)
Expand Down Expand Up @@ -349,9 +388,13 @@ def custom_forward(*inputs):
)

else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
if self.context_embedder is not None:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
else:
# SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
hidden_states = block(hidden_states, temb)

block_res_samples = block_res_samples + (hidden_states,)

Expand Down
79 changes: 76 additions & 3 deletions src/diffusers/models/transformers/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,94 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention import JointTransformerBlock
from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
from ...models.attention import FeedForward, JointTransformerBlock
from ...models.attention_processor import (
Attention,
AttentionProcessor,
FusedJointAttnProcessor2_0,
JointAttnProcessor2_0,
)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


@maybe_allow_in_graph
class SD3SingleTransformerBlock(nn.Module):
r"""
A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we could make this more explicit by:

Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
"""

def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
):
super().__init__()

self.norm1 = AdaLayerNormZero(dim)

if hasattr(F, "scaled_dot_product_attention"):
processor = JointAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)

self.attn = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
eps=1e-6,
)

self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")

def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
# Attention.
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
)

# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output

norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output

hidden_states = hidden_states + ff_output

return hidden_states


class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
"""
The Transformer model introduced in Stable Diffusion 3.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,12 @@ def __call__(
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor

controlnet_config = (
self.controlnet.config
if isinstance(self.controlnet, SD3ControlNetModel)
else self.controlnet.nets[0].config
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is okay for now but could there be a case where we may have incompatibility configs when len(self.controlnet.nets) > 1? I guess we will know when we will know.

)

# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
Expand Down Expand Up @@ -932,6 +938,11 @@ def __call__(
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)

# 3. Prepare control image
if controlnet_config.force_zeros_for_pooled_projection:
# instantx sd3 controlnet does not apply shift factor
vae_shift_factor = 0
else:
vae_shift_factor = self.vae.config.shift_factor
if isinstance(self.controlnet, SD3ControlNetModel):
control_image = self.prepare_image(
image=control_image,
Expand All @@ -947,8 +958,7 @@ def __call__(
height, width = control_image.shape[-2:]

control_image = self.vae.encode(control_image).latent_dist.sample()
control_image = control_image * self.vae.config.scaling_factor

control_image = (control_image - vae_shift_factor) * self.vae.config.scaling_factor
elif isinstance(self.controlnet, SD3MultiControlNetModel):
control_images = []

Expand All @@ -966,19 +976,14 @@ def __call__(
)

control_image_ = self.vae.encode(control_image_).latent_dist.sample()
control_image_ = control_image_ * self.vae.config.scaling_factor
control_image_ = (control_image_ - vae_shift_factor) * self.vae.config.scaling_factor

control_images.append(control_image_)

control_image = control_images
else:
assert False

if controlnet_pooled_projections is None:
controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds)
else:
controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds

# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
Expand Down Expand Up @@ -1006,6 +1011,18 @@ def __call__(
]
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, SD3ControlNetModel) else keeps)

if controlnet_config.force_zeros_for_pooled_projection:
# instantx sd3 controlnet used zero pooled projection
controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds)
else:
controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds

if controlnet_config.joint_attention_dim is not None:
controlnet_encoder_hidden_states = prompt_embeds
else:
# SD35 official 8b controlnet does not use encoder_hidden_states
controlnet_encoder_hidden_states = None

# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
Expand All @@ -1025,11 +1042,17 @@ def __call__(
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]

if controlnet_config.use_pos_embed is False:
# sd35 (offical) 8b controlnet
controlnet_model_input = self.transformer.pos_embed(latent_model_input)
else:
controlnet_model_input = latent_model_input

# controlnet(s) inference
control_block_samples = self.controlnet(
hidden_states=latent_model_input,
hidden_states=controlnet_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states=controlnet_encoder_hidden_states,
pooled_projections=controlnet_pooled_projections,
joint_attention_kwargs=self.joint_attention_kwargs,
controlnet_cond=control_image,
Expand Down
Loading