Skip to content

Commit 8d45c2f

Browse files
committed
fix attention
1 parent 1001425 commit 8d45c2f

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

scripts/convert_vae_pt_to_diffusers.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
5252
}
5353

5454
for i in range(num_down_blocks):
55-
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
55+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key and "attn" not in key]
56+
attentions = [key for key in down_blocks[i] if f"down.{i}.attn" in key]
5657

5758
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
5859
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
@@ -66,6 +67,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
6667
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
6768
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
6869

70+
paths = renew_vae_attention_paths(attentions)
71+
meta_path = {"old": f"down.{i}.attn", "new": f"down_blocks.{i}.attentions"}
72+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
73+
6974
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
7075
num_mid_res_blocks = 2
7176
for i in range(1, num_mid_res_blocks + 1):
@@ -84,7 +89,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
8489
for i in range(num_up_blocks):
8590
block_id = num_up_blocks - 1 - i
8691
resnets = [
87-
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
92+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key and "attn" not in key
93+
]
94+
attentions = [
95+
key for key in up_blocks[block_id] if f"up.{block_id}.attn" in key
8896
]
8997

9098
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
@@ -99,6 +107,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
99107
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
100108
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
101109

110+
paths = renew_vae_attention_paths(attentions)
111+
meta_path = {"old": f"up.{block_id}.attn", "new": f"up_blocks.{i}.attentions"}
112+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
113+
102114
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
103115
num_mid_res_blocks = 2
104116
for i in range(1, num_mid_res_blocks + 1):

src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,8 @@ def create_vae_diffusers_config(original_config, image_size: int):
349349
_ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
350350

351351
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
352-
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
353-
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
352+
down_block_types = ["DownEncoderBlock2D" if image_size // 2 ** i not in vae_params["attn_resolutions"] else "AttnDownEncoderBlock2D" for i, _ in enumerate(block_out_channels)]
353+
up_block_types = ["UpDecoderBlock2D" if image_size // 2 ** i not in vae_params["attn_resolutions"] else "AttnUpDecoderBlock2D" for i, _ in enumerate(block_out_channels)][::-1]
354354

355355
config = {
356356
"sample_size": image_size,

0 commit comments

Comments
 (0)