Skip to content

Commit 3100bc9

Browse files
[Vae and AutoencoderKL] Final clean of LDM checkpoints (#137)
* [Vae and AutoencoderKL clean] * save intermediate finished work * more progress * more progress * finish modeling code * save intermediate * finish * Correct tests
1 parent e05f03a commit 3100bc9

File tree

6 files changed

+490
-312
lines changed

6 files changed

+490
-312
lines changed

scripts/convert_ddpm_original_checkpoint_to_diffusers.py

Lines changed: 153 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline
1+
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline, VQModel, AutoencoderKL
22
import argparse
33
import json
44
import torch
@@ -64,7 +64,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
6464

6565
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
6666

67-
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
67+
num_heads = old_tensor.shape[0] // config.get("num_head_channels", 1) // 3
6868

6969
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
7070
query, key, value = old_tensor.split(channels // num_heads, dim=1)
@@ -79,7 +79,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
7979
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
8080
continue
8181

82-
new_path = new_path.replace('down.', 'downsample_blocks.')
82+
new_path = new_path.replace('down.', 'down_blocks.')
8383
new_path = new_path.replace('up.', 'up_blocks.')
8484

8585
if additional_replacements is not None:
@@ -111,36 +111,36 @@ def convert_ddpm_checkpoint(checkpoint, config):
111111
new_checkpoint['conv_out.weight'] = checkpoint['conv_out.weight']
112112
new_checkpoint['conv_out.bias'] = checkpoint['conv_out.bias']
113113

114-
num_downsample_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer})
115-
downsample_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_downsample_blocks)}
114+
num_down_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer})
115+
down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}
116116

117117
num_up_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer})
118118
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
119119

120-
for i in range(num_downsample_blocks):
121-
block_id = (i - 1) // (config['num_res_blocks'] + 1)
120+
for i in range(num_down_blocks):
121+
block_id = (i - 1) // (config['layers_per_block'] + 1)
122122

123-
if any('downsample' in layer for layer in downsample_blocks[i]):
124-
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
125-
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
126-
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
127-
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
123+
if any('downsample' in layer for layer in down_blocks[i]):
124+
new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.op.weight']
125+
new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'down.{i}.downsample.op.bias']
126+
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
127+
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
128128

129-
if any('block' in layer for layer in downsample_blocks[i]):
130-
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in downsample_blocks[i] if 'block' in layer})
131-
blocks = {layer_id: [key for key in downsample_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
129+
if any('block' in layer for layer in down_blocks[i]):
130+
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'block' in layer})
131+
blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
132132

133133
if num_blocks > 0:
134-
for j in range(config['num_res_blocks']):
134+
for j in range(config['layers_per_block']):
135135
paths = renew_resnet_paths(blocks[j])
136136
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
137137

138-
if any('attn' in layer for layer in downsample_blocks[i]):
139-
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in downsample_blocks[i] if 'attn' in layer})
140-
attns = {layer_id: [key for key in downsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
138+
if any('attn' in layer for layer in down_blocks[i]):
139+
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'attn' in layer})
140+
attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
141141

142142
if num_attn > 0:
143-
for j in range(config['num_res_blocks']):
143+
for j in range(config['layers_per_block']):
144144
paths = renew_attention_paths(attns[j])
145145
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
146146

@@ -176,7 +176,7 @@ def convert_ddpm_checkpoint(checkpoint, config):
176176
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
177177

178178
if num_blocks > 0:
179-
for j in range(config['num_res_blocks'] + 1):
179+
for j in range(config['layers_per_block'] + 1):
180180
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
181181
paths = renew_resnet_paths(blocks[j])
182182
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
@@ -186,7 +186,7 @@ def convert_ddpm_checkpoint(checkpoint, config):
186186
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
187187

188188
if num_attn > 0:
189-
for j in range(config['num_res_blocks'] + 1):
189+
for j in range(config['layers_per_block'] + 1):
190190
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
191191
paths = renew_attention_paths(attns[j])
192192
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
@@ -195,6 +195,117 @@ def convert_ddpm_checkpoint(checkpoint, config):
195195
return new_checkpoint
196196

197197

198+
def convert_vq_autoenc_checkpoint(checkpoint, config):
199+
"""
200+
Takes a state dict and a config, and returns a converted checkpoint.
201+
"""
202+
new_checkpoint = {}
203+
204+
new_checkpoint['encoder.conv_norm_out.weight'] = checkpoint['encoder.norm_out.weight']
205+
new_checkpoint['encoder.conv_norm_out.bias'] = checkpoint['encoder.norm_out.bias']
206+
207+
new_checkpoint['encoder.conv_in.weight'] = checkpoint['encoder.conv_in.weight']
208+
new_checkpoint['encoder.conv_in.bias'] = checkpoint['encoder.conv_in.bias']
209+
new_checkpoint['encoder.conv_out.weight'] = checkpoint['encoder.conv_out.weight']
210+
new_checkpoint['encoder.conv_out.bias'] = checkpoint['encoder.conv_out.bias']
211+
212+
new_checkpoint['decoder.conv_norm_out.weight'] = checkpoint['decoder.norm_out.weight']
213+
new_checkpoint['decoder.conv_norm_out.bias'] = checkpoint['decoder.norm_out.bias']
214+
215+
new_checkpoint['decoder.conv_in.weight'] = checkpoint['decoder.conv_in.weight']
216+
new_checkpoint['decoder.conv_in.bias'] = checkpoint['decoder.conv_in.bias']
217+
new_checkpoint['decoder.conv_out.weight'] = checkpoint['decoder.conv_out.weight']
218+
new_checkpoint['decoder.conv_out.bias'] = checkpoint['decoder.conv_out.bias']
219+
220+
num_down_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'down' in layer})
221+
down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}
222+
223+
num_up_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'up' in layer})
224+
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
225+
226+
for i in range(num_down_blocks):
227+
block_id = (i - 1) // (config['layers_per_block'] + 1)
228+
229+
if any('downsample' in layer for layer in down_blocks[i]):
230+
new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'encoder.down.{i}.downsample.conv.weight']
231+
new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'encoder.down.{i}.downsample.conv.bias']
232+
233+
if any('block' in layer for layer in down_blocks[i]):
234+
num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'block' in layer})
235+
blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
236+
237+
if num_blocks > 0:
238+
for j in range(config['layers_per_block']):
239+
paths = renew_resnet_paths(blocks[j])
240+
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
241+
242+
if any('attn' in layer for layer in down_blocks[i]):
243+
num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'attn' in layer})
244+
attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
245+
246+
if num_attn > 0:
247+
for j in range(config['layers_per_block']):
248+
paths = renew_attention_paths(attns[j])
249+
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
250+
251+
mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
252+
mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
253+
mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
254+
255+
# Mid new 2
256+
paths = renew_resnet_paths(mid_block_1_layers)
257+
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
258+
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_1', 'new': 'resnets.0'}
259+
])
260+
261+
paths = renew_resnet_paths(mid_block_2_layers)
262+
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
263+
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_2', 'new': 'resnets.1'}
264+
])
265+
266+
paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
267+
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
268+
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'}
269+
])
270+
271+
for i in range(num_up_blocks):
272+
block_id = num_up_blocks - 1 - i
273+
274+
if any('upsample' in layer for layer in up_blocks[i]):
275+
new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'decoder.up.{i}.upsample.conv.weight']
276+
new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'decoder.up.{i}.upsample.conv.bias']
277+
278+
if any('block' in layer for layer in up_blocks[i]):
279+
num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'block' in layer})
280+
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
281+
282+
if num_blocks > 0:
283+
for j in range(config['layers_per_block'] + 1):
284+
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
285+
paths = renew_resnet_paths(blocks[j])
286+
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
287+
288+
if any('attn' in layer for layer in up_blocks[i]):
289+
num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'attn' in layer})
290+
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
291+
292+
if num_attn > 0:
293+
for j in range(config['layers_per_block'] + 1):
294+
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
295+
paths = renew_attention_paths(attns[j])
296+
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
297+
298+
new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()}
299+
new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"]
300+
new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"]
301+
if "quantize.embedding.weight" in checkpoint:
302+
new_checkpoint["quantize.embedding.weight"] = checkpoint["quantize.embedding.weight"]
303+
new_checkpoint["post_quant_conv.weight"] = checkpoint["post_quant_conv.weight"]
304+
new_checkpoint["post_quant_conv.bias"] = checkpoint["post_quant_conv.bias"]
305+
306+
return new_checkpoint
307+
308+
198309
if __name__ == "__main__":
199310
parser = argparse.ArgumentParser()
200311

@@ -220,15 +331,29 @@ def convert_ddpm_checkpoint(checkpoint, config):
220331
with open(args.config_file) as f:
221332
config = json.loads(f.read())
222333

223-
converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config)
334+
# unet case
335+
key_prefix_set = set(key.split(".")[0] for key in checkpoint.keys())
336+
if "encoder" in key_prefix_set and "decoder" in key_prefix_set:
337+
converted_checkpoint = convert_vq_autoenc_checkpoint(checkpoint, config)
338+
else:
339+
converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config)
224340

225341
if "ddpm" in config:
226342
del config["ddpm"]
227343

228-
model = UNet2DModel(**config)
229-
model.load_state_dict(converted_checkpoint)
344+
if config["_class_name"] == "VQModel":
345+
model = VQModel(**config)
346+
model.load_state_dict(converted_checkpoint)
347+
model.save_pretrained(args.dump_path)
348+
elif config["_class_name"] == "AutoencoderKL":
349+
model = AutoencoderKL(**config)
350+
model.load_state_dict(converted_checkpoint)
351+
model.save_pretrained(args.dump_path)
352+
else:
353+
model = UNet2DModel(**config)
354+
model.load_state_dict(converted_checkpoint)
230355

231-
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
356+
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
232357

233-
pipe = DDPMPipeline(unet=model, scheduler=scheduler)
234-
pipe.save_pretrained(args.dump_path)
358+
pipe = DDPMPipeline(unet=model, scheduler=scheduler)
359+
pipe.save_pretrained(args.dump_path)

src/diffusers/models/resnet.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,10 @@ def __init__(
288288

289289
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
290290

291-
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
291+
if temb_channels is not None:
292+
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
293+
else:
294+
self.time_emb_proj = None
292295

293296
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
294297
self.dropout = torch.nn.Dropout(dropout)
@@ -364,8 +367,9 @@ def set_weight(self, resnet):
364367
self.conv1.weight.data = resnet.conv1.weight.data
365368
self.conv1.bias.data = resnet.conv1.bias.data
366369

367-
self.time_emb_proj.weight.data = resnet.temb_proj.weight.data
368-
self.time_emb_proj.bias.data = resnet.temb_proj.bias.data
370+
if self.time_emb_proj is not None:
371+
self.time_emb_proj.weight.data = resnet.temb_proj.weight.data
372+
self.time_emb_proj.bias.data = resnet.temb_proj.bias.data
369373

370374
self.norm2.weight.data = resnet.norm2.weight.data
371375
self.norm2.bias.data = resnet.norm2.bias.data

0 commit comments

Comments
 (0)