1- from diffusers import UNet2DModel , DDPMScheduler , DDPMPipeline
1+ from diffusers import UNet2DModel , DDPMScheduler , DDPMPipeline , VQModel , AutoencoderKL
22import argparse
33import json
44import torch
@@ -64,7 +64,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
6464
6565 target_shape = (- 1 , channels ) if len (old_tensor .shape ) == 3 else (- 1 )
6666
67- num_heads = old_tensor .shape [0 ] // config [ "num_head_channels" ] // 3
67+ num_heads = old_tensor .shape [0 ] // config . get ( "num_head_channels" , 1 ) // 3
6868
6969 old_tensor = old_tensor .reshape ((num_heads , 3 * channels // num_heads ) + old_tensor .shape [1 :])
7070 query , key , value = old_tensor .split (channels // num_heads , dim = 1 )
@@ -79,7 +79,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
7979 if attention_paths_to_split is not None and new_path in attention_paths_to_split :
8080 continue
8181
82- new_path = new_path .replace ('down.' , 'downsample_blocks .' )
82+ new_path = new_path .replace ('down.' , 'down_blocks .' )
8383 new_path = new_path .replace ('up.' , 'up_blocks.' )
8484
8585 if additional_replacements is not None :
@@ -111,36 +111,36 @@ def convert_ddpm_checkpoint(checkpoint, config):
111111 new_checkpoint ['conv_out.weight' ] = checkpoint ['conv_out.weight' ]
112112 new_checkpoint ['conv_out.bias' ] = checkpoint ['conv_out.bias' ]
113113
114- num_downsample_blocks = len ({'.' .join (layer .split ('.' )[:2 ]) for layer in checkpoint if 'down' in layer })
115- downsample_blocks = {layer_id : [key for key in checkpoint if f'down.{ layer_id } ' in key ] for layer_id in range (num_downsample_blocks )}
114+ num_down_blocks = len ({'.' .join (layer .split ('.' )[:2 ]) for layer in checkpoint if 'down' in layer })
115+ down_blocks = {layer_id : [key for key in checkpoint if f'down.{ layer_id } ' in key ] for layer_id in range (num_down_blocks )}
116116
117117 num_up_blocks = len ({'.' .join (layer .split ('.' )[:2 ]) for layer in checkpoint if 'up' in layer })
118118 up_blocks = {layer_id : [key for key in checkpoint if f'up.{ layer_id } ' in key ] for layer_id in range (num_up_blocks )}
119119
120- for i in range (num_downsample_blocks ):
121- block_id = (i - 1 ) // (config ['num_res_blocks ' ] + 1 )
120+ for i in range (num_down_blocks ):
121+ block_id = (i - 1 ) // (config ['layers_per_block ' ] + 1 )
122122
123- if any ('downsample' in layer for layer in downsample_blocks [i ]):
124- new_checkpoint [f'downsample_blocks .{ i } .downsamplers.0.conv.weight' ] = checkpoint [f'down.{ i } .downsample.conv .weight' ]
125- new_checkpoint [f'downsample_blocks .{ i } .downsamplers.0.conv.bias' ] = checkpoint [f'down.{ i } .downsample.conv .bias' ]
126- new_checkpoint [f'downsample_blocks .{ i } .downsamplers.0.op.weight' ] = checkpoint [f'down.{ i } .downsample.conv.weight' ]
127- new_checkpoint [f'downsample_blocks .{ i } .downsamplers.0.op.bias' ] = checkpoint [f'down.{ i } .downsample.conv.bias' ]
123+ if any ('downsample' in layer for layer in down_blocks [i ]):
124+ new_checkpoint [f'down_blocks .{ i } .downsamplers.0.conv.weight' ] = checkpoint [f'down.{ i } .downsample.op .weight' ]
125+ new_checkpoint [f'down_blocks .{ i } .downsamplers.0.conv.bias' ] = checkpoint [f'down.{ i } .downsample.op .bias' ]
126+ # new_checkpoint[f'down_blocks .{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
127+ # new_checkpoint[f'down_blocks .{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
128128
129- if any ('block' in layer for layer in downsample_blocks [i ]):
130- num_blocks = len ({'.' .join (shave_segments (layer , 2 ).split ('.' )[:2 ]) for layer in downsample_blocks [i ] if 'block' in layer })
131- blocks = {layer_id : [key for key in downsample_blocks [i ] if f'block.{ layer_id } ' in key ] for layer_id in range (num_blocks )}
129+ if any ('block' in layer for layer in down_blocks [i ]):
130+ num_blocks = len ({'.' .join (shave_segments (layer , 2 ).split ('.' )[:2 ]) for layer in down_blocks [i ] if 'block' in layer })
131+ blocks = {layer_id : [key for key in down_blocks [i ] if f'block.{ layer_id } ' in key ] for layer_id in range (num_blocks )}
132132
133133 if num_blocks > 0 :
134- for j in range (config ['num_res_blocks ' ]):
134+ for j in range (config ['layers_per_block ' ]):
135135 paths = renew_resnet_paths (blocks [j ])
136136 assign_to_checkpoint (paths , new_checkpoint , checkpoint )
137137
138- if any ('attn' in layer for layer in downsample_blocks [i ]):
139- num_attn = len ({'.' .join (shave_segments (layer , 2 ).split ('.' )[:2 ]) for layer in downsample_blocks [i ] if 'attn' in layer })
140- attns = {layer_id : [key for key in downsample_blocks [i ] if f'attn.{ layer_id } ' in key ] for layer_id in range (num_blocks )}
138+ if any ('attn' in layer for layer in down_blocks [i ]):
139+ num_attn = len ({'.' .join (shave_segments (layer , 2 ).split ('.' )[:2 ]) for layer in down_blocks [i ] if 'attn' in layer })
140+ attns = {layer_id : [key for key in down_blocks [i ] if f'attn.{ layer_id } ' in key ] for layer_id in range (num_blocks )}
141141
142142 if num_attn > 0 :
143- for j in range (config ['num_res_blocks ' ]):
143+ for j in range (config ['layers_per_block ' ]):
144144 paths = renew_attention_paths (attns [j ])
145145 assign_to_checkpoint (paths , new_checkpoint , checkpoint , config = config )
146146
@@ -176,7 +176,7 @@ def convert_ddpm_checkpoint(checkpoint, config):
176176 blocks = {layer_id : [key for key in up_blocks [i ] if f'block.{ layer_id } ' in key ] for layer_id in range (num_blocks )}
177177
178178 if num_blocks > 0 :
179- for j in range (config ['num_res_blocks ' ] + 1 ):
179+ for j in range (config ['layers_per_block ' ] + 1 ):
180180 replace_indices = {'old' : f'up_blocks.{ i } ' , 'new' : f'up_blocks.{ block_id } ' }
181181 paths = renew_resnet_paths (blocks [j ])
182182 assign_to_checkpoint (paths , new_checkpoint , checkpoint , additional_replacements = [replace_indices ])
@@ -186,7 +186,7 @@ def convert_ddpm_checkpoint(checkpoint, config):
186186 attns = {layer_id : [key for key in up_blocks [i ] if f'attn.{ layer_id } ' in key ] for layer_id in range (num_blocks )}
187187
188188 if num_attn > 0 :
189- for j in range (config ['num_res_blocks ' ] + 1 ):
189+ for j in range (config ['layers_per_block ' ] + 1 ):
190190 replace_indices = {'old' : f'up_blocks.{ i } ' , 'new' : f'up_blocks.{ block_id } ' }
191191 paths = renew_attention_paths (attns [j ])
192192 assign_to_checkpoint (paths , new_checkpoint , checkpoint , additional_replacements = [replace_indices ])
@@ -195,6 +195,117 @@ def convert_ddpm_checkpoint(checkpoint, config):
195195 return new_checkpoint
196196
197197
198+ def convert_vq_autoenc_checkpoint (checkpoint , config ):
199+ """
200+ Takes a state dict and a config, and returns a converted checkpoint.
201+ """
202+ new_checkpoint = {}
203+
204+ new_checkpoint ['encoder.conv_norm_out.weight' ] = checkpoint ['encoder.norm_out.weight' ]
205+ new_checkpoint ['encoder.conv_norm_out.bias' ] = checkpoint ['encoder.norm_out.bias' ]
206+
207+ new_checkpoint ['encoder.conv_in.weight' ] = checkpoint ['encoder.conv_in.weight' ]
208+ new_checkpoint ['encoder.conv_in.bias' ] = checkpoint ['encoder.conv_in.bias' ]
209+ new_checkpoint ['encoder.conv_out.weight' ] = checkpoint ['encoder.conv_out.weight' ]
210+ new_checkpoint ['encoder.conv_out.bias' ] = checkpoint ['encoder.conv_out.bias' ]
211+
212+ new_checkpoint ['decoder.conv_norm_out.weight' ] = checkpoint ['decoder.norm_out.weight' ]
213+ new_checkpoint ['decoder.conv_norm_out.bias' ] = checkpoint ['decoder.norm_out.bias' ]
214+
215+ new_checkpoint ['decoder.conv_in.weight' ] = checkpoint ['decoder.conv_in.weight' ]
216+ new_checkpoint ['decoder.conv_in.bias' ] = checkpoint ['decoder.conv_in.bias' ]
217+ new_checkpoint ['decoder.conv_out.weight' ] = checkpoint ['decoder.conv_out.weight' ]
218+ new_checkpoint ['decoder.conv_out.bias' ] = checkpoint ['decoder.conv_out.bias' ]
219+
220+ num_down_blocks = len ({'.' .join (layer .split ('.' )[:3 ]) for layer in checkpoint if 'down' in layer })
221+ down_blocks = {layer_id : [key for key in checkpoint if f'down.{ layer_id } ' in key ] for layer_id in range (num_down_blocks )}
222+
223+ num_up_blocks = len ({'.' .join (layer .split ('.' )[:3 ]) for layer in checkpoint if 'up' in layer })
224+ up_blocks = {layer_id : [key for key in checkpoint if f'up.{ layer_id } ' in key ] for layer_id in range (num_up_blocks )}
225+
226+ for i in range (num_down_blocks ):
227+ block_id = (i - 1 ) // (config ['layers_per_block' ] + 1 )
228+
229+ if any ('downsample' in layer for layer in down_blocks [i ]):
230+ new_checkpoint [f'encoder.down_blocks.{ i } .downsamplers.0.conv.weight' ] = checkpoint [f'encoder.down.{ i } .downsample.conv.weight' ]
231+ new_checkpoint [f'encoder.down_blocks.{ i } .downsamplers.0.conv.bias' ] = checkpoint [f'encoder.down.{ i } .downsample.conv.bias' ]
232+
233+ if any ('block' in layer for layer in down_blocks [i ]):
234+ num_blocks = len ({'.' .join (shave_segments (layer , 3 ).split ('.' )[:3 ]) for layer in down_blocks [i ] if 'block' in layer })
235+ blocks = {layer_id : [key for key in down_blocks [i ] if f'block.{ layer_id } ' in key ] for layer_id in range (num_blocks )}
236+
237+ if num_blocks > 0 :
238+ for j in range (config ['layers_per_block' ]):
239+ paths = renew_resnet_paths (blocks [j ])
240+ assign_to_checkpoint (paths , new_checkpoint , checkpoint )
241+
242+ if any ('attn' in layer for layer in down_blocks [i ]):
243+ num_attn = len ({'.' .join (shave_segments (layer , 3 ).split ('.' )[:3 ]) for layer in down_blocks [i ] if 'attn' in layer })
244+ attns = {layer_id : [key for key in down_blocks [i ] if f'attn.{ layer_id } ' in key ] for layer_id in range (num_blocks )}
245+
246+ if num_attn > 0 :
247+ for j in range (config ['layers_per_block' ]):
248+ paths = renew_attention_paths (attns [j ])
249+ assign_to_checkpoint (paths , new_checkpoint , checkpoint , config = config )
250+
251+ mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key ]
252+ mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key ]
253+ mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key ]
254+
255+ # Mid new 2
256+ paths = renew_resnet_paths (mid_block_1_layers )
257+ assign_to_checkpoint (paths , new_checkpoint , checkpoint , additional_replacements = [
258+ {'old' : 'mid.' , 'new' : 'mid_new_2.' }, {'old' : 'block_1' , 'new' : 'resnets.0' }
259+ ])
260+
261+ paths = renew_resnet_paths (mid_block_2_layers )
262+ assign_to_checkpoint (paths , new_checkpoint , checkpoint , additional_replacements = [
263+ {'old' : 'mid.' , 'new' : 'mid_new_2.' }, {'old' : 'block_2' , 'new' : 'resnets.1' }
264+ ])
265+
266+ paths = renew_attention_paths (mid_attn_1_layers , in_mid = True )
267+ assign_to_checkpoint (paths , new_checkpoint , checkpoint , additional_replacements = [
268+ {'old' : 'mid.' , 'new' : 'mid_new_2.' }, {'old' : 'attn_1' , 'new' : 'attentions.0' }
269+ ])
270+
271+ for i in range (num_up_blocks ):
272+ block_id = num_up_blocks - 1 - i
273+
274+ if any ('upsample' in layer for layer in up_blocks [i ]):
275+ new_checkpoint [f'decoder.up_blocks.{ block_id } .upsamplers.0.conv.weight' ] = checkpoint [f'decoder.up.{ i } .upsample.conv.weight' ]
276+ new_checkpoint [f'decoder.up_blocks.{ block_id } .upsamplers.0.conv.bias' ] = checkpoint [f'decoder.up.{ i } .upsample.conv.bias' ]
277+
278+ if any ('block' in layer for layer in up_blocks [i ]):
279+ num_blocks = len ({'.' .join (shave_segments (layer , 3 ).split ('.' )[:3 ]) for layer in up_blocks [i ] if 'block' in layer })
280+ blocks = {layer_id : [key for key in up_blocks [i ] if f'block.{ layer_id } ' in key ] for layer_id in range (num_blocks )}
281+
282+ if num_blocks > 0 :
283+ for j in range (config ['layers_per_block' ] + 1 ):
284+ replace_indices = {'old' : f'up_blocks.{ i } ' , 'new' : f'up_blocks.{ block_id } ' }
285+ paths = renew_resnet_paths (blocks [j ])
286+ assign_to_checkpoint (paths , new_checkpoint , checkpoint , additional_replacements = [replace_indices ])
287+
288+ if any ('attn' in layer for layer in up_blocks [i ]):
289+ num_attn = len ({'.' .join (shave_segments (layer , 3 ).split ('.' )[:3 ]) for layer in up_blocks [i ] if 'attn' in layer })
290+ attns = {layer_id : [key for key in up_blocks [i ] if f'attn.{ layer_id } ' in key ] for layer_id in range (num_blocks )}
291+
292+ if num_attn > 0 :
293+ for j in range (config ['layers_per_block' ] + 1 ):
294+ replace_indices = {'old' : f'up_blocks.{ i } ' , 'new' : f'up_blocks.{ block_id } ' }
295+ paths = renew_attention_paths (attns [j ])
296+ assign_to_checkpoint (paths , new_checkpoint , checkpoint , additional_replacements = [replace_indices ])
297+
298+ new_checkpoint = {k .replace ('mid_new_2' , 'mid_block' ): v for k , v in new_checkpoint .items ()}
299+ new_checkpoint ["quant_conv.weight" ] = checkpoint ["quant_conv.weight" ]
300+ new_checkpoint ["quant_conv.bias" ] = checkpoint ["quant_conv.bias" ]
301+ if "quantize.embedding.weight" in checkpoint :
302+ new_checkpoint ["quantize.embedding.weight" ] = checkpoint ["quantize.embedding.weight" ]
303+ new_checkpoint ["post_quant_conv.weight" ] = checkpoint ["post_quant_conv.weight" ]
304+ new_checkpoint ["post_quant_conv.bias" ] = checkpoint ["post_quant_conv.bias" ]
305+
306+ return new_checkpoint
307+
308+
198309if __name__ == "__main__" :
199310 parser = argparse .ArgumentParser ()
200311
@@ -220,15 +331,29 @@ def convert_ddpm_checkpoint(checkpoint, config):
220331 with open (args .config_file ) as f :
221332 config = json .loads (f .read ())
222333
223- converted_checkpoint = convert_ddpm_checkpoint (checkpoint , config )
334+ # unet case
335+ key_prefix_set = set (key .split ("." )[0 ] for key in checkpoint .keys ())
336+ if "encoder" in key_prefix_set and "decoder" in key_prefix_set :
337+ converted_checkpoint = convert_vq_autoenc_checkpoint (checkpoint , config )
338+ else :
339+ converted_checkpoint = convert_ddpm_checkpoint (checkpoint , config )
224340
225341 if "ddpm" in config :
226342 del config ["ddpm" ]
227343
228- model = UNet2DModel (** config )
229- model .load_state_dict (converted_checkpoint )
344+ if config ["_class_name" ] == "VQModel" :
345+ model = VQModel (** config )
346+ model .load_state_dict (converted_checkpoint )
347+ model .save_pretrained (args .dump_path )
348+ elif config ["_class_name" ] == "AutoencoderKL" :
349+ model = AutoencoderKL (** config )
350+ model .load_state_dict (converted_checkpoint )
351+ model .save_pretrained (args .dump_path )
352+ else :
353+ model = UNet2DModel (** config )
354+ model .load_state_dict (converted_checkpoint )
230355
231- scheduler = DDPMScheduler .from_config ("/" .join (args .checkpoint_path .split ("/" )[:- 1 ]))
356+ scheduler = DDPMScheduler .from_config ("/" .join (args .checkpoint_path .split ("/" )[:- 1 ]))
232357
233- pipe = DDPMPipeline (unet = model , scheduler = scheduler )
234- pipe .save_pretrained (args .dump_path )
358+ pipe = DDPMPipeline (unet = model , scheduler = scheduler )
359+ pipe .save_pretrained (args .dump_path )
0 commit comments