Skip to content

Commit eb96ff0

Browse files
authored
Safetensor loading in AnimateDiff conversion scripts (#7764)
* update * update
1 parent a38dd79 commit eb96ff0

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

scripts/convert_animatediff_motion_lora_to_diffusers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import argparse
22

33
import torch
4-
from safetensors.torch import save_file
4+
from safetensors.torch import load_file, save_file
55

66

77
def convert_motion_module(original_state_dict):
@@ -34,7 +34,10 @@ def get_args():
3434
if __name__ == "__main__":
3535
args = get_args()
3636

37-
state_dict = torch.load(args.ckpt_path, map_location="cpu")
37+
if args.ckpt_path.endswith(".safetensors"):
38+
state_dict = load_file(args.ckpt_path)
39+
else:
40+
state_dict = torch.load(args.ckpt_path, map_location="cpu")
3841

3942
if "state_dict" in state_dict.keys():
4043
state_dict = state_dict["state_dict"]

scripts/convert_animatediff_motion_module_to_diffusers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22

33
import torch
4+
from safetensors.torch import load_file
45

56
from diffusers import MotionAdapter
67

@@ -38,7 +39,11 @@ def get_args():
3839
if __name__ == "__main__":
3940
args = get_args()
4041

41-
state_dict = torch.load(args.ckpt_path, map_location="cpu")
42+
if args.ckpt_path.endswith(".safetensors"):
43+
state_dict = load_file(args.ckpt_path)
44+
else:
45+
state_dict = torch.load(args.ckpt_path, map_location="cpu")
46+
4247
if "state_dict" in state_dict.keys():
4348
state_dict = state_dict["state_dict"]
4449

0 commit comments

Comments
 (0)