Skip to content

Commit 6074175

Browse files
committed
add safetensors to requirements
1 parent f108782 commit 6074175

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

modules/sd_models.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from collections import namedtuple
66
import torch
77
import re
8+
import safetensors.torch
89
from omegaconf import OmegaConf
910

1011
from ldm.util import instantiate_from_config
@@ -173,14 +174,12 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
173174
# load from file
174175
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
175176

176-
if checkpoint_file.endswith(".safetensors"):
177-
try:
178-
from safetensors.torch import load_file
179-
except ImportError as e:
180-
raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}")
181-
pl_sd = load_file(checkpoint_file, device=shared.weight_load_location)
177+
_, extension = os.path.splitext(checkpoint_file)
178+
if extension.lower() == ".safetensors":
179+
pl_sd = safetensors.torch.load_file(checkpoint_file, device=shared.weight_load_location)
182180
else:
183181
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
182+
184183
if "global_step" in pl_sd:
185184
print(f"Global Step: {pl_sd['global_step']}")
186185

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,4 @@ lark
2929
inflection
3030
GitPython
3131
torchsde
32+
safetensors

requirements_versions.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@ lark==1.1.2
2626
inflection==0.5.1
2727
GitPython==3.1.27
2828
torchsde==0.2.5
29+
safetensors==0.2.5

0 commit comments

Comments
 (0)