-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Sd35 controlnet #10020
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Sd35 controlnet #10020
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
e1e14f9
add model
yiyixuxu 6f6f0d7
add pipeline
yiyixuxu cbe5a42
add shift
yiyixuxu c87e4a3
Merge branch 'main' into sd35-control
yiyixuxu 77dadd3
Merge branch 'sd35-control' of github.com:huggingface/diffusers into …
yiyixuxu c5150de
fix
yiyixuxu f9103b1
fix so backward compatible
yiyixuxu f93efef
Update src/diffusers/models/controlnets/controlnet_sd3.py
yiyixuxu 2502a0c
Merge branch 'main' into sd35-control
sayakpaul 54fb3bc
add conversion script
yiyixuxu 6a6456b
Merge branch 'sd35-control' of github.com:huggingface/diffusers into …
yiyixuxu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,185 @@ | ||
| """ | ||
| A script to convert Stable Diffusion 3.5 ControlNet checkpoints to the Diffusers format. | ||
|
|
||
| Example: | ||
| Convert a SD3.5 ControlNet checkpoint to Diffusers format using local file: | ||
| ```bash | ||
| python scripts/convert_sd3_controlnet_to_diffusers.py \ | ||
| --checkpoint_path "path/to/local/sd3.5_large_controlnet_canny.safetensors" \ | ||
| --output_path "output/sd35-controlnet-canny" \ | ||
| --dtype "fp16" # optional, defaults to fp32 | ||
| ``` | ||
|
|
||
| Or download and convert from HuggingFace repository: | ||
| ```bash | ||
| python scripts/convert_sd3_controlnet_to_diffusers.py \ | ||
| --original_state_dict_repo_id "stabilityai/stable-diffusion-3.5-controlnets" \ | ||
| --filename "sd3.5_large_controlnet_canny.safetensors" \ | ||
| --output_path "/raid/yiyi/sd35-controlnet-canny-diffusers" \ | ||
| --dtype "fp32" # optional, defaults to fp32 | ||
| ``` | ||
|
|
||
| Note: | ||
| The script supports the following ControlNet types from SD3.5: | ||
| - Canny edge detection | ||
| - Depth estimation | ||
| - Blur detection | ||
|
|
||
| The checkpoint files can be downloaded from: | ||
| https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets | ||
| """ | ||
|
|
||
| import argparse | ||
|
|
||
| import safetensors.torch | ||
| import torch | ||
| from huggingface_hub import hf_hub_download | ||
|
|
||
| from diffusers import SD3ControlNetModel | ||
|
|
||
|
|
||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to local checkpoint file") | ||
| parser.add_argument( | ||
| "--original_state_dict_repo_id", type=str, default=None, help="HuggingFace repo ID containing the checkpoint" | ||
| ) | ||
| parser.add_argument("--filename", type=str, default=None, help="Filename of the checkpoint in the HF repo") | ||
| parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model") | ||
| parser.add_argument( | ||
| "--dtype", type=str, default="fp32", help="Data type for the converted model (fp16, bf16, or fp32)" | ||
| ) | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
|
|
||
| def load_original_checkpoint(args): | ||
| if args.original_state_dict_repo_id is not None: | ||
| if args.filename is None: | ||
| raise ValueError("When using `original_state_dict_repo_id`, `filename` must also be specified") | ||
| print(f"Downloading checkpoint from {args.original_state_dict_repo_id}/{args.filename}") | ||
| ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename) | ||
| elif args.checkpoint_path is not None: | ||
| print(f"Loading checkpoint from local path: {args.checkpoint_path}") | ||
| ckpt_path = args.checkpoint_path | ||
| else: | ||
| raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") | ||
|
|
||
| original_state_dict = safetensors.torch.load_file(ckpt_path) | ||
| return original_state_dict | ||
|
|
||
|
|
||
| def convert_sd3_controlnet_checkpoint_to_diffusers(original_state_dict): | ||
| converted_state_dict = {} | ||
|
|
||
| # Direct mappings for controlnet blocks | ||
| for i in range(19): # 19 controlnet blocks | ||
| converted_state_dict[f"controlnet_blocks.{i}.weight"] = original_state_dict[f"controlnet_blocks.{i}.weight"] | ||
| converted_state_dict[f"controlnet_blocks.{i}.bias"] = original_state_dict[f"controlnet_blocks.{i}.bias"] | ||
|
|
||
| # Positional embeddings | ||
| converted_state_dict["pos_embed_input.proj.weight"] = original_state_dict["pos_embed_input.proj.weight"] | ||
| converted_state_dict["pos_embed_input.proj.bias"] = original_state_dict["pos_embed_input.proj.bias"] | ||
|
|
||
| # Time and text embeddings | ||
| time_text_mappings = { | ||
| "time_text_embed.timestep_embedder.linear_1.weight": "time_text_embed.timestep_embedder.linear_1.weight", | ||
| "time_text_embed.timestep_embedder.linear_1.bias": "time_text_embed.timestep_embedder.linear_1.bias", | ||
| "time_text_embed.timestep_embedder.linear_2.weight": "time_text_embed.timestep_embedder.linear_2.weight", | ||
| "time_text_embed.timestep_embedder.linear_2.bias": "time_text_embed.timestep_embedder.linear_2.bias", | ||
| "time_text_embed.text_embedder.linear_1.weight": "time_text_embed.text_embedder.linear_1.weight", | ||
| "time_text_embed.text_embedder.linear_1.bias": "time_text_embed.text_embedder.linear_1.bias", | ||
| "time_text_embed.text_embedder.linear_2.weight": "time_text_embed.text_embedder.linear_2.weight", | ||
| "time_text_embed.text_embedder.linear_2.bias": "time_text_embed.text_embedder.linear_2.bias", | ||
| } | ||
|
|
||
| for new_key, old_key in time_text_mappings.items(): | ||
| if old_key in original_state_dict: | ||
| converted_state_dict[new_key] = original_state_dict[old_key] | ||
|
|
||
| # Transformer blocks | ||
| for i in range(19): | ||
| # Split QKV into separate Q, K, V | ||
| qkv_weight = original_state_dict[f"transformer_blocks.{i}.attn.qkv.weight"] | ||
| qkv_bias = original_state_dict[f"transformer_blocks.{i}.attn.qkv.bias"] | ||
| q, k, v = torch.chunk(qkv_weight, 3, dim=0) | ||
| q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0) | ||
|
|
||
| block_mappings = { | ||
| f"transformer_blocks.{i}.attn.to_q.weight": q, | ||
| f"transformer_blocks.{i}.attn.to_q.bias": q_bias, | ||
| f"transformer_blocks.{i}.attn.to_k.weight": k, | ||
| f"transformer_blocks.{i}.attn.to_k.bias": k_bias, | ||
| f"transformer_blocks.{i}.attn.to_v.weight": v, | ||
| f"transformer_blocks.{i}.attn.to_v.bias": v_bias, | ||
| # Output projections | ||
| f"transformer_blocks.{i}.attn.to_out.0.weight": original_state_dict[ | ||
| f"transformer_blocks.{i}.attn.proj.weight" | ||
| ], | ||
| f"transformer_blocks.{i}.attn.to_out.0.bias": original_state_dict[ | ||
| f"transformer_blocks.{i}.attn.proj.bias" | ||
| ], | ||
| # Feed forward | ||
| f"transformer_blocks.{i}.ff.net.0.proj.weight": original_state_dict[ | ||
| f"transformer_blocks.{i}.mlp.fc1.weight" | ||
| ], | ||
| f"transformer_blocks.{i}.ff.net.0.proj.bias": original_state_dict[f"transformer_blocks.{i}.mlp.fc1.bias"], | ||
| f"transformer_blocks.{i}.ff.net.2.weight": original_state_dict[f"transformer_blocks.{i}.mlp.fc2.weight"], | ||
| f"transformer_blocks.{i}.ff.net.2.bias": original_state_dict[f"transformer_blocks.{i}.mlp.fc2.bias"], | ||
| # Norms | ||
| f"transformer_blocks.{i}.norm1.linear.weight": original_state_dict[ | ||
| f"transformer_blocks.{i}.adaLN_modulation.1.weight" | ||
| ], | ||
| f"transformer_blocks.{i}.norm1.linear.bias": original_state_dict[ | ||
| f"transformer_blocks.{i}.adaLN_modulation.1.bias" | ||
| ], | ||
| } | ||
| converted_state_dict.update(block_mappings) | ||
|
|
||
| return converted_state_dict | ||
|
|
||
|
|
||
| def main(args): | ||
| original_ckpt = load_original_checkpoint(args) | ||
| original_dtype = next(iter(original_ckpt.values())).dtype | ||
|
|
||
| # Initialize dtype with fp32 as default | ||
| if args.dtype == "fp16": | ||
| dtype = torch.float16 | ||
| elif args.dtype == "bf16": | ||
| dtype = torch.bfloat16 | ||
| elif args.dtype == "fp32": | ||
| dtype = torch.float32 | ||
| else: | ||
| raise ValueError(f"Unsupported dtype: {args.dtype}. Must be one of: fp16, bf16, fp32") | ||
|
|
||
| if dtype != original_dtype: | ||
| print( | ||
| f"Converting checkpoint from {original_dtype} to {dtype}. This can lead to unexpected results, proceed with caution." | ||
| ) | ||
|
|
||
| converted_controlnet_state_dict = convert_sd3_controlnet_checkpoint_to_diffusers(original_ckpt) | ||
|
|
||
| controlnet = SD3ControlNetModel( | ||
| patch_size=2, | ||
| in_channels=16, | ||
| num_layers=19, | ||
| attention_head_dim=64, | ||
| num_attention_heads=38, | ||
| joint_attention_dim=None, | ||
| caption_projection_dim=2048, | ||
| pooled_projection_dim=2048, | ||
| out_channels=16, | ||
| pos_embed_max_size=None, | ||
| pos_embed_type=None, | ||
| use_pos_embed=False, | ||
| force_zeros_for_pooled_projection=False, | ||
| ) | ||
|
|
||
| controlnet.load_state_dict(converted_controlnet_state_dict, strict=True) | ||
|
|
||
| print(f"Saving SD3 ControlNet in Diffusers format in {args.output_path}.") | ||
| controlnet.to(dtype).save_pretrained(args.output_path) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main(args) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ | |
| from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed | ||
| from ..modeling_outputs import Transformer2DModelOutput | ||
| from ..modeling_utils import ModelMixin | ||
| from ..transformers.transformer_sd3 import SD3SingleTransformerBlock | ||
| from .controlnet import BaseOutput, zero_module | ||
|
|
||
|
|
||
|
|
@@ -58,40 +59,60 @@ def __init__( | |
| extra_conditioning_channels: int = 0, | ||
| dual_attention_layers: Tuple[int, ...] = (), | ||
| qk_norm: Optional[str] = None, | ||
| pos_embed_type: Optional[str] = "sincos", | ||
| use_pos_embed: bool = True, | ||
| force_zeros_for_pooled_projection: bool = True, | ||
| ): | ||
| super().__init__() | ||
| default_out_channels = in_channels | ||
| self.out_channels = out_channels if out_channels is not None else default_out_channels | ||
| self.inner_dim = num_attention_heads * attention_head_dim | ||
|
|
||
| self.pos_embed = PatchEmbed( | ||
| height=sample_size, | ||
| width=sample_size, | ||
| patch_size=patch_size, | ||
| in_channels=in_channels, | ||
| embed_dim=self.inner_dim, | ||
| pos_embed_max_size=pos_embed_max_size, | ||
| ) | ||
| if use_pos_embed: | ||
| self.pos_embed = PatchEmbed( | ||
| height=sample_size, | ||
| width=sample_size, | ||
| patch_size=patch_size, | ||
| in_channels=in_channels, | ||
| embed_dim=self.inner_dim, | ||
| pos_embed_max_size=pos_embed_max_size, | ||
| pos_embed_type=pos_embed_type, | ||
| ) | ||
| else: | ||
| self.pos_embed = None | ||
| self.time_text_embed = CombinedTimestepTextProjEmbeddings( | ||
| embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim | ||
| ) | ||
| self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim) | ||
|
|
||
| # `attention_head_dim` is doubled to account for the mixing. | ||
| # It needs to crafted when we get the actual checkpoints. | ||
| self.transformer_blocks = nn.ModuleList( | ||
| [ | ||
| JointTransformerBlock( | ||
| dim=self.inner_dim, | ||
| num_attention_heads=num_attention_heads, | ||
| attention_head_dim=self.config.attention_head_dim, | ||
| context_pre_only=False, | ||
| qk_norm=qk_norm, | ||
| use_dual_attention=True if i in dual_attention_layers else False, | ||
| ) | ||
| for i in range(num_layers) | ||
| ] | ||
| ) | ||
| if joint_attention_dim is not None: | ||
| self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim) | ||
|
|
||
| # `attention_head_dim` is doubled to account for the mixing. | ||
| # It needs to crafted when we get the actual checkpoints. | ||
| self.transformer_blocks = nn.ModuleList( | ||
| [ | ||
| JointTransformerBlock( | ||
| dim=self.inner_dim, | ||
| num_attention_heads=num_attention_heads, | ||
| attention_head_dim=self.config.attention_head_dim, | ||
| context_pre_only=False, | ||
| qk_norm=qk_norm, | ||
| use_dual_attention=True if i in dual_attention_layers else False, | ||
| ) | ||
| for i in range(num_layers) | ||
| ] | ||
| ) | ||
| else: | ||
| self.context_embedder = None | ||
| self.transformer_blocks = nn.ModuleList( | ||
| [ | ||
| SD3SingleTransformerBlock( | ||
| dim=self.inner_dim, | ||
| num_attention_heads=num_attention_heads, | ||
| attention_head_dim=self.config.attention_head_dim, | ||
| ) | ||
| for _ in range(num_layers) | ||
| ] | ||
| ) | ||
|
|
||
| # controlnet_blocks | ||
| self.controlnet_blocks = nn.ModuleList([]) | ||
|
|
@@ -318,9 +339,27 @@ def forward( | |
| "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." | ||
| ) | ||
|
|
||
| hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. | ||
| if self.pos_embed is not None and hidden_states.ndim != 4: | ||
| raise ValueError("hidden_states must be 4D when pos_embed is used") | ||
|
|
||
| # SD3.5 8b controlnet does not have a `pos_embed`, | ||
| # it use the `pos_embed` from the transformer to process input before passing to controlnet | ||
| elif self.pos_embed is None and hidden_states.ndim != 3: | ||
| raise ValueError("hidden_states must be 3D when pos_embed is not used") | ||
|
|
||
| if self.context_embedder is not None and encoder_hidden_states is None: | ||
| raise ValueError("encoder_hidden_states must be provided when context_embedder is used") | ||
| # SD3.5 8b controlnet does not have a `context_embedder`, it does not use `encoder_hidden_states` | ||
| elif self.context_embedder is None and encoder_hidden_states is not None: | ||
| raise ValueError("encoder_hidden_states should not be provided when context_embedder is not used") | ||
|
Comment on lines
+345
to
+354
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very useful! |
||
|
|
||
| if self.pos_embed is not None: | ||
| hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. | ||
|
|
||
| temb = self.time_text_embed(timestep, pooled_projections) | ||
| encoder_hidden_states = self.context_embedder(encoder_hidden_states) | ||
|
|
||
| if self.context_embedder is not None: | ||
| encoder_hidden_states = self.context_embedder(encoder_hidden_states) | ||
|
|
||
| # add | ||
| hidden_states = hidden_states + self.pos_embed_input(controlnet_cond) | ||
|
|
@@ -349,9 +388,13 @@ def custom_forward(*inputs): | |
| ) | ||
|
|
||
| else: | ||
| encoder_hidden_states, hidden_states = block( | ||
| hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb | ||
| ) | ||
| if self.context_embedder is not None: | ||
| encoder_hidden_states, hidden_states = block( | ||
| hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb | ||
| ) | ||
| else: | ||
| # SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states` | ||
| hidden_states = block(hidden_states, temb) | ||
|
|
||
| block_res_samples = block_res_samples + (hidden_states,) | ||
|
|
||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a good enough condition for now. Because based
joint_attention_dimwe initialize both thecontext_embeddedandtransformer_blocks(that have theJointTransformerBlocktype). I am okay with it.