@@ -53,7 +53,12 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
5353 }
5454
5555 for i in range (num_down_blocks ):
56- resnets = [key for key in down_blocks [i ] if f"down.{ i } " in key and f"down.{ i } .downsample" not in key ]
56+ resnets = [
57+ key
58+ for key in down_blocks [i ]
59+ if f"down.{ i } " in key and f"down.{ i } .downsample" not in key and "attn" not in key
60+ ]
61+ attentions = [key for key in down_blocks [i ] if f"down.{ i } .attn" in key ]
5762
5863 if f"encoder.down.{ i } .downsample.conv.weight" in vae_state_dict :
5964 new_checkpoint [f"encoder.down_blocks.{ i } .downsamplers.0.conv.weight" ] = vae_state_dict .pop (
@@ -67,6 +72,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
6772 meta_path = {"old" : f"down.{ i } .block" , "new" : f"down_blocks.{ i } .resnets" }
6873 assign_to_checkpoint (paths , new_checkpoint , vae_state_dict , additional_replacements = [meta_path ], config = config )
6974
75+ paths = renew_vae_attention_paths (attentions )
76+ meta_path = {"old" : f"down.{ i } .attn" , "new" : f"down_blocks.{ i } .attentions" }
77+ assign_to_checkpoint (paths , new_checkpoint , vae_state_dict , additional_replacements = [meta_path ], config = config )
78+
7079 mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key ]
7180 num_mid_res_blocks = 2
7281 for i in range (1 , num_mid_res_blocks + 1 ):
@@ -85,8 +94,11 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
8594 for i in range (num_up_blocks ):
8695 block_id = num_up_blocks - 1 - i
8796 resnets = [
88- key for key in up_blocks [block_id ] if f"up.{ block_id } " in key and f"up.{ block_id } .upsample" not in key
97+ key
98+ for key in up_blocks [block_id ]
99+ if f"up.{ block_id } " in key and f"up.{ block_id } .upsample" not in key and "attn" not in key
89100 ]
101+ attentions = [key for key in up_blocks [block_id ] if f"up.{ block_id } .attn" in key ]
90102
91103 if f"decoder.up.{ block_id } .upsample.conv.weight" in vae_state_dict :
92104 new_checkpoint [f"decoder.up_blocks.{ i } .upsamplers.0.conv.weight" ] = vae_state_dict [
@@ -100,6 +112,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
100112 meta_path = {"old" : f"up.{ block_id } .block" , "new" : f"up_blocks.{ i } .resnets" }
101113 assign_to_checkpoint (paths , new_checkpoint , vae_state_dict , additional_replacements = [meta_path ], config = config )
102114
115+ paths = renew_vae_attention_paths (attentions )
116+ meta_path = {"old" : f"up.{ block_id } .attn" , "new" : f"up_blocks.{ i } .attentions" }
117+ assign_to_checkpoint (paths , new_checkpoint , vae_state_dict , additional_replacements = [meta_path ], config = config )
118+
103119 mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key ]
104120 num_mid_res_blocks = 2
105121 for i in range (1 , num_mid_res_blocks + 1 ):
0 commit comments