@@ -52,7 +52,8 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
5252 }
5353
5454 for i in range (num_down_blocks ):
55- resnets = [key for key in down_blocks [i ] if f"down.{ i } " in key and f"down.{ i } .downsample" not in key ]
55+ resnets = [key for key in down_blocks [i ] if f"down.{ i } " in key and f"down.{ i } .downsample" not in key and "attn" not in key ]
56+ attentions = [key for key in down_blocks [i ] if f"down.{ i } .attn" in key ]
5657
5758 if f"encoder.down.{ i } .downsample.conv.weight" in vae_state_dict :
5859 new_checkpoint [f"encoder.down_blocks.{ i } .downsamplers.0.conv.weight" ] = vae_state_dict .pop (
@@ -66,6 +67,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
6667 meta_path = {"old" : f"down.{ i } .block" , "new" : f"down_blocks.{ i } .resnets" }
6768 assign_to_checkpoint (paths , new_checkpoint , vae_state_dict , additional_replacements = [meta_path ], config = config )
6869
70+ paths = renew_vae_attention_paths (attentions )
71+ meta_path = {"old" : f"down.{ i } .attn" , "new" : f"down_blocks.{ i } .attentions" }
72+ assign_to_checkpoint (paths , new_checkpoint , vae_state_dict , additional_replacements = [meta_path ], config = config )
73+
6974 mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key ]
7075 num_mid_res_blocks = 2
7176 for i in range (1 , num_mid_res_blocks + 1 ):
@@ -84,7 +89,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
8489 for i in range (num_up_blocks ):
8590 block_id = num_up_blocks - 1 - i
8691 resnets = [
87- key for key in up_blocks [block_id ] if f"up.{ block_id } " in key and f"up.{ block_id } .upsample" not in key
92+ key for key in up_blocks [block_id ] if f"up.{ block_id } " in key and f"up.{ block_id } .upsample" not in key and "attn" not in key
93+ ]
94+ attentions = [
95+ key for key in up_blocks [block_id ] if f"up.{ block_id } .attn" in key
8896 ]
8997
9098 if f"decoder.up.{ block_id } .upsample.conv.weight" in vae_state_dict :
@@ -99,6 +107,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
99107 meta_path = {"old" : f"up.{ block_id } .block" , "new" : f"up_blocks.{ i } .resnets" }
100108 assign_to_checkpoint (paths , new_checkpoint , vae_state_dict , additional_replacements = [meta_path ], config = config )
101109
110+ paths = renew_vae_attention_paths (attentions )
111+ meta_path = {"old" : f"up.{ block_id } .attn" , "new" : f"up_blocks.{ i } .attentions" }
112+ assign_to_checkpoint (paths , new_checkpoint , vae_state_dict , additional_replacements = [meta_path ], config = config )
113+
102114 mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key ]
103115 num_mid_res_blocks = 2
104116 for i in range (1 , num_mid_res_blocks + 1 ):
0 commit comments