@@ -764,7 +764,7 @@ def convert_vae_22():
764764        "conv2.weight" : "post_quant_conv.weight" ,
765765        "conv2.bias" : "post_quant_conv.bias" ,
766766    }
767-          
767+ 
768768    # Process each key in the state dict 
769769    for  key , value  in  old_state_dict .items ():
770770        # Handle middle block keys using the mapping 
@@ -797,12 +797,12 @@ def convert_vae_22():
797797        elif  key .startswith ("encoder.downsamples." ):
798798            # Change encoder.downsamples to encoder.down_blocks 
799799            new_key  =  key .replace ("encoder.downsamples." , "encoder.down_blocks." )
800-              
800+ 
801801            # Handle residual blocks - change downsamples to resnets and rename components 
802802            if  "residual"  in  new_key  or  "shortcut"  in  new_key :
803803                # Change the second downsamples to resnets 
804804                new_key  =  new_key .replace (".downsamples." , ".resnets." )
805-                  
805+ 
806806                # Rename residual components 
807807                if  ".residual.0.gamma"  in  new_key :
808808                    new_key  =  new_key .replace (".residual.0.gamma" , ".norm1.gamma" )
@@ -820,7 +820,7 @@ def convert_vae_22():
820820                    new_key  =  new_key .replace (".shortcut.weight" , ".conv_shortcut.weight" )
821821                elif  ".shortcut.bias"  in  new_key :
822822                    new_key  =  new_key .replace (".shortcut.bias" , ".conv_shortcut.bias" )
823-              
823+ 
824824            # Handle resample blocks - change downsamples to downsampler and remove index 
825825            elif  "resample"  in  new_key  or  "time_conv"  in  new_key :
826826                # Change the second downsamples to downsampler and remove the index 
@@ -831,19 +831,19 @@ def convert_vae_22():
831831                    # Remove the index (parts[4]) and change downsamples to downsampler 
832832                    new_parts  =  parts [:3 ] +  ["downsampler" ] +  parts [5 :]
833833                    new_key  =  "." .join (new_parts )
834-              
834+ 
835835            new_state_dict [new_key ] =  value 
836836
837837        # Handle decoder upsamples 
838838        elif  key .startswith ("decoder.upsamples." ):
839839            # Change decoder.upsamples to decoder.up_blocks 
840840            new_key  =  key .replace ("decoder.upsamples." , "decoder.up_blocks." )
841-              
841+ 
842842            # Handle residual blocks - change upsamples to resnets and rename components 
843843            if  "residual"  in  new_key  or  "shortcut"  in  new_key :
844844                # Change the second upsamples to resnets 
845845                new_key  =  new_key .replace (".upsamples." , ".resnets." )
846-                  
846+ 
847847                # Rename residual components 
848848                if  ".residual.0.gamma"  in  new_key :
849849                    new_key  =  new_key .replace (".residual.0.gamma" , ".norm1.gamma" )
@@ -861,7 +861,7 @@ def convert_vae_22():
861861                    new_key  =  new_key .replace (".shortcut.weight" , ".conv_shortcut.weight" )
862862                elif  ".shortcut.bias"  in  new_key :
863863                    new_key  =  new_key .replace (".shortcut.bias" , ".conv_shortcut.bias" )
864-              
864+ 
865865            # Handle resample blocks - change upsamples to upsampler and remove index 
866866            elif  "resample"  in  new_key  or  "time_conv"  in  new_key :
867867                # Change the second upsamples to upsampler and remove the index 
@@ -872,7 +872,7 @@ def convert_vae_22():
872872                    # Remove the index (parts[4]) and change upsamples to upsampler 
873873                    new_parts  =  parts [:3 ] +  ["upsampler" ] +  parts [5 :]
874874                    new_key  =  "." .join (new_parts )
875-              
875+ 
876876            new_state_dict [new_key ] =  value 
877877        else :
878878            # Keep other keys unchanged 
0 commit comments