Skip to content

Commit fb02316

Browse files
authored
Add AnimateDiff conversion scripts (#6340)
* add scripts * update
1 parent 98a2b3d commit fb02316

File tree

2 files changed

+102
-0
lines changed

2 files changed

+102
-0
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import argparse
2+
3+
import torch
4+
from safetensors.torch import save_file
5+
6+
7+
def convert_motion_module(original_state_dict):
8+
converted_state_dict = {}
9+
for k, v in original_state_dict.items():
10+
if "pos_encoder" in k:
11+
continue
12+
13+
else:
14+
converted_state_dict[
15+
k.replace(".norms.0", ".norm1")
16+
.replace(".norms.1", ".norm2")
17+
.replace(".ff_norm", ".norm3")
18+
.replace(".attention_blocks.0", ".attn1")
19+
.replace(".attention_blocks.1", ".attn2")
20+
.replace(".temporal_transformer", "")
21+
] = v
22+
23+
return converted_state_dict
24+
25+
26+
def get_args():
27+
parser = argparse.ArgumentParser()
28+
parser.add_argument("--ckpt_path", type=str, required=True)
29+
parser.add_argument("--output_path", type=str, required=True)
30+
31+
return parser.parse_args()
32+
33+
34+
if __name__ == "__main__":
35+
args = get_args()
36+
37+
state_dict = torch.load(args.ckpt_path, map_location="cpu")
38+
39+
if "state_dict" in state_dict.keys():
40+
state_dict = state_dict["state_dict"]
41+
42+
conv_state_dict = convert_motion_module(state_dict)
43+
44+
# convert to new format
45+
output_dict = {}
46+
for module_name, params in conv_state_dict.items():
47+
if type(params) is not torch.Tensor:
48+
continue
49+
output_dict.update({f"unet.{module_name}": params})
50+
51+
save_file(output_dict, f"{args.output_path}/diffusion_pytorch_model.safetensors")
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import argparse
2+
3+
import torch
4+
5+
from diffusers import MotionAdapter
6+
7+
8+
def convert_motion_module(original_state_dict):
9+
converted_state_dict = {}
10+
for k, v in original_state_dict.items():
11+
if "pos_encoder" in k:
12+
continue
13+
14+
else:
15+
converted_state_dict[
16+
k.replace(".norms.0", ".norm1")
17+
.replace(".norms.1", ".norm2")
18+
.replace(".ff_norm", ".norm3")
19+
.replace(".attention_blocks.0", ".attn1")
20+
.replace(".attention_blocks.1", ".attn2")
21+
.replace(".temporal_transformer", "")
22+
] = v
23+
24+
return converted_state_dict
25+
26+
27+
def get_args():
28+
parser = argparse.ArgumentParser()
29+
parser.add_argument("--ckpt_path", type=str, required=True)
30+
parser.add_argument("--output_path", type=str, required=True)
31+
parser.add_argument("--use_motion_mid_block", action="store_true")
32+
parser.add_argument("--motion_max_seq_length", type=int, default=32)
33+
34+
return parser.parse_args()
35+
36+
37+
if __name__ == "__main__":
38+
args = get_args()
39+
40+
state_dict = torch.load(args.ckpt_path, map_location="cpu")
41+
if "state_dict" in state_dict.keys():
42+
state_dict = state_dict["state_dict"]
43+
44+
conv_state_dict = convert_motion_module(state_dict)
45+
adapter = MotionAdapter(
46+
use_motion_mid_block=args.use_motion_mid_block, motion_max_seq_length=args.motion_max_seq_length
47+
)
48+
# skip loading position embeddings
49+
adapter.load_state_dict(conv_state_dict, strict=False)
50+
adapter.save_pretrained(args.output_path)
51+
adapter.save_pretrained(args.output_path, variant="fp16", torch_dtype=torch.float16)

0 commit comments

Comments
 (0)