From 8e659b97894849bfddf3857213bde84f1c7dcd06 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 29 Sep 2024 23:22:28 +0200 Subject: [PATCH 1/7] add changes for sd3.4 --- scripts/convert_sd3_to_diffusers.py | 96 ++++++++++++++++++- src/diffusers/models/attention.py | 41 +++++++- src/diffusers/models/attention_processor.py | 83 +++++++++------- src/diffusers/models/normalization.py | 34 +++++++ .../models/transformers/transformer_sd3.py | 6 +- 5 files changed, 216 insertions(+), 44 deletions(-) diff --git a/scripts/convert_sd3_to_diffusers.py b/scripts/convert_sd3_to_diffusers.py index 4f32745dae75..2abefde372e3 100644 --- a/scripts/convert_sd3_to_diffusers.py +++ b/scripts/convert_sd3_to_diffusers.py @@ -16,10 +16,19 @@ parser = argparse.ArgumentParser() parser.add_argument("--checkpoint_path", type=str) parser.add_argument("--output_path", type=str) -parser.add_argument("--dtype", type=str, default="fp16") +parser.add_argument("--dtype", type=str) args = parser.parse_args() -dtype = torch.float16 if args.dtype == "fp16" else torch.float32 + +# if dtype is not specified, use the dtype of the original checkpoint(recommended) +if args.dtype == "fp16": + dtype = torch.float16 +elif args.dtype == "bf16": + dtype = torch.bfloat16 +elif args.dtype == "fp32": + dtype = torch.float32 +else: + dtype = None def load_original_checkpoint(ckpt_path): @@ -40,7 +49,9 @@ def swap_scale_shift(weight, dim): return new_weight -def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_layers, caption_projection_dim): +def convert_sd3_transformer_checkpoint_to_diffusers( + original_state_dict, num_layers, caption_projection_dim, add_attn2_layers, has_qk_norm +): converted_state_dict = {} # Positional and patch embeddings. @@ -110,6 +121,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v]) converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias]) + # qk norm + if has_qk_norm: + converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.attn.ln_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.attn.ln_k.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.context_block.attn.ln_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.context_block.attn.ln_k.weight" + ) + # output projections. converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = original_state_dict.pop( f"joint_blocks.{i}.x_block.attn.proj.weight" @@ -125,6 +151,39 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay f"joint_blocks.{i}.context_block.attn.proj.bias" ) + # attn2 + if i in add_attn2_layers: + # Q, K, V + sample_q2, sample_k2, sample_v2 = torch.chunk( + original_state_dict.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0 + ) + sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk( + original_state_dict.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0 + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias]) + + # qk norm + if has_qk_norm: + converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.attn2.ln_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.attn2.ln_k.weight" + ) + + # output projections. + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.attn2.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.attn2.proj.bias" + ) + # norms. converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = original_state_dict.pop( f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight" @@ -186,6 +245,9 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay original_state_dict.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim ) + if len(original_state_dict) > 0: + raise ValueError(f"{len(original_state_dict)} keys are not converted: {original_state_dict.keys()}") + return converted_state_dict @@ -195,13 +257,35 @@ def is_vae_in_checkpoint(original_state_dict): ) +def get_add_attn2_layers(state_dict): + add_attn2_layers = [] + for key in state_dict.keys(): + if "attn2.to_q.weight" in key: + # Extract the layer number from the key + layer_num = int(key.split(".")[1]) + add_attn2_layers.append(layer_num) + return tuple(sorted(add_attn2_layers)) + + def main(args): original_ckpt = load_original_checkpoint(args.checkpoint_path) + original_dtype = next(iter(original_ckpt.values())).dtype + if dtype is None: + dtype = original_dtype + elif dtype != original_dtype: + print(f"Checkpoint dtype {original_dtype} does not match requested dtype {dtype}") + num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401 caption_projection_dim = 1536 + # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5 + add_attn2_layers = get_add_attn2_layers(original_ckpt) + # sd3.5 use qk norm("rms_norm") + has_qk_norm = any("ln_q" in key for key in original_ckpt.keys()) + # sd3.5 use pox_embed_max_size=384 and sd3.0 use 192 + pos_embed_max_size = 384 if has_qk_norm else 192 converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers( - original_ckpt, num_layers, caption_projection_dim + original_ckpt, num_layers, caption_projection_dim, add_attn2_layers, has_qk_norm ) with CTX(): @@ -213,7 +297,9 @@ def main(args): num_layers=num_layers, caption_projection_dim=caption_projection_dim, num_attention_heads=24, - pos_embed_max_size=192, + pos_embed_max_size=pos_embed_max_size, + qk_norm="rms_norm" if has_qk_norm else None, + add_attn2_layers=add_attn2_layers, ) if is_accelerate_available(): load_model_dict_into_meta(transformer, converted_transformer_state_dict) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 84db0d061768..78fb5b9d1842 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -22,7 +22,7 @@ from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU from .attention_processor import Attention, JointAttnProcessor2_0 from .embeddings import SinusoidalPositionalEmbedding -from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm +from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX logger = logging.get_logger(__name__) @@ -100,13 +100,19 @@ class JointTransformerBlock(nn.Module): processing of `context` conditions. """ - def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False): + def __init__( + self, dim, num_attention_heads, attention_head_dim, context_pre_only=False, qk_norm=None, add_attn2=False + ): super().__init__() + self.add_attn2 = add_attn2 self.context_pre_only = context_pre_only context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero" - self.norm1 = AdaLayerNormZero(dim) + if add_attn2: + self.norm1 = SD35AdaLayerNormZeroX(dim) + else: + self.norm1 = AdaLayerNormZero(dim) if context_norm_type == "ada_norm_continous": self.norm1_context = AdaLayerNormContinuous( @@ -134,8 +140,25 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_onl context_pre_only=context_pre_only, bias=True, processor=processor, + qk_norm=qk_norm, + eps=1e-6, ) + if add_attn2: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=processor, + qk_norm=qk_norm, + eps=1e-6, + ) + else: + self.attn2 = None + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") @@ -159,7 +182,12 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor ): - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + if self.add_attn2: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( + hidden_states, emb=temb + ) + else: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) if self.context_pre_only: norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) @@ -177,6 +205,11 @@ def forward( attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = hidden_states + attn_output + if self.add_attn2: + attn_output2 = self.attn2(hidden_states=norm_hidden_states2) + attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 + hidden_states = hidden_states + attn_output2 + norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] if self._chunk_size is not None: diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index d333590982e3..b90efeb46ddc 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -193,7 +193,7 @@ def __init__( self.norm_q = RMSNorm(dim_head, eps=eps) self.norm_k = RMSNorm(dim_head, eps=eps) else: - raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'") if cross_attention_norm is None: self.norm_cross = None @@ -250,6 +250,10 @@ def __init__( elif qk_norm == "rms_norm": self.norm_added_q = RMSNorm(dim_head, eps=eps) self.norm_added_k = RMSNorm(dim_head, eps=eps) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'" + ) else: self.norm_added_q = None self.norm_added_k = None @@ -1050,61 +1054,72 @@ def __call__( ) -> torch.FloatTensor: residual = hidden_states - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - context_input_ndim = encoder_hidden_states.ndim - if context_input_ndim == 4: - batch_size, channel, height, width = encoder_hidden_states.shape - encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size = encoder_hidden_states.shape[0] + batch_size = hidden_states.shape[0] # `sample` projections. query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - # attention - query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) - key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) - value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) - inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - # Split the attention outputs. - hidden_states, encoder_hidden_states = ( - hidden_states[:, : residual.shape[1]], - hidden_states[:, residual.shape[1] :], - ) + if encoder_hidden_states is not None: + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) - if not attn.context_pre_only: - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - if context_input_ndim == 4: - encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - return hidden_states, encoder_hidden_states + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states class PAGJointAttnProcessor2_0: diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 21e9d3cd6fc5..6c6f00edb8a4 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -97,6 +97,40 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: ).to(origin_dtype) +class SD35AdaLayerNormZeroX(nn.Module): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 9 * embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + else: + raise ValueError(f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm'.") + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.chunk( + 9, dim=1 + ) + normed_x = self.norm(x) + x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None] + x2 = normed_x * (1 + scale_msa2[:, None]) + shift_msa2[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, x2, gate_msa2 + + class AdaLayerNormZero(nn.Module): r""" Norm layer adaptive layer norm zero (adaLN-Zero). diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 9376c91d0756..424644b4a179 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -69,6 +69,8 @@ def __init__( pooled_projection_dim: int = 2048, out_channels: int = 16, pos_embed_max_size: int = 96, + add_attn2_layers: Tuple[int, ...] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5 + qk_norm: Optional[str] = None, ): super().__init__() default_out_channels = in_channels @@ -97,6 +99,8 @@ def __init__( num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, context_pre_only=i == num_layers - 1, + qk_norm=qk_norm, + add_attn2=True if i in add_attn2_layers else False, ) for i in range(self.config.num_layers) ] From 3ab805a658e88d89b8b43bd536d430acd17414eb Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 30 Sep 2024 07:50:11 +0200 Subject: [PATCH 2/7] update --- scripts/convert_sd3_to_diffusers.py | 57 ++++++++++++++++++----------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/scripts/convert_sd3_to_diffusers.py b/scripts/convert_sd3_to_diffusers.py index 2abefde372e3..ac836c176c8f 100644 --- a/scripts/convert_sd3_to_diffusers.py +++ b/scripts/convert_sd3_to_diffusers.py @@ -20,16 +20,6 @@ args = parser.parse_args() -# if dtype is not specified, use the dtype of the original checkpoint(recommended) -if args.dtype == "fp16": - dtype = torch.float16 -elif args.dtype == "bf16": - dtype = torch.bfloat16 -elif args.dtype == "fp32": - dtype = torch.float32 -else: - dtype = None - def load_original_checkpoint(ckpt_path): original_state_dict = safetensors.torch.load_file(ckpt_path) @@ -245,9 +235,6 @@ def convert_sd3_transformer_checkpoint_to_diffusers( original_state_dict.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim ) - if len(original_state_dict) > 0: - raise ValueError(f"{len(original_state_dict)} keys are not converted: {original_state_dict.keys()}") - return converted_state_dict @@ -260,29 +247,57 @@ def is_vae_in_checkpoint(original_state_dict): def get_add_attn2_layers(state_dict): add_attn2_layers = [] for key in state_dict.keys(): - if "attn2.to_q.weight" in key: + if "attn2." in key: # Extract the layer number from the key layer_num = int(key.split(".")[1]) add_attn2_layers.append(layer_num) return tuple(sorted(add_attn2_layers)) +def get_pos_embed_max_size(state_dict): + num_patches = state_dict["pos_embed"].shape[1] + pos_embed_max_size = int(num_patches**0.5) + return pos_embed_max_size + + +def get_caption_projection_dim(state_dict): + caption_projection_dim = state_dict["context_embedder.weight"].shape[0] + return caption_projection_dim + + def main(args): original_ckpt = load_original_checkpoint(args.checkpoint_path) original_dtype = next(iter(original_ckpt.values())).dtype - if dtype is None: + + # Initialize dtype with a default value + dtype = None + + if args.dtype is None: dtype = original_dtype - elif dtype != original_dtype: + elif args.dtype == "fp16": + dtype = torch.float16 + elif args.dtype == "bf16": + dtype = torch.bfloat16 + elif args.dtype == "fp32": + dtype = torch.float32 + else: + raise ValueError(f"Unsupported dtype: {args.dtype}") + + if dtype != original_dtype: print(f"Checkpoint dtype {original_dtype} does not match requested dtype {dtype}") num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401 - caption_projection_dim = 1536 + + caption_projection_dim = get_caption_projection_dim(original_ckpt) + # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5 add_attn2_layers = get_add_attn2_layers(original_ckpt) + # sd3.5 use qk norm("rms_norm") has_qk_norm = any("ln_q" in key for key in original_ckpt.keys()) - # sd3.5 use pox_embed_max_size=384 and sd3.0 use 192 - pos_embed_max_size = 384 if has_qk_norm else 192 + + # sd3.5 2b use pox_embed_max_size=384 and sd3.0 and sd3.5 8b use 192 + pos_embed_max_size = get_pos_embed_max_size(original_ckpt) converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers( original_ckpt, num_layers, caption_projection_dim, add_attn2_layers, has_qk_norm @@ -290,13 +305,13 @@ def main(args): with CTX(): transformer = SD3Transformer2DModel( - sample_size=64, + sample_size=128, patch_size=2, in_channels=16, joint_attention_dim=4096, num_layers=num_layers, caption_projection_dim=caption_projection_dim, - num_attention_heads=24, + num_attention_heads=num_layers, pos_embed_max_size=pos_embed_max_size, qk_norm="rms_norm" if has_qk_norm else None, add_attn2_layers=add_attn2_layers, From 79a1d31c2b50a7dc5710161d24d2170455fc0bd1 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 30 Sep 2024 08:10:41 +0200 Subject: [PATCH 3/7] fix --- scripts/convert_sd3_to_diffusers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/convert_sd3_to_diffusers.py b/scripts/convert_sd3_to_diffusers.py index ac836c176c8f..3dafc21487bd 100644 --- a/scripts/convert_sd3_to_diffusers.py +++ b/scripts/convert_sd3_to_diffusers.py @@ -251,7 +251,7 @@ def get_add_attn2_layers(state_dict): # Extract the layer number from the key layer_num = int(key.split(".")[1]) add_attn2_layers.append(layer_num) - return tuple(sorted(add_attn2_layers)) + return tuple(sorted(set(add_attn2_layers))) def get_pos_embed_max_size(state_dict): From db47e4200214b6b585a6b5bdf912842539ca9fca Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 19 Oct 2024 02:39:11 +0530 Subject: [PATCH 4/7] Update scripts/convert_sd3_to_diffusers.py Co-authored-by: Sayak Paul --- scripts/convert_sd3_to_diffusers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/convert_sd3_to_diffusers.py b/scripts/convert_sd3_to_diffusers.py index 3dafc21487bd..2a0bb5cc489d 100644 --- a/scripts/convert_sd3_to_diffusers.py +++ b/scripts/convert_sd3_to_diffusers.py @@ -284,7 +284,7 @@ def main(args): raise ValueError(f"Unsupported dtype: {args.dtype}") if dtype != original_dtype: - print(f"Checkpoint dtype {original_dtype} does not match requested dtype {dtype}") + print(f"Checkpoint dtype {original_dtype} does not match requested dtype {dtype}. This can lead to unexpected results, proceed with caution.") num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401 From e96d5ad599e62d44d3d1335863b28a2f39ee9f2b Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 19 Oct 2024 02:39:26 +0530 Subject: [PATCH 5/7] Update src/diffusers/models/attention_processor.py Co-authored-by: Sayak Paul --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index b90efeb46ddc..e735c4ee7d17 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -252,7 +252,7 @@ def __init__( self.norm_added_k = RMSNorm(dim_head, eps=eps) else: raise ValueError( - f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'" + f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`" ) else: self.norm_added_q = None From afc60a001bdcc1db70de417e35ad935ce91bfcdd Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 18 Oct 2024 23:43:46 +0200 Subject: [PATCH 6/7] apply suggestions from review; update docs --- .../stable_diffusion/stable_diffusion_3.md | 5 ++ scripts/convert_sd3_to_diffusers.py | 22 +++---- src/diffusers/models/attention.py | 20 +++++-- src/diffusers/models/normalization.py | 16 ++--- .../models/transformers/transformer_sd3.py | 6 +- .../test_models_transformer_sd3.py | 59 +++++++++++++++++++ 6 files changed, 102 insertions(+), 26 deletions(-) diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md index 388c2d1a87cd..1bcb2282410e 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md +++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md @@ -54,6 +54,11 @@ image = pipe( image.save("sd3_hello_world.png") ``` +**Note:** Stable Diffusion 3.5 can also be run using the SD3 pipeline, and all mentioned optimizations and techniques apply to it as well. In total there are three official models in the SD3 family: +- [`stabilityai/stable-diffusion-3-medium-diffusers`](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) +- [`stabilityai/stable-diffusion-3.5-medium-diffusers`](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium-diffusers) +- [`stabilityai/stable-diffusion-3.5-large-diffusers`](https://huggingface.co/stabilityai/stable-diffusion-3.5-large-diffusers) + ## Memory Optimisations for SD3 SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware. diff --git a/scripts/convert_sd3_to_diffusers.py b/scripts/convert_sd3_to_diffusers.py index 2a0bb5cc489d..1f9c434b39d0 100644 --- a/scripts/convert_sd3_to_diffusers.py +++ b/scripts/convert_sd3_to_diffusers.py @@ -40,7 +40,7 @@ def swap_scale_shift(weight, dim): def convert_sd3_transformer_checkpoint_to_diffusers( - original_state_dict, num_layers, caption_projection_dim, add_attn2_layers, has_qk_norm + original_state_dict, num_layers, caption_projection_dim, dual_attention_layers, has_qk_norm ): converted_state_dict = {} @@ -142,7 +142,7 @@ def convert_sd3_transformer_checkpoint_to_diffusers( ) # attn2 - if i in add_attn2_layers: + if i in dual_attention_layers: # Q, K, V sample_q2, sample_k2, sample_v2 = torch.chunk( original_state_dict.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0 @@ -244,14 +244,14 @@ def is_vae_in_checkpoint(original_state_dict): ) -def get_add_attn2_layers(state_dict): - add_attn2_layers = [] +def get_attn2_layers(state_dict): + attn2_layers = [] for key in state_dict.keys(): if "attn2." in key: # Extract the layer number from the key layer_num = int(key.split(".")[1]) - add_attn2_layers.append(layer_num) - return tuple(sorted(set(add_attn2_layers))) + attn2_layers.append(layer_num) + return tuple(sorted(set(attn2_layers))) def get_pos_embed_max_size(state_dict): @@ -284,14 +284,16 @@ def main(args): raise ValueError(f"Unsupported dtype: {args.dtype}") if dtype != original_dtype: - print(f"Checkpoint dtype {original_dtype} does not match requested dtype {dtype}. This can lead to unexpected results, proceed with caution.") + print( + f"Checkpoint dtype {original_dtype} does not match requested dtype {dtype}. This can lead to unexpected results, proceed with caution." + ) num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401 caption_projection_dim = get_caption_projection_dim(original_ckpt) # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5 - add_attn2_layers = get_add_attn2_layers(original_ckpt) + attn2_layers = get_attn2_layers(original_ckpt) # sd3.5 use qk norm("rms_norm") has_qk_norm = any("ln_q" in key for key in original_ckpt.keys()) @@ -300,7 +302,7 @@ def main(args): pos_embed_max_size = get_pos_embed_max_size(original_ckpt) converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers( - original_ckpt, num_layers, caption_projection_dim, add_attn2_layers, has_qk_norm + original_ckpt, num_layers, caption_projection_dim, attn2_layers, has_qk_norm ) with CTX(): @@ -314,7 +316,7 @@ def main(args): num_attention_heads=num_layers, pos_embed_max_size=pos_embed_max_size, qk_norm="rms_norm" if has_qk_norm else None, - add_attn2_layers=add_attn2_layers, + dual_attention_layers=attn2_layers, ) if is_accelerate_available(): load_model_dict_into_meta(transformer, converted_transformer_state_dict) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 78fb5b9d1842..02ed1f965abf 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -101,15 +101,21 @@ class JointTransformerBlock(nn.Module): """ def __init__( - self, dim, num_attention_heads, attention_head_dim, context_pre_only=False, qk_norm=None, add_attn2=False + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + context_pre_only: bool = False, + qk_norm: Optional[str] = None, + use_dual_attention: bool = False, ): super().__init__() - self.add_attn2 = add_attn2 + self.use_dual_attention = use_dual_attention self.context_pre_only = context_pre_only context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero" - if add_attn2: + if use_dual_attention: self.norm1 = SD35AdaLayerNormZeroX(dim) else: self.norm1 = AdaLayerNormZero(dim) @@ -124,12 +130,14 @@ def __init__( raise ValueError( f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`" ) + if hasattr(F, "scaled_dot_product_attention"): processor = JointAttnProcessor2_0() else: raise ValueError( "The current PyTorch version does not support the `scaled_dot_product_attention` function." ) + self.attn = Attention( query_dim=dim, cross_attention_dim=None, @@ -144,7 +152,7 @@ def __init__( eps=1e-6, ) - if add_attn2: + if use_dual_attention: self.attn2 = Attention( query_dim=dim, cross_attention_dim=None, @@ -182,7 +190,7 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor ): - if self.add_attn2: + if self.use_dual_attention: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( hidden_states, emb=temb ) @@ -205,7 +213,7 @@ def forward( attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = hidden_states + attn_output - if self.add_attn2: + if self.use_dual_attention: attn_output2 = self.attn2(hidden_states=norm_hidden_states2) attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 hidden_states = hidden_states + attn_output2 diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 6c6f00edb8a4..029c147fcbac 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -99,14 +99,14 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: class SD35AdaLayerNormZeroX(nn.Module): r""" - Norm layer adaptive layer norm zero (adaLN-Zero). + Norm layer adaptive layer norm zero (AdaLN-Zero). Parameters: embedding_dim (`int`): The size of each embedding vector. num_embeddings (`int`): The size of the embeddings dictionary. """ - def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): + def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True) -> None: super().__init__() self.silu = nn.SiLU() @@ -118,17 +118,17 @@ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): def forward( self, - x: torch.Tensor, + hidden_states: torch.Tensor, emb: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, ...]: emb = self.linear(self.silu(emb)) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.chunk( 9, dim=1 ) - normed_x = self.norm(x) - x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None] - x2 = normed_x * (1 + scale_msa2[:, None]) + shift_msa2[:, None] - return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, x2, gate_msa2 + norm_hidden_states = self.norm(hidden_states) + hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None] + norm_hidden_states2 = norm_hidden_states * (1 + scale_msa2[:, None]) + shift_msa2[:, None] + return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 class AdaLayerNormZero(nn.Module): diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 424644b4a179..b28350b8ed9c 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -69,7 +69,9 @@ def __init__( pooled_projection_dim: int = 2048, out_channels: int = 16, pos_embed_max_size: int = 96, - add_attn2_layers: Tuple[int, ...] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5 + dual_attention_layers: Tuple[ + int, ... + ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5 qk_norm: Optional[str] = None, ): super().__init__() @@ -100,7 +102,7 @@ def __init__( attention_head_dim=self.config.attention_head_dim, context_pre_only=i == num_layers - 1, qk_norm=qk_norm, - add_attn2=True if i in add_attn2_layers else False, + use_dual_attention=True if i in dual_attention_layers else False, ) for i in range(self.config.num_layers) ] diff --git a/tests/models/transformers/test_models_transformer_sd3.py b/tests/models/transformers/test_models_transformer_sd3.py index 2b9084327289..2be4744c5ac4 100644 --- a/tests/models/transformers/test_models_transformer_sd3.py +++ b/tests/models/transformers/test_models_transformer_sd3.py @@ -73,6 +73,65 @@ def prepare_init_args_and_inputs_for_common(self): "joint_attention_dim": 32, "pooled_projection_dim": 64, "out_channels": 4, + "pos_embed_max_size": 96, + "dual_attention_layers": (), + "qk_norm": None, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + @unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply") + def test_set_attn_processor_for_determinism(self): + pass + + +class SD35TransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = SD3Transformer2DModel + main_input_name = "hidden_states" + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + height = width = embedding_dim = 32 + pooled_embedding_dim = embedding_dim * 2 + sequence_length = 154 + + hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_prompt_embeds, + "timestep": timestep, + } + + @property + def input_shape(self): + return (4, 32, 32) + + @property + def output_shape(self): + return (4, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "sample_size": 32, + "patch_size": 1, + "in_channels": 4, + "num_layers": 2, + "attention_head_dim": 8, + "num_attention_heads": 4, + "caption_projection_dim": 32, + "joint_attention_dim": 32, + "pooled_projection_dim": 64, + "out_channels": 4, + "pos_embed_max_size": 96, + "dual_attention_layers": (0,), + "qk_norm": "rms_norm", } inputs_dict = self.dummy_input return init_dict, inputs_dict From 6f486ee6581434a13fd0d10e8783a34c9469da14 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Mon, 21 Oct 2024 09:42:38 -1000 Subject: [PATCH 7/7] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: apolinário --- .../en/api/pipelines/stable_diffusion/stable_diffusion_3.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md index 1bcb2282410e..fd026f07c923 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md +++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md @@ -56,8 +56,8 @@ image.save("sd3_hello_world.png") **Note:** Stable Diffusion 3.5 can also be run using the SD3 pipeline, and all mentioned optimizations and techniques apply to it as well. In total there are three official models in the SD3 family: - [`stabilityai/stable-diffusion-3-medium-diffusers`](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) -- [`stabilityai/stable-diffusion-3.5-medium-diffusers`](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium-diffusers) -- [`stabilityai/stable-diffusion-3.5-large-diffusers`](https://huggingface.co/stabilityai/stable-diffusion-3.5-large-diffusers) +- [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large) +- [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large-turbo) ## Memory Optimisations for SD3