-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Sd35 controlnet #10020
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sd35 controlnet #10020
Changes from 9 commits
e1e14f9
6f6f0d7
cbe5a42
c87e4a3
77dadd3
c5150de
f9103b1
f93efef
2502a0c
54fb3bc
6a6456b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ | |
| from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed | ||
| from ..modeling_outputs import Transformer2DModelOutput | ||
| from ..modeling_utils import ModelMixin | ||
| from ..transformers.transformer_sd3 import SD3SingleTransformerBlock | ||
| from .controlnet import BaseOutput, zero_module | ||
|
|
||
|
|
||
|
|
@@ -58,40 +59,60 @@ def __init__( | |
| extra_conditioning_channels: int = 0, | ||
| dual_attention_layers: Tuple[int, ...] = (), | ||
| qk_norm: Optional[str] = None, | ||
| pos_embed_type: Optional[str] = "sincos", | ||
| use_pos_embed: bool = True, | ||
| force_zeros_for_pooled_projection: bool = True, | ||
| ): | ||
| super().__init__() | ||
| default_out_channels = in_channels | ||
| self.out_channels = out_channels if out_channels is not None else default_out_channels | ||
| self.inner_dim = num_attention_heads * attention_head_dim | ||
|
|
||
| self.pos_embed = PatchEmbed( | ||
| height=sample_size, | ||
| width=sample_size, | ||
| patch_size=patch_size, | ||
| in_channels=in_channels, | ||
| embed_dim=self.inner_dim, | ||
| pos_embed_max_size=pos_embed_max_size, | ||
| ) | ||
| if use_pos_embed: | ||
| self.pos_embed = PatchEmbed( | ||
| height=sample_size, | ||
| width=sample_size, | ||
| patch_size=patch_size, | ||
| in_channels=in_channels, | ||
| embed_dim=self.inner_dim, | ||
| pos_embed_max_size=pos_embed_max_size, | ||
| pos_embed_type=pos_embed_type, | ||
| ) | ||
| else: | ||
| self.pos_embed = None | ||
| self.time_text_embed = CombinedTimestepTextProjEmbeddings( | ||
| embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim | ||
| ) | ||
| self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim) | ||
|
|
||
| # `attention_head_dim` is doubled to account for the mixing. | ||
| # It needs to crafted when we get the actual checkpoints. | ||
| self.transformer_blocks = nn.ModuleList( | ||
| [ | ||
| JointTransformerBlock( | ||
| dim=self.inner_dim, | ||
| num_attention_heads=num_attention_heads, | ||
| attention_head_dim=self.config.attention_head_dim, | ||
| context_pre_only=False, | ||
| qk_norm=qk_norm, | ||
| use_dual_attention=True if i in dual_attention_layers else False, | ||
| ) | ||
| for i in range(num_layers) | ||
| ] | ||
| ) | ||
| if joint_attention_dim is not None: | ||
| self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim) | ||
|
|
||
| # `attention_head_dim` is doubled to account for the mixing. | ||
| # It needs to crafted when we get the actual checkpoints. | ||
| self.transformer_blocks = nn.ModuleList( | ||
| [ | ||
| JointTransformerBlock( | ||
| dim=self.inner_dim, | ||
| num_attention_heads=num_attention_heads, | ||
| attention_head_dim=self.config.attention_head_dim, | ||
| context_pre_only=False, | ||
| qk_norm=qk_norm, | ||
| use_dual_attention=True if i in dual_attention_layers else False, | ||
| ) | ||
| for i in range(num_layers) | ||
| ] | ||
| ) | ||
| else: | ||
| self.context_embedder = None | ||
| self.transformer_blocks = nn.ModuleList( | ||
| [ | ||
| SD3SingleTransformerBlock( | ||
| dim=self.inner_dim, | ||
| num_attention_heads=num_attention_heads, | ||
| attention_head_dim=self.config.attention_head_dim, | ||
| ) | ||
| for _ in range(num_layers) | ||
| ] | ||
| ) | ||
|
|
||
| # controlnet_blocks | ||
| self.controlnet_blocks = nn.ModuleList([]) | ||
|
|
@@ -318,9 +339,27 @@ def forward( | |
| "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." | ||
| ) | ||
|
|
||
| hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. | ||
| if self.pos_embed is not None and hidden_states.ndim != 4: | ||
| raise ValueError("hidden_states must be 4D when pos_embed is used") | ||
|
|
||
| # SD3.5 8b controlnet does not have a `pos_embed`, | ||
| # it use the `pos_embed` from the transformer to process input before passing to controlnet | ||
| elif self.pos_embed is None and hidden_states.ndim != 3: | ||
| raise ValueError("hidden_states must be 3D when pos_embed is not used") | ||
|
|
||
| if self.context_embedder is not None and encoder_hidden_states is None: | ||
| raise ValueError("encoder_hidden_states must be provided when context_embedder is used") | ||
| # SD3.5 8b controlnet does not have a `context_embedder`, it does not use `encoder_hidden_states` | ||
| elif self.context_embedder is None and encoder_hidden_states is not None: | ||
| raise ValueError("encoder_hidden_states should not be provided when context_embedder is not used") | ||
|
Comment on lines
+345
to
+354
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very useful! |
||
|
|
||
| if self.pos_embed is not None: | ||
| hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. | ||
|
|
||
| temb = self.time_text_embed(timestep, pooled_projections) | ||
| encoder_hidden_states = self.context_embedder(encoder_hidden_states) | ||
|
|
||
| if self.context_embedder is not None: | ||
| encoder_hidden_states = self.context_embedder(encoder_hidden_states) | ||
|
|
||
| # add | ||
| hidden_states = hidden_states + self.pos_embed_input(controlnet_cond) | ||
|
|
@@ -349,9 +388,13 @@ def custom_forward(*inputs): | |
| ) | ||
|
|
||
| else: | ||
| encoder_hidden_states, hidden_states = block( | ||
| hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb | ||
| ) | ||
| if self.context_embedder is not None: | ||
| encoder_hidden_states, hidden_states = block( | ||
| hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb | ||
| ) | ||
| else: | ||
| # SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states` | ||
| hidden_states = block(hidden_states, temb) | ||
|
|
||
| block_res_samples = block_res_samples + (hidden_states,) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,21 +18,94 @@ | |
| import numpy as np | ||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
|
|
||
| from ...configuration_utils import ConfigMixin, register_to_config | ||
| from ...loaders import FromOriginalModelMixin, PeftAdapterMixin | ||
| from ...models.attention import JointTransformerBlock | ||
| from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 | ||
| from ...models.attention import FeedForward, JointTransformerBlock | ||
| from ...models.attention_processor import ( | ||
| Attention, | ||
| AttentionProcessor, | ||
| FusedJointAttnProcessor2_0, | ||
| JointAttnProcessor2_0, | ||
| ) | ||
| from ...models.modeling_utils import ModelMixin | ||
| from ...models.normalization import AdaLayerNormContinuous | ||
| from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero | ||
| from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers | ||
| from ...utils.torch_utils import maybe_allow_in_graph | ||
| from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed | ||
| from ..modeling_outputs import Transformer2DModelOutput | ||
|
|
||
|
|
||
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||
|
|
||
|
|
||
| @maybe_allow_in_graph | ||
| class SD3SingleTransformerBlock(nn.Module): | ||
| r""" | ||
| A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps we could make this more explicit by:
|
||
| Reference: https://arxiv.org/abs/2403.03206 | ||
| Parameters: | ||
| dim (`int`): The number of channels in the input and output. | ||
| num_attention_heads (`int`): The number of heads to use for multi-head attention. | ||
| attention_head_dim (`int`): The number of channels in each head. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| dim: int, | ||
| num_attention_heads: int, | ||
| attention_head_dim: int, | ||
| ): | ||
| super().__init__() | ||
|
|
||
| self.norm1 = AdaLayerNormZero(dim) | ||
|
|
||
| if hasattr(F, "scaled_dot_product_attention"): | ||
| processor = JointAttnProcessor2_0() | ||
| else: | ||
| raise ValueError( | ||
| "The current PyTorch version does not support the `scaled_dot_product_attention` function." | ||
| ) | ||
|
|
||
| self.attn = Attention( | ||
| query_dim=dim, | ||
| dim_head=attention_head_dim, | ||
| heads=num_attention_heads, | ||
| out_dim=dim, | ||
| bias=True, | ||
| processor=processor, | ||
| eps=1e-6, | ||
| ) | ||
|
|
||
| self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) | ||
| self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") | ||
|
|
||
| def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor): | ||
| norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) | ||
| # Attention. | ||
| attn_output = self.attn( | ||
| hidden_states=norm_hidden_states, | ||
| encoder_hidden_states=None, | ||
| ) | ||
|
|
||
| # Process attention outputs for the `hidden_states`. | ||
| attn_output = gate_msa.unsqueeze(1) * attn_output | ||
| hidden_states = hidden_states + attn_output | ||
|
|
||
| norm_hidden_states = self.norm2(hidden_states) | ||
| norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] | ||
|
|
||
| ff_output = self.ff(norm_hidden_states) | ||
| ff_output = gate_mlp.unsqueeze(1) * ff_output | ||
|
|
||
| hidden_states = hidden_states + ff_output | ||
|
|
||
| return hidden_states | ||
|
|
||
|
|
||
| class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): | ||
| """ | ||
| The Transformer model introduced in Stable Diffusion 3. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -858,6 +858,12 @@ def __call__( | |
| height = height or self.default_sample_size * self.vae_scale_factor | ||
| width = width or self.default_sample_size * self.vae_scale_factor | ||
|
|
||
| controlnet_config = ( | ||
| self.controlnet.config | ||
| if isinstance(self.controlnet, SD3ControlNetModel) | ||
| else self.controlnet.nets[0].config | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this is okay for now but could there be a case where we may have incompatibility configs when |
||
| ) | ||
|
|
||
| # align format for control guidance | ||
| if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): | ||
| control_guidance_start = len(control_guidance_end) * [control_guidance_start] | ||
|
|
@@ -932,6 +938,11 @@ def __call__( | |
| pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) | ||
|
|
||
| # 3. Prepare control image | ||
| if controlnet_config.force_zeros_for_pooled_projection: | ||
| # instantx sd3 controlnet does not apply shift factor | ||
| vae_shift_factor = 0 | ||
| else: | ||
| vae_shift_factor = self.vae.config.shift_factor | ||
| if isinstance(self.controlnet, SD3ControlNetModel): | ||
| control_image = self.prepare_image( | ||
| image=control_image, | ||
|
|
@@ -947,8 +958,7 @@ def __call__( | |
| height, width = control_image.shape[-2:] | ||
|
|
||
| control_image = self.vae.encode(control_image).latent_dist.sample() | ||
| control_image = control_image * self.vae.config.scaling_factor | ||
|
|
||
| control_image = (control_image - vae_shift_factor) * self.vae.config.scaling_factor | ||
| elif isinstance(self.controlnet, SD3MultiControlNetModel): | ||
| control_images = [] | ||
|
|
||
|
|
@@ -966,19 +976,14 @@ def __call__( | |
| ) | ||
|
|
||
| control_image_ = self.vae.encode(control_image_).latent_dist.sample() | ||
| control_image_ = control_image_ * self.vae.config.scaling_factor | ||
| control_image_ = (control_image_ - vae_shift_factor) * self.vae.config.scaling_factor | ||
|
|
||
| control_images.append(control_image_) | ||
|
|
||
| control_image = control_images | ||
| else: | ||
| assert False | ||
|
|
||
| if controlnet_pooled_projections is None: | ||
| controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds) | ||
| else: | ||
| controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds | ||
|
|
||
| # 4. Prepare timesteps | ||
| timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) | ||
| num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) | ||
|
|
@@ -1006,6 +1011,18 @@ def __call__( | |
| ] | ||
| controlnet_keep.append(keeps[0] if isinstance(self.controlnet, SD3ControlNetModel) else keeps) | ||
|
|
||
| if controlnet_config.force_zeros_for_pooled_projection: | ||
| # instantx sd3 controlnet used zero pooled projection | ||
| controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds) | ||
| else: | ||
| controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds | ||
|
|
||
| if controlnet_config.joint_attention_dim is not None: | ||
| controlnet_encoder_hidden_states = prompt_embeds | ||
| else: | ||
| # SD35 official 8b controlnet does not use encoder_hidden_states | ||
| controlnet_encoder_hidden_states = None | ||
|
|
||
| # 7. Denoising loop | ||
| with self.progress_bar(total=num_inference_steps) as progress_bar: | ||
| for i, t in enumerate(timesteps): | ||
|
|
@@ -1025,11 +1042,17 @@ def __call__( | |
| controlnet_cond_scale = controlnet_cond_scale[0] | ||
| cond_scale = controlnet_cond_scale * controlnet_keep[i] | ||
|
|
||
| if controlnet_config.use_pos_embed is False: | ||
| # sd35 (offical) 8b controlnet | ||
| controlnet_model_input = self.transformer.pos_embed(latent_model_input) | ||
| else: | ||
| controlnet_model_input = latent_model_input | ||
|
|
||
| # controlnet(s) inference | ||
| control_block_samples = self.controlnet( | ||
| hidden_states=latent_model_input, | ||
| hidden_states=controlnet_model_input, | ||
| timestep=timestep, | ||
| encoder_hidden_states=prompt_embeds, | ||
| encoder_hidden_states=controlnet_encoder_hidden_states, | ||
| pooled_projections=controlnet_pooled_projections, | ||
| joint_attention_kwargs=self.joint_attention_kwargs, | ||
| controlnet_cond=control_image, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a good enough condition for now. Because based
joint_attention_dimwe initialize both thecontext_embeddedandtransformer_blocks(that have theJointTransformerBlocktype). I am okay with it.