1
- from diffusers import UNet2DModel , DDPMScheduler , DDPMPipeline
1
+ from diffusers import UNet2DModel , DDPMScheduler , DDPMPipeline , VQModel , AutoencoderKL
2
2
import argparse
3
3
import json
4
4
import torch
@@ -64,7 +64,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
64
64
65
65
target_shape = (- 1 , channels ) if len (old_tensor .shape ) == 3 else (- 1 )
66
66
67
- num_heads = old_tensor .shape [0 ] // config [ "num_head_channels" ] // 3
67
+ num_heads = old_tensor .shape [0 ] // config . get ( "num_head_channels" , 1 ) // 3
68
68
69
69
old_tensor = old_tensor .reshape ((num_heads , 3 * channels // num_heads ) + old_tensor .shape [1 :])
70
70
query , key , value = old_tensor .split (channels // num_heads , dim = 1 )
@@ -79,7 +79,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
79
79
if attention_paths_to_split is not None and new_path in attention_paths_to_split :
80
80
continue
81
81
82
- new_path = new_path .replace ('down.' , 'downsample_blocks .' )
82
+ new_path = new_path .replace ('down.' , 'down_blocks .' )
83
83
new_path = new_path .replace ('up.' , 'up_blocks.' )
84
84
85
85
if additional_replacements is not None :
@@ -111,36 +111,36 @@ def convert_ddpm_checkpoint(checkpoint, config):
111
111
new_checkpoint ['conv_out.weight' ] = checkpoint ['conv_out.weight' ]
112
112
new_checkpoint ['conv_out.bias' ] = checkpoint ['conv_out.bias' ]
113
113
114
- num_downsample_blocks = len ({'.' .join (layer .split ('.' )[:2 ]) for layer in checkpoint if 'down' in layer })
115
- downsample_blocks = {layer_id : [key for key in checkpoint if f'down.{ layer_id } ' in key ] for layer_id in range (num_downsample_blocks )}
114
+ num_down_blocks = len ({'.' .join (layer .split ('.' )[:2 ]) for layer in checkpoint if 'down' in layer })
115
+ down_blocks = {layer_id : [key for key in checkpoint if f'down.{ layer_id } ' in key ] for layer_id in range (num_down_blocks )}
116
116
117
117
num_up_blocks = len ({'.' .join (layer .split ('.' )[:2 ]) for layer in checkpoint if 'up' in layer })
118
118
up_blocks = {layer_id : [key for key in checkpoint if f'up.{ layer_id } ' in key ] for layer_id in range (num_up_blocks )}
119
119
120
- for i in range (num_downsample_blocks ):
121
- block_id = (i - 1 ) // (config ['num_res_blocks ' ] + 1 )
120
+ for i in range (num_down_blocks ):
121
+ block_id = (i - 1 ) // (config ['layers_per_block ' ] + 1 )
122
122
123
- if any ('downsample' in layer for layer in downsample_blocks [i ]):
124
- new_checkpoint [f'downsample_blocks .{ i } .downsamplers.0.conv.weight' ] = checkpoint [f'down.{ i } .downsample.conv .weight' ]
125
- new_checkpoint [f'downsample_blocks .{ i } .downsamplers.0.conv.bias' ] = checkpoint [f'down.{ i } .downsample.conv .bias' ]
126
- new_checkpoint [f'downsample_blocks .{ i } .downsamplers.0.op.weight' ] = checkpoint [f'down.{ i } .downsample.conv.weight' ]
127
- new_checkpoint [f'downsample_blocks .{ i } .downsamplers.0.op.bias' ] = checkpoint [f'down.{ i } .downsample.conv.bias' ]
123
+ if any ('downsample' in layer for layer in down_blocks [i ]):
124
+ new_checkpoint [f'down_blocks .{ i } .downsamplers.0.conv.weight' ] = checkpoint [f'down.{ i } .downsample.op .weight' ]
125
+ new_checkpoint [f'down_blocks .{ i } .downsamplers.0.conv.bias' ] = checkpoint [f'down.{ i } .downsample.op .bias' ]
126
+ # new_checkpoint[f'down_blocks .{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
127
+ # new_checkpoint[f'down_blocks .{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
128
128
129
- if any ('block' in layer for layer in downsample_blocks [i ]):
130
- num_blocks = len ({'.' .join (shave_segments (layer , 2 ).split ('.' )[:2 ]) for layer in downsample_blocks [i ] if 'block' in layer })
131
- blocks = {layer_id : [key for key in downsample_blocks [i ] if f'block.{ layer_id } ' in key ] for layer_id in range (num_blocks )}
129
+ if any ('block' in layer for layer in down_blocks [i ]):
130
+ num_blocks = len ({'.' .join (shave_segments (layer , 2 ).split ('.' )[:2 ]) for layer in down_blocks [i ] if 'block' in layer })
131
+ blocks = {layer_id : [key for key in down_blocks [i ] if f'block.{ layer_id } ' in key ] for layer_id in range (num_blocks )}
132
132
133
133
if num_blocks > 0 :
134
- for j in range (config ['num_res_blocks ' ]):
134
+ for j in range (config ['layers_per_block ' ]):
135
135
paths = renew_resnet_paths (blocks [j ])
136
136
assign_to_checkpoint (paths , new_checkpoint , checkpoint )
137
137
138
- if any ('attn' in layer for layer in downsample_blocks [i ]):
139
- num_attn = len ({'.' .join (shave_segments (layer , 2 ).split ('.' )[:2 ]) for layer in downsample_blocks [i ] if 'attn' in layer })
140
- attns = {layer_id : [key for key in downsample_blocks [i ] if f'attn.{ layer_id } ' in key ] for layer_id in range (num_blocks )}
138
+ if any ('attn' in layer for layer in down_blocks [i ]):
139
+ num_attn = len ({'.' .join (shave_segments (layer , 2 ).split ('.' )[:2 ]) for layer in down_blocks [i ] if 'attn' in layer })
140
+ attns = {layer_id : [key for key in down_blocks [i ] if f'attn.{ layer_id } ' in key ] for layer_id in range (num_blocks )}
141
141
142
142
if num_attn > 0 :
143
- for j in range (config ['num_res_blocks ' ]):
143
+ for j in range (config ['layers_per_block ' ]):
144
144
paths = renew_attention_paths (attns [j ])
145
145
assign_to_checkpoint (paths , new_checkpoint , checkpoint , config = config )
146
146
@@ -176,7 +176,7 @@ def convert_ddpm_checkpoint(checkpoint, config):
176
176
blocks = {layer_id : [key for key in up_blocks [i ] if f'block.{ layer_id } ' in key ] for layer_id in range (num_blocks )}
177
177
178
178
if num_blocks > 0 :
179
- for j in range (config ['num_res_blocks ' ] + 1 ):
179
+ for j in range (config ['layers_per_block ' ] + 1 ):
180
180
replace_indices = {'old' : f'up_blocks.{ i } ' , 'new' : f'up_blocks.{ block_id } ' }
181
181
paths = renew_resnet_paths (blocks [j ])
182
182
assign_to_checkpoint (paths , new_checkpoint , checkpoint , additional_replacements = [replace_indices ])
@@ -186,7 +186,7 @@ def convert_ddpm_checkpoint(checkpoint, config):
186
186
attns = {layer_id : [key for key in up_blocks [i ] if f'attn.{ layer_id } ' in key ] for layer_id in range (num_blocks )}
187
187
188
188
if num_attn > 0 :
189
- for j in range (config ['num_res_blocks ' ] + 1 ):
189
+ for j in range (config ['layers_per_block ' ] + 1 ):
190
190
replace_indices = {'old' : f'up_blocks.{ i } ' , 'new' : f'up_blocks.{ block_id } ' }
191
191
paths = renew_attention_paths (attns [j ])
192
192
assign_to_checkpoint (paths , new_checkpoint , checkpoint , additional_replacements = [replace_indices ])
@@ -195,6 +195,117 @@ def convert_ddpm_checkpoint(checkpoint, config):
195
195
return new_checkpoint
196
196
197
197
198
+ def convert_vq_autoenc_checkpoint (checkpoint , config ):
199
+ """
200
+ Takes a state dict and a config, and returns a converted checkpoint.
201
+ """
202
+ new_checkpoint = {}
203
+
204
+ new_checkpoint ['encoder.conv_norm_out.weight' ] = checkpoint ['encoder.norm_out.weight' ]
205
+ new_checkpoint ['encoder.conv_norm_out.bias' ] = checkpoint ['encoder.norm_out.bias' ]
206
+
207
+ new_checkpoint ['encoder.conv_in.weight' ] = checkpoint ['encoder.conv_in.weight' ]
208
+ new_checkpoint ['encoder.conv_in.bias' ] = checkpoint ['encoder.conv_in.bias' ]
209
+ new_checkpoint ['encoder.conv_out.weight' ] = checkpoint ['encoder.conv_out.weight' ]
210
+ new_checkpoint ['encoder.conv_out.bias' ] = checkpoint ['encoder.conv_out.bias' ]
211
+
212
+ new_checkpoint ['decoder.conv_norm_out.weight' ] = checkpoint ['decoder.norm_out.weight' ]
213
+ new_checkpoint ['decoder.conv_norm_out.bias' ] = checkpoint ['decoder.norm_out.bias' ]
214
+
215
+ new_checkpoint ['decoder.conv_in.weight' ] = checkpoint ['decoder.conv_in.weight' ]
216
+ new_checkpoint ['decoder.conv_in.bias' ] = checkpoint ['decoder.conv_in.bias' ]
217
+ new_checkpoint ['decoder.conv_out.weight' ] = checkpoint ['decoder.conv_out.weight' ]
218
+ new_checkpoint ['decoder.conv_out.bias' ] = checkpoint ['decoder.conv_out.bias' ]
219
+
220
+ num_down_blocks = len ({'.' .join (layer .split ('.' )[:3 ]) for layer in checkpoint if 'down' in layer })
221
+ down_blocks = {layer_id : [key for key in checkpoint if f'down.{ layer_id } ' in key ] for layer_id in range (num_down_blocks )}
222
+
223
+ num_up_blocks = len ({'.' .join (layer .split ('.' )[:3 ]) for layer in checkpoint if 'up' in layer })
224
+ up_blocks = {layer_id : [key for key in checkpoint if f'up.{ layer_id } ' in key ] for layer_id in range (num_up_blocks )}
225
+
226
+ for i in range (num_down_blocks ):
227
+ block_id = (i - 1 ) // (config ['layers_per_block' ] + 1 )
228
+
229
+ if any ('downsample' in layer for layer in down_blocks [i ]):
230
+ new_checkpoint [f'encoder.down_blocks.{ i } .downsamplers.0.conv.weight' ] = checkpoint [f'encoder.down.{ i } .downsample.conv.weight' ]
231
+ new_checkpoint [f'encoder.down_blocks.{ i } .downsamplers.0.conv.bias' ] = checkpoint [f'encoder.down.{ i } .downsample.conv.bias' ]
232
+
233
+ if any ('block' in layer for layer in down_blocks [i ]):
234
+ num_blocks = len ({'.' .join (shave_segments (layer , 3 ).split ('.' )[:3 ]) for layer in down_blocks [i ] if 'block' in layer })
235
+ blocks = {layer_id : [key for key in down_blocks [i ] if f'block.{ layer_id } ' in key ] for layer_id in range (num_blocks )}
236
+
237
+ if num_blocks > 0 :
238
+ for j in range (config ['layers_per_block' ]):
239
+ paths = renew_resnet_paths (blocks [j ])
240
+ assign_to_checkpoint (paths , new_checkpoint , checkpoint )
241
+
242
+ if any ('attn' in layer for layer in down_blocks [i ]):
243
+ num_attn = len ({'.' .join (shave_segments (layer , 3 ).split ('.' )[:3 ]) for layer in down_blocks [i ] if 'attn' in layer })
244
+ attns = {layer_id : [key for key in down_blocks [i ] if f'attn.{ layer_id } ' in key ] for layer_id in range (num_blocks )}
245
+
246
+ if num_attn > 0 :
247
+ for j in range (config ['layers_per_block' ]):
248
+ paths = renew_attention_paths (attns [j ])
249
+ assign_to_checkpoint (paths , new_checkpoint , checkpoint , config = config )
250
+
251
+ mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key ]
252
+ mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key ]
253
+ mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key ]
254
+
255
+ # Mid new 2
256
+ paths = renew_resnet_paths (mid_block_1_layers )
257
+ assign_to_checkpoint (paths , new_checkpoint , checkpoint , additional_replacements = [
258
+ {'old' : 'mid.' , 'new' : 'mid_new_2.' }, {'old' : 'block_1' , 'new' : 'resnets.0' }
259
+ ])
260
+
261
+ paths = renew_resnet_paths (mid_block_2_layers )
262
+ assign_to_checkpoint (paths , new_checkpoint , checkpoint , additional_replacements = [
263
+ {'old' : 'mid.' , 'new' : 'mid_new_2.' }, {'old' : 'block_2' , 'new' : 'resnets.1' }
264
+ ])
265
+
266
+ paths = renew_attention_paths (mid_attn_1_layers , in_mid = True )
267
+ assign_to_checkpoint (paths , new_checkpoint , checkpoint , additional_replacements = [
268
+ {'old' : 'mid.' , 'new' : 'mid_new_2.' }, {'old' : 'attn_1' , 'new' : 'attentions.0' }
269
+ ])
270
+
271
+ for i in range (num_up_blocks ):
272
+ block_id = num_up_blocks - 1 - i
273
+
274
+ if any ('upsample' in layer for layer in up_blocks [i ]):
275
+ new_checkpoint [f'decoder.up_blocks.{ block_id } .upsamplers.0.conv.weight' ] = checkpoint [f'decoder.up.{ i } .upsample.conv.weight' ]
276
+ new_checkpoint [f'decoder.up_blocks.{ block_id } .upsamplers.0.conv.bias' ] = checkpoint [f'decoder.up.{ i } .upsample.conv.bias' ]
277
+
278
+ if any ('block' in layer for layer in up_blocks [i ]):
279
+ num_blocks = len ({'.' .join (shave_segments (layer , 3 ).split ('.' )[:3 ]) for layer in up_blocks [i ] if 'block' in layer })
280
+ blocks = {layer_id : [key for key in up_blocks [i ] if f'block.{ layer_id } ' in key ] for layer_id in range (num_blocks )}
281
+
282
+ if num_blocks > 0 :
283
+ for j in range (config ['layers_per_block' ] + 1 ):
284
+ replace_indices = {'old' : f'up_blocks.{ i } ' , 'new' : f'up_blocks.{ block_id } ' }
285
+ paths = renew_resnet_paths (blocks [j ])
286
+ assign_to_checkpoint (paths , new_checkpoint , checkpoint , additional_replacements = [replace_indices ])
287
+
288
+ if any ('attn' in layer for layer in up_blocks [i ]):
289
+ num_attn = len ({'.' .join (shave_segments (layer , 3 ).split ('.' )[:3 ]) for layer in up_blocks [i ] if 'attn' in layer })
290
+ attns = {layer_id : [key for key in up_blocks [i ] if f'attn.{ layer_id } ' in key ] for layer_id in range (num_blocks )}
291
+
292
+ if num_attn > 0 :
293
+ for j in range (config ['layers_per_block' ] + 1 ):
294
+ replace_indices = {'old' : f'up_blocks.{ i } ' , 'new' : f'up_blocks.{ block_id } ' }
295
+ paths = renew_attention_paths (attns [j ])
296
+ assign_to_checkpoint (paths , new_checkpoint , checkpoint , additional_replacements = [replace_indices ])
297
+
298
+ new_checkpoint = {k .replace ('mid_new_2' , 'mid_block' ): v for k , v in new_checkpoint .items ()}
299
+ new_checkpoint ["quant_conv.weight" ] = checkpoint ["quant_conv.weight" ]
300
+ new_checkpoint ["quant_conv.bias" ] = checkpoint ["quant_conv.bias" ]
301
+ if "quantize.embedding.weight" in checkpoint :
302
+ new_checkpoint ["quantize.embedding.weight" ] = checkpoint ["quantize.embedding.weight" ]
303
+ new_checkpoint ["post_quant_conv.weight" ] = checkpoint ["post_quant_conv.weight" ]
304
+ new_checkpoint ["post_quant_conv.bias" ] = checkpoint ["post_quant_conv.bias" ]
305
+
306
+ return new_checkpoint
307
+
308
+
198
309
if __name__ == "__main__" :
199
310
parser = argparse .ArgumentParser ()
200
311
@@ -220,15 +331,29 @@ def convert_ddpm_checkpoint(checkpoint, config):
220
331
with open (args .config_file ) as f :
221
332
config = json .loads (f .read ())
222
333
223
- converted_checkpoint = convert_ddpm_checkpoint (checkpoint , config )
334
+ # unet case
335
+ key_prefix_set = set (key .split ("." )[0 ] for key in checkpoint .keys ())
336
+ if "encoder" in key_prefix_set and "decoder" in key_prefix_set :
337
+ converted_checkpoint = convert_vq_autoenc_checkpoint (checkpoint , config )
338
+ else :
339
+ converted_checkpoint = convert_ddpm_checkpoint (checkpoint , config )
224
340
225
341
if "ddpm" in config :
226
342
del config ["ddpm" ]
227
343
228
- model = UNet2DModel (** config )
229
- model .load_state_dict (converted_checkpoint )
344
+ if config ["_class_name" ] == "VQModel" :
345
+ model = VQModel (** config )
346
+ model .load_state_dict (converted_checkpoint )
347
+ model .save_pretrained (args .dump_path )
348
+ elif config ["_class_name" ] == "AutoencoderKL" :
349
+ model = AutoencoderKL (** config )
350
+ model .load_state_dict (converted_checkpoint )
351
+ model .save_pretrained (args .dump_path )
352
+ else :
353
+ model = UNet2DModel (** config )
354
+ model .load_state_dict (converted_checkpoint )
230
355
231
- scheduler = DDPMScheduler .from_config ("/" .join (args .checkpoint_path .split ("/" )[:- 1 ]))
356
+ scheduler = DDPMScheduler .from_config ("/" .join (args .checkpoint_path .split ("/" )[:- 1 ]))
232
357
233
- pipe = DDPMPipeline (unet = model , scheduler = scheduler )
234
- pipe .save_pretrained (args .dump_path )
358
+ pipe = DDPMPipeline (unet = model , scheduler = scheduler )
359
+ pipe .save_pretrained (args .dump_path )
0 commit comments