-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathtool_add_vae.py
More file actions
34 lines (24 loc) · 1.04 KB
/
tool_add_vae.py
File metadata and controls
34 lines (24 loc) · 1.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
orginal_weight_path = 'models/control_sd15_ini.ckpt'
trained_weight_path = 'logs/***.ckpt'
output_path = 'models/DODA-coco-wvae.ckpt'
assert os.path.exists(orginal_weight_path), 'Original model does not exist.'
assert os.path.exists(trained_weight_path), 'Trained model does not exist.'
assert not os.path.exists(output_path), 'Output filename already exists.'
assert os.path.exists(os.path.dirname(output_path)), 'Output path is not valid.'
import torch
from cldm.model import create_model
def get_node_name(name, parent_name):
if len(name) <= len(parent_name):
return False, ''
p = name[:len(parent_name)]
if p != parent_name:
return False, ''
return True, name[len(parent_name):]
model = create_model(config_path='./configs/controlnet/coco_train.yaml')
orginal_weight = torch.load(orginal_weight_path)
trained_weight = torch.load(trained_weight_path)
model.load_state_dict(orginal_weight, strict=True)
model.load_state_dict(trained_weight, strict=False)
torch.save(model.state_dict(), output_path)
print('Done.')