Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,26 @@ image = pipe("a picture of a cat holding a sign that says hello world").images[0
image.save('sd3-single-file-t5-fp8.png')
```

### Loading the single file checkpoint for the Stable Diffusion 3.5 Transformer Model

```python
import torch
from diffusers import SD3Transformer2DModel, StableDiffusion3Pipeline

transformer = SD3Transformer2DModel.from_single_file(
"https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo/blob/main/sd3.5_large.safetensors",
torch_dtype=torch.bfloat16,
)
pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3.5-large",
transformer=transformer,
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
image = pipe("a cat holding a sign that says hello world").images[0]
image.save("sd35.png")
```

## StableDiffusion3Pipeline

[[autodoc]] StableDiffusion3Pipeline
Expand Down
77 changes: 75 additions & 2 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
"stable_cascade_stage_c": "clip_txt_mapper.weight",
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
"sd35_large": "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
Expand Down Expand Up @@ -113,6 +114,9 @@
"sd3": {
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
},
"sd35_large": {
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-large",
},
"animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"},
"animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
"animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
Expand Down Expand Up @@ -504,9 +508,12 @@ def infer_diffusers_model_type(checkpoint):
):
model_type = "stable_cascade_stage_b"

elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["sd3"]].shape[-1] == 9216:
model_type = "sd3"

elif CHECKPOINT_KEY_NAMES["sd35_large"] in checkpoint:
model_type = "sd35_large"

elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
model_type = "animatediff_scribble"
Expand Down Expand Up @@ -1670,6 +1677,22 @@ def swap_scale_shift(weight, dim):
return new_weight


def get_attn2_layers(state_dict):
attn2_layers = []
for key in state_dict.keys():
if "attn2." in key:
# Extract the layer number from the key
layer_num = int(key.split(".")[1])
attn2_layers.append(layer_num)

return tuple(sorted(set(attn2_layers)))


def get_caption_projection_dim(state_dict):
caption_projection_dim = state_dict["context_embedder.weight"].shape[0]
return caption_projection_dim


def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
Expand All @@ -1678,7 +1701,10 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)

num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401
caption_projection_dim = 1536
dual_attention_layers = get_attn2_layers(checkpoint)

caption_projection_dim = get_caption_projection_dim(checkpoint)
has_qk_norm = any("ln_q" in key for key in checkpoint.keys())

# Positional and patch embeddings.
converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed")
Expand Down Expand Up @@ -1735,6 +1761,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])

# qk norm
if has_qk_norm:
converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = checkpoint.pop(
f"joint_blocks.{i}.x_block.attn.ln_q.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = checkpoint.pop(
f"joint_blocks.{i}.x_block.attn.ln_k.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = checkpoint.pop(
f"joint_blocks.{i}.context_block.attn.ln_q.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = checkpoint.pop(
f"joint_blocks.{i}.context_block.attn.ln_k.weight"
)

# output projections.
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop(
f"joint_blocks.{i}.x_block.attn.proj.weight"
Expand All @@ -1750,6 +1791,38 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
f"joint_blocks.{i}.context_block.attn.proj.bias"
)

if i in dual_attention_layers:
# Q, K, V
sample_q2, sample_k2, sample_v2 = torch.chunk(
checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0
)
sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk(
checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0
)
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2])
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias])
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2])
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias])
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2])
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias])

# qk norm
if has_qk_norm:
converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = checkpoint.pop(
f"joint_blocks.{i}.x_block.attn2.ln_q.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = checkpoint.pop(
f"joint_blocks.{i}.x_block.attn2.ln_k.weight"
)

# output projections.
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
f"joint_blocks.{i}.x_block.attn2.proj.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
f"joint_blocks.{i}.x_block.attn2.proj.bias"
)

# norms.
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
Expand Down
Loading