Skip to content
185 changes: 185 additions & 0 deletions scripts/convert_sd3_controlnet_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""
A script to convert Stable Diffusion 3.5 ControlNet checkpoints to the Diffusers format.

Example:
Convert a SD3.5 ControlNet checkpoint to Diffusers format using local file:
```bash
python scripts/convert_sd3_controlnet_to_diffusers.py \
--checkpoint_path "path/to/local/sd3.5_large_controlnet_canny.safetensors" \
--output_path "output/sd35-controlnet-canny" \
--dtype "fp16" # optional, defaults to fp32
```

Or download and convert from HuggingFace repository:
```bash
python scripts/convert_sd3_controlnet_to_diffusers.py \
--original_state_dict_repo_id "stabilityai/stable-diffusion-3.5-controlnets" \
--filename "sd3.5_large_controlnet_canny.safetensors" \
--output_path "/raid/yiyi/sd35-controlnet-canny-diffusers" \
--dtype "fp32" # optional, defaults to fp32
```

Note:
The script supports the following ControlNet types from SD3.5:
- Canny edge detection
- Depth estimation
- Blur detection

The checkpoint files can be downloaded from:
https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets
"""

import argparse

import safetensors.torch
import torch
from huggingface_hub import hf_hub_download

from diffusers import SD3ControlNetModel


parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to local checkpoint file")
parser.add_argument(
"--original_state_dict_repo_id", type=str, default=None, help="HuggingFace repo ID containing the checkpoint"
)
parser.add_argument("--filename", type=str, default=None, help="Filename of the checkpoint in the HF repo")
parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
parser.add_argument(
"--dtype", type=str, default="fp32", help="Data type for the converted model (fp16, bf16, or fp32)"
)

args = parser.parse_args()


def load_original_checkpoint(args):
if args.original_state_dict_repo_id is not None:
if args.filename is None:
raise ValueError("When using `original_state_dict_repo_id`, `filename` must also be specified")
print(f"Downloading checkpoint from {args.original_state_dict_repo_id}/{args.filename}")
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
elif args.checkpoint_path is not None:
print(f"Loading checkpoint from local path: {args.checkpoint_path}")
ckpt_path = args.checkpoint_path
else:
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")

original_state_dict = safetensors.torch.load_file(ckpt_path)
return original_state_dict


def convert_sd3_controlnet_checkpoint_to_diffusers(original_state_dict):
converted_state_dict = {}

# Direct mappings for controlnet blocks
for i in range(19): # 19 controlnet blocks
converted_state_dict[f"controlnet_blocks.{i}.weight"] = original_state_dict[f"controlnet_blocks.{i}.weight"]
converted_state_dict[f"controlnet_blocks.{i}.bias"] = original_state_dict[f"controlnet_blocks.{i}.bias"]

# Positional embeddings
converted_state_dict["pos_embed_input.proj.weight"] = original_state_dict["pos_embed_input.proj.weight"]
converted_state_dict["pos_embed_input.proj.bias"] = original_state_dict["pos_embed_input.proj.bias"]

# Time and text embeddings
time_text_mappings = {
"time_text_embed.timestep_embedder.linear_1.weight": "time_text_embed.timestep_embedder.linear_1.weight",
"time_text_embed.timestep_embedder.linear_1.bias": "time_text_embed.timestep_embedder.linear_1.bias",
"time_text_embed.timestep_embedder.linear_2.weight": "time_text_embed.timestep_embedder.linear_2.weight",
"time_text_embed.timestep_embedder.linear_2.bias": "time_text_embed.timestep_embedder.linear_2.bias",
"time_text_embed.text_embedder.linear_1.weight": "time_text_embed.text_embedder.linear_1.weight",
"time_text_embed.text_embedder.linear_1.bias": "time_text_embed.text_embedder.linear_1.bias",
"time_text_embed.text_embedder.linear_2.weight": "time_text_embed.text_embedder.linear_2.weight",
"time_text_embed.text_embedder.linear_2.bias": "time_text_embed.text_embedder.linear_2.bias",
}

for new_key, old_key in time_text_mappings.items():
if old_key in original_state_dict:
converted_state_dict[new_key] = original_state_dict[old_key]

# Transformer blocks
for i in range(19):
# Split QKV into separate Q, K, V
qkv_weight = original_state_dict[f"transformer_blocks.{i}.attn.qkv.weight"]
qkv_bias = original_state_dict[f"transformer_blocks.{i}.attn.qkv.bias"]
q, k, v = torch.chunk(qkv_weight, 3, dim=0)
q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0)

block_mappings = {
f"transformer_blocks.{i}.attn.to_q.weight": q,
f"transformer_blocks.{i}.attn.to_q.bias": q_bias,
f"transformer_blocks.{i}.attn.to_k.weight": k,
f"transformer_blocks.{i}.attn.to_k.bias": k_bias,
f"transformer_blocks.{i}.attn.to_v.weight": v,
f"transformer_blocks.{i}.attn.to_v.bias": v_bias,
# Output projections
f"transformer_blocks.{i}.attn.to_out.0.weight": original_state_dict[
f"transformer_blocks.{i}.attn.proj.weight"
],
f"transformer_blocks.{i}.attn.to_out.0.bias": original_state_dict[
f"transformer_blocks.{i}.attn.proj.bias"
],
# Feed forward
f"transformer_blocks.{i}.ff.net.0.proj.weight": original_state_dict[
f"transformer_blocks.{i}.mlp.fc1.weight"
],
f"transformer_blocks.{i}.ff.net.0.proj.bias": original_state_dict[f"transformer_blocks.{i}.mlp.fc1.bias"],
f"transformer_blocks.{i}.ff.net.2.weight": original_state_dict[f"transformer_blocks.{i}.mlp.fc2.weight"],
f"transformer_blocks.{i}.ff.net.2.bias": original_state_dict[f"transformer_blocks.{i}.mlp.fc2.bias"],
# Norms
f"transformer_blocks.{i}.norm1.linear.weight": original_state_dict[
f"transformer_blocks.{i}.adaLN_modulation.1.weight"
],
f"transformer_blocks.{i}.norm1.linear.bias": original_state_dict[
f"transformer_blocks.{i}.adaLN_modulation.1.bias"
],
}
converted_state_dict.update(block_mappings)

return converted_state_dict


def main(args):
original_ckpt = load_original_checkpoint(args)
original_dtype = next(iter(original_ckpt.values())).dtype

# Initialize dtype with fp32 as default
if args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
dtype = torch.bfloat16
elif args.dtype == "fp32":
dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype: {args.dtype}. Must be one of: fp16, bf16, fp32")

if dtype != original_dtype:
print(
f"Converting checkpoint from {original_dtype} to {dtype}. This can lead to unexpected results, proceed with caution."
)

converted_controlnet_state_dict = convert_sd3_controlnet_checkpoint_to_diffusers(original_ckpt)

controlnet = SD3ControlNetModel(
patch_size=2,
in_channels=16,
num_layers=19,
attention_head_dim=64,
num_attention_heads=38,
joint_attention_dim=None,
caption_projection_dim=2048,
pooled_projection_dim=2048,
out_channels=16,
pos_embed_max_size=None,
pos_embed_type=None,
use_pos_embed=False,
force_zeros_for_pooled_projection=False,
)

controlnet.load_state_dict(converted_controlnet_state_dict, strict=True)

print(f"Saving SD3 ControlNet in Diffusers format in {args.output_path}.")
controlnet.to(dtype).save_pretrained(args.output_path)


if __name__ == "__main__":
main(args)
103 changes: 73 additions & 30 deletions src/diffusers/models/controlnets/controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..transformers.transformer_sd3 import SD3SingleTransformerBlock
from .controlnet import BaseOutput, zero_module


Expand Down Expand Up @@ -58,40 +59,60 @@ def __init__(
extra_conditioning_channels: int = 0,
dual_attention_layers: Tuple[int, ...] = (),
qk_norm: Optional[str] = None,
pos_embed_type: Optional[str] = "sincos",
use_pos_embed: bool = True,
force_zeros_for_pooled_projection: bool = True,
):
super().__init__()
default_out_channels = in_channels
self.out_channels = out_channels if out_channels is not None else default_out_channels
self.inner_dim = num_attention_heads * attention_head_dim

self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=self.inner_dim,
pos_embed_max_size=pos_embed_max_size,
)
if use_pos_embed:
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=self.inner_dim,
pos_embed_max_size=pos_embed_max_size,
pos_embed_type=pos_embed_type,
)
else:
self.pos_embed = None
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)

# `attention_head_dim` is doubled to account for the mixing.
# It needs to crafted when we get the actual checkpoints.
self.transformer_blocks = nn.ModuleList(
[
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
context_pre_only=False,
qk_norm=qk_norm,
use_dual_attention=True if i in dual_attention_layers else False,
)
for i in range(num_layers)
]
)
if joint_attention_dim is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a good enough condition for now. Because based joint_attention_dim we initialize both the context_embedded and transformer_blocks (that have the JointTransformerBlock type). I am okay with it.

self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)

# `attention_head_dim` is doubled to account for the mixing.
# It needs to crafted when we get the actual checkpoints.
self.transformer_blocks = nn.ModuleList(
[
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
context_pre_only=False,
qk_norm=qk_norm,
use_dual_attention=True if i in dual_attention_layers else False,
)
for i in range(num_layers)
]
)
else:
self.context_embedder = None
self.transformer_blocks = nn.ModuleList(
[
SD3SingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for _ in range(num_layers)
]
)

# controlnet_blocks
self.controlnet_blocks = nn.ModuleList([])
Expand Down Expand Up @@ -318,9 +339,27 @@ def forward(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)

hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
if self.pos_embed is not None and hidden_states.ndim != 4:
raise ValueError("hidden_states must be 4D when pos_embed is used")

# SD3.5 8b controlnet does not have a `pos_embed`,
# it use the `pos_embed` from the transformer to process input before passing to controlnet
elif self.pos_embed is None and hidden_states.ndim != 3:
raise ValueError("hidden_states must be 3D when pos_embed is not used")

if self.context_embedder is not None and encoder_hidden_states is None:
raise ValueError("encoder_hidden_states must be provided when context_embedder is used")
# SD3.5 8b controlnet does not have a `context_embedder`, it does not use `encoder_hidden_states`
elif self.context_embedder is None and encoder_hidden_states is not None:
raise ValueError("encoder_hidden_states should not be provided when context_embedder is not used")
Comment on lines +345 to +354
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very useful!


if self.pos_embed is not None:
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.

temb = self.time_text_embed(timestep, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

if self.context_embedder is not None:
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

# add
hidden_states = hidden_states + self.pos_embed_input(controlnet_cond)
Expand Down Expand Up @@ -349,9 +388,13 @@ def custom_forward(*inputs):
)

else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
if self.context_embedder is not None:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
else:
# SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
hidden_states = block(hidden_states, temb)

block_res_samples = block_res_samples + (hidden_states,)

Expand Down
Loading
Loading