diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md index fd026f07c923..8170c5280d38 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md +++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md @@ -313,6 +313,26 @@ image = pipe("a picture of a cat holding a sign that says hello world").images[0 image.save('sd3-single-file-t5-fp8.png') ``` +### Loading the single file checkpoint for the Stable Diffusion 3.5 Transformer Model + +```python +import torch +from diffusers import SD3Transformer2DModel, StableDiffusion3Pipeline + +transformer = SD3Transformer2DModel.from_single_file( + "https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo/blob/main/sd3.5_large.safetensors", + torch_dtype=torch.bfloat16, +) +pipe = StableDiffusion3Pipeline.from_pretrained( + "stabilityai/stable-diffusion-3.5-large", + transformer=transformer, + torch_dtype=torch.bfloat16, +) +pipe.enable_model_cpu_offload() +image = pipe("a cat holding a sign that says hello world").images[0] +image.save("sd35.png") +``` + ## StableDiffusion3Pipeline [[autodoc]] StableDiffusion3Pipeline diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 236fbd0c2295..d1bad8b5a7cd 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -75,6 +75,7 @@ "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight", "stable_cascade_stage_c": "clip_txt_mapper.weight", "sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias", + "sd35_large": "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight", "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe", "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias", "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight", @@ -113,6 +114,9 @@ "sd3": { "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers", }, + "sd35_large": { + "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-large", + }, "animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"}, "animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"}, "animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"}, @@ -504,9 +508,12 @@ def infer_diffusers_model_type(checkpoint): ): model_type = "stable_cascade_stage_b" - elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint: + elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["sd3"]].shape[-1] == 9216: model_type = "sd3" + elif CHECKPOINT_KEY_NAMES["sd35_large"] in checkpoint: + model_type = "sd35_large" + elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint: if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint: model_type = "animatediff_scribble" @@ -1670,6 +1677,22 @@ def swap_scale_shift(weight, dim): return new_weight +def get_attn2_layers(state_dict): + attn2_layers = [] + for key in state_dict.keys(): + if "attn2." in key: + # Extract the layer number from the key + layer_num = int(key.split(".")[1]) + attn2_layers.append(layer_num) + + return tuple(sorted(set(attn2_layers))) + + +def get_caption_projection_dim(state_dict): + caption_projection_dim = state_dict["context_embedder.weight"].shape[0] + return caption_projection_dim + + def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict = {} keys = list(checkpoint.keys()) @@ -1678,7 +1701,10 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401 - caption_projection_dim = 1536 + dual_attention_layers = get_attn2_layers(checkpoint) + + caption_projection_dim = get_caption_projection_dim(checkpoint) + has_qk_norm = any("ln_q" in key for key in checkpoint.keys()) # Positional and patch embeddings. converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed") @@ -1735,6 +1761,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v]) converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias]) + # qk norm + if has_qk_norm: + converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn.ln_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn.ln_k.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.attn.ln_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.attn.ln_k.weight" + ) + # output projections. converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop( f"joint_blocks.{i}.x_block.attn.proj.weight" @@ -1750,6 +1791,38 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): f"joint_blocks.{i}.context_block.attn.proj.bias" ) + if i in dual_attention_layers: + # Q, K, V + sample_q2, sample_k2, sample_v2 = torch.chunk( + checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0 + ) + sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk( + checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0 + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias]) + + # qk norm + if has_qk_norm: + converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn2.ln_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn2.ln_k.weight" + ) + + # output projections. + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn2.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn2.proj.bias" + ) + # norms. converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop( f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"