Skip to content

Commit 5290c75

Browse files
authored
Merge branch 'main' into fix-run-pr-workflow
2 parents 142f259 + e47cc1f commit 5290c75

File tree

7 files changed

+484
-53
lines changed

7 files changed

+484
-53
lines changed

docs/source/en/api/pipelines/cogvideox.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,17 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
3030
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
3131

3232
There are three official CogVideoX checkpoints for text-to-video and video-to-video.
33+
3334
| checkpoints | recommended inference dtype |
34-
|---|---|
35+
|:---:|:---:|
3536
| [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b) | torch.float16 |
3637
| [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b) | torch.bfloat16 |
3738
| [`THUDM/CogVideoX1.5-5b`](https://huggingface.co/THUDM/CogVideoX1.5-5b) | torch.bfloat16 |
3839

3940
There are two official CogVideoX checkpoints available for image-to-video.
41+
4042
| checkpoints | recommended inference dtype |
41-
|---|---|
43+
|:---:|:---:|
4244
| [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V) | torch.bfloat16 |
4345
| [`THUDM/CogVideoX-1.5-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-1.5-5b-I2V) | torch.bfloat16 |
4446

@@ -48,8 +50,9 @@ For the CogVideoX 1.5 series:
4850
- Both T2V and I2V models support generation with 81 and 161 frames and work best at this value. Exporting videos at 16 FPS is recommended.
4951

5052
There are two official CogVideoX checkpoints that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team).
53+
5154
| checkpoints | recommended inference dtype |
52-
|---|---|
55+
|:---:|:---:|
5356
| [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose) | torch.bfloat16 |
5457
| [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose) | torch.bfloat16 |
5558

examples/community/README.md

Lines changed: 6 additions & 4 deletions
Large diffs are not rendered by default.
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
"""
2+
A script to convert Stable Diffusion 3.5 ControlNet checkpoints to the Diffusers format.
3+
4+
Example:
5+
Convert a SD3.5 ControlNet checkpoint to Diffusers format using local file:
6+
```bash
7+
python scripts/convert_sd3_controlnet_to_diffusers.py \
8+
--checkpoint_path "path/to/local/sd3.5_large_controlnet_canny.safetensors" \
9+
--output_path "output/sd35-controlnet-canny" \
10+
--dtype "fp16" # optional, defaults to fp32
11+
```
12+
13+
Or download and convert from HuggingFace repository:
14+
```bash
15+
python scripts/convert_sd3_controlnet_to_diffusers.py \
16+
--original_state_dict_repo_id "stabilityai/stable-diffusion-3.5-controlnets" \
17+
--filename "sd3.5_large_controlnet_canny.safetensors" \
18+
--output_path "/raid/yiyi/sd35-controlnet-canny-diffusers" \
19+
--dtype "fp32" # optional, defaults to fp32
20+
```
21+
22+
Note:
23+
The script supports the following ControlNet types from SD3.5:
24+
- Canny edge detection
25+
- Depth estimation
26+
- Blur detection
27+
28+
The checkpoint files can be downloaded from:
29+
https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets
30+
"""
31+
32+
import argparse
33+
34+
import safetensors.torch
35+
import torch
36+
from huggingface_hub import hf_hub_download
37+
38+
from diffusers import SD3ControlNetModel
39+
40+
41+
parser = argparse.ArgumentParser()
42+
parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to local checkpoint file")
43+
parser.add_argument(
44+
"--original_state_dict_repo_id", type=str, default=None, help="HuggingFace repo ID containing the checkpoint"
45+
)
46+
parser.add_argument("--filename", type=str, default=None, help="Filename of the checkpoint in the HF repo")
47+
parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
48+
parser.add_argument(
49+
"--dtype", type=str, default="fp32", help="Data type for the converted model (fp16, bf16, or fp32)"
50+
)
51+
52+
args = parser.parse_args()
53+
54+
55+
def load_original_checkpoint(args):
56+
if args.original_state_dict_repo_id is not None:
57+
if args.filename is None:
58+
raise ValueError("When using `original_state_dict_repo_id`, `filename` must also be specified")
59+
print(f"Downloading checkpoint from {args.original_state_dict_repo_id}/{args.filename}")
60+
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
61+
elif args.checkpoint_path is not None:
62+
print(f"Loading checkpoint from local path: {args.checkpoint_path}")
63+
ckpt_path = args.checkpoint_path
64+
else:
65+
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
66+
67+
original_state_dict = safetensors.torch.load_file(ckpt_path)
68+
return original_state_dict
69+
70+
71+
def convert_sd3_controlnet_checkpoint_to_diffusers(original_state_dict):
72+
converted_state_dict = {}
73+
74+
# Direct mappings for controlnet blocks
75+
for i in range(19): # 19 controlnet blocks
76+
converted_state_dict[f"controlnet_blocks.{i}.weight"] = original_state_dict[f"controlnet_blocks.{i}.weight"]
77+
converted_state_dict[f"controlnet_blocks.{i}.bias"] = original_state_dict[f"controlnet_blocks.{i}.bias"]
78+
79+
# Positional embeddings
80+
converted_state_dict["pos_embed_input.proj.weight"] = original_state_dict["pos_embed_input.proj.weight"]
81+
converted_state_dict["pos_embed_input.proj.bias"] = original_state_dict["pos_embed_input.proj.bias"]
82+
83+
# Time and text embeddings
84+
time_text_mappings = {
85+
"time_text_embed.timestep_embedder.linear_1.weight": "time_text_embed.timestep_embedder.linear_1.weight",
86+
"time_text_embed.timestep_embedder.linear_1.bias": "time_text_embed.timestep_embedder.linear_1.bias",
87+
"time_text_embed.timestep_embedder.linear_2.weight": "time_text_embed.timestep_embedder.linear_2.weight",
88+
"time_text_embed.timestep_embedder.linear_2.bias": "time_text_embed.timestep_embedder.linear_2.bias",
89+
"time_text_embed.text_embedder.linear_1.weight": "time_text_embed.text_embedder.linear_1.weight",
90+
"time_text_embed.text_embedder.linear_1.bias": "time_text_embed.text_embedder.linear_1.bias",
91+
"time_text_embed.text_embedder.linear_2.weight": "time_text_embed.text_embedder.linear_2.weight",
92+
"time_text_embed.text_embedder.linear_2.bias": "time_text_embed.text_embedder.linear_2.bias",
93+
}
94+
95+
for new_key, old_key in time_text_mappings.items():
96+
if old_key in original_state_dict:
97+
converted_state_dict[new_key] = original_state_dict[old_key]
98+
99+
# Transformer blocks
100+
for i in range(19):
101+
# Split QKV into separate Q, K, V
102+
qkv_weight = original_state_dict[f"transformer_blocks.{i}.attn.qkv.weight"]
103+
qkv_bias = original_state_dict[f"transformer_blocks.{i}.attn.qkv.bias"]
104+
q, k, v = torch.chunk(qkv_weight, 3, dim=0)
105+
q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0)
106+
107+
block_mappings = {
108+
f"transformer_blocks.{i}.attn.to_q.weight": q,
109+
f"transformer_blocks.{i}.attn.to_q.bias": q_bias,
110+
f"transformer_blocks.{i}.attn.to_k.weight": k,
111+
f"transformer_blocks.{i}.attn.to_k.bias": k_bias,
112+
f"transformer_blocks.{i}.attn.to_v.weight": v,
113+
f"transformer_blocks.{i}.attn.to_v.bias": v_bias,
114+
# Output projections
115+
f"transformer_blocks.{i}.attn.to_out.0.weight": original_state_dict[
116+
f"transformer_blocks.{i}.attn.proj.weight"
117+
],
118+
f"transformer_blocks.{i}.attn.to_out.0.bias": original_state_dict[
119+
f"transformer_blocks.{i}.attn.proj.bias"
120+
],
121+
# Feed forward
122+
f"transformer_blocks.{i}.ff.net.0.proj.weight": original_state_dict[
123+
f"transformer_blocks.{i}.mlp.fc1.weight"
124+
],
125+
f"transformer_blocks.{i}.ff.net.0.proj.bias": original_state_dict[f"transformer_blocks.{i}.mlp.fc1.bias"],
126+
f"transformer_blocks.{i}.ff.net.2.weight": original_state_dict[f"transformer_blocks.{i}.mlp.fc2.weight"],
127+
f"transformer_blocks.{i}.ff.net.2.bias": original_state_dict[f"transformer_blocks.{i}.mlp.fc2.bias"],
128+
# Norms
129+
f"transformer_blocks.{i}.norm1.linear.weight": original_state_dict[
130+
f"transformer_blocks.{i}.adaLN_modulation.1.weight"
131+
],
132+
f"transformer_blocks.{i}.norm1.linear.bias": original_state_dict[
133+
f"transformer_blocks.{i}.adaLN_modulation.1.bias"
134+
],
135+
}
136+
converted_state_dict.update(block_mappings)
137+
138+
return converted_state_dict
139+
140+
141+
def main(args):
142+
original_ckpt = load_original_checkpoint(args)
143+
original_dtype = next(iter(original_ckpt.values())).dtype
144+
145+
# Initialize dtype with fp32 as default
146+
if args.dtype == "fp16":
147+
dtype = torch.float16
148+
elif args.dtype == "bf16":
149+
dtype = torch.bfloat16
150+
elif args.dtype == "fp32":
151+
dtype = torch.float32
152+
else:
153+
raise ValueError(f"Unsupported dtype: {args.dtype}. Must be one of: fp16, bf16, fp32")
154+
155+
if dtype != original_dtype:
156+
print(
157+
f"Converting checkpoint from {original_dtype} to {dtype}. This can lead to unexpected results, proceed with caution."
158+
)
159+
160+
converted_controlnet_state_dict = convert_sd3_controlnet_checkpoint_to_diffusers(original_ckpt)
161+
162+
controlnet = SD3ControlNetModel(
163+
patch_size=2,
164+
in_channels=16,
165+
num_layers=19,
166+
attention_head_dim=64,
167+
num_attention_heads=38,
168+
joint_attention_dim=None,
169+
caption_projection_dim=2048,
170+
pooled_projection_dim=2048,
171+
out_channels=16,
172+
pos_embed_max_size=None,
173+
pos_embed_type=None,
174+
use_pos_embed=False,
175+
force_zeros_for_pooled_projection=False,
176+
)
177+
178+
controlnet.load_state_dict(converted_controlnet_state_dict, strict=True)
179+
180+
print(f"Saving SD3 ControlNet in Diffusers format in {args.output_path}.")
181+
controlnet.to(dtype).save_pretrained(args.output_path)
182+
183+
184+
if __name__ == "__main__":
185+
main(args)

src/diffusers/models/controlnets/controlnet_sd3.py

Lines changed: 73 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
2828
from ..modeling_outputs import Transformer2DModelOutput
2929
from ..modeling_utils import ModelMixin
30+
from ..transformers.transformer_sd3 import SD3SingleTransformerBlock
3031
from .controlnet import BaseOutput, zero_module
3132

3233

@@ -58,40 +59,60 @@ def __init__(
5859
extra_conditioning_channels: int = 0,
5960
dual_attention_layers: Tuple[int, ...] = (),
6061
qk_norm: Optional[str] = None,
62+
pos_embed_type: Optional[str] = "sincos",
63+
use_pos_embed: bool = True,
64+
force_zeros_for_pooled_projection: bool = True,
6165
):
6266
super().__init__()
6367
default_out_channels = in_channels
6468
self.out_channels = out_channels if out_channels is not None else default_out_channels
6569
self.inner_dim = num_attention_heads * attention_head_dim
6670

67-
self.pos_embed = PatchEmbed(
68-
height=sample_size,
69-
width=sample_size,
70-
patch_size=patch_size,
71-
in_channels=in_channels,
72-
embed_dim=self.inner_dim,
73-
pos_embed_max_size=pos_embed_max_size,
74-
)
71+
if use_pos_embed:
72+
self.pos_embed = PatchEmbed(
73+
height=sample_size,
74+
width=sample_size,
75+
patch_size=patch_size,
76+
in_channels=in_channels,
77+
embed_dim=self.inner_dim,
78+
pos_embed_max_size=pos_embed_max_size,
79+
pos_embed_type=pos_embed_type,
80+
)
81+
else:
82+
self.pos_embed = None
7583
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
7684
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
7785
)
78-
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
79-
80-
# `attention_head_dim` is doubled to account for the mixing.
81-
# It needs to crafted when we get the actual checkpoints.
82-
self.transformer_blocks = nn.ModuleList(
83-
[
84-
JointTransformerBlock(
85-
dim=self.inner_dim,
86-
num_attention_heads=num_attention_heads,
87-
attention_head_dim=self.config.attention_head_dim,
88-
context_pre_only=False,
89-
qk_norm=qk_norm,
90-
use_dual_attention=True if i in dual_attention_layers else False,
91-
)
92-
for i in range(num_layers)
93-
]
94-
)
86+
if joint_attention_dim is not None:
87+
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
88+
89+
# `attention_head_dim` is doubled to account for the mixing.
90+
# It needs to crafted when we get the actual checkpoints.
91+
self.transformer_blocks = nn.ModuleList(
92+
[
93+
JointTransformerBlock(
94+
dim=self.inner_dim,
95+
num_attention_heads=num_attention_heads,
96+
attention_head_dim=self.config.attention_head_dim,
97+
context_pre_only=False,
98+
qk_norm=qk_norm,
99+
use_dual_attention=True if i in dual_attention_layers else False,
100+
)
101+
for i in range(num_layers)
102+
]
103+
)
104+
else:
105+
self.context_embedder = None
106+
self.transformer_blocks = nn.ModuleList(
107+
[
108+
SD3SingleTransformerBlock(
109+
dim=self.inner_dim,
110+
num_attention_heads=num_attention_heads,
111+
attention_head_dim=self.config.attention_head_dim,
112+
)
113+
for _ in range(num_layers)
114+
]
115+
)
95116

96117
# controlnet_blocks
97118
self.controlnet_blocks = nn.ModuleList([])
@@ -318,9 +339,27 @@ def forward(
318339
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
319340
)
320341

321-
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
342+
if self.pos_embed is not None and hidden_states.ndim != 4:
343+
raise ValueError("hidden_states must be 4D when pos_embed is used")
344+
345+
# SD3.5 8b controlnet does not have a `pos_embed`,
346+
# it use the `pos_embed` from the transformer to process input before passing to controlnet
347+
elif self.pos_embed is None and hidden_states.ndim != 3:
348+
raise ValueError("hidden_states must be 3D when pos_embed is not used")
349+
350+
if self.context_embedder is not None and encoder_hidden_states is None:
351+
raise ValueError("encoder_hidden_states must be provided when context_embedder is used")
352+
# SD3.5 8b controlnet does not have a `context_embedder`, it does not use `encoder_hidden_states`
353+
elif self.context_embedder is None and encoder_hidden_states is not None:
354+
raise ValueError("encoder_hidden_states should not be provided when context_embedder is not used")
355+
356+
if self.pos_embed is not None:
357+
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
358+
322359
temb = self.time_text_embed(timestep, pooled_projections)
323-
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
360+
361+
if self.context_embedder is not None:
362+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
324363

325364
# add
326365
hidden_states = hidden_states + self.pos_embed_input(controlnet_cond)
@@ -349,9 +388,13 @@ def custom_forward(*inputs):
349388
)
350389

351390
else:
352-
encoder_hidden_states, hidden_states = block(
353-
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
354-
)
391+
if self.context_embedder is not None:
392+
encoder_hidden_states, hidden_states = block(
393+
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
394+
)
395+
else:
396+
# SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
397+
hidden_states = block(hidden_states, temb)
355398

356399
block_res_samples = block_res_samples + (hidden_states,)
357400

0 commit comments

Comments
 (0)