Skip to content

Commit ac90cf3

Browse files
committed
safetensors optional for now
1 parent 210cb4c commit ac90cf3

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

modules/sd_models.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import gc
55
from collections import namedtuple
66
import torch
7-
from safetensors.torch import load_file, save_file
87
import re
98
from omegaconf import OmegaConf
109

@@ -149,6 +148,10 @@ def torch_load(model_filename, model_info, map_override=None):
149148
# safely load weights
150149
# TODO: safetensors supports zero copy fast load to gpu, see issue #684.
151150
# GPU only for now, see https://github.com/huggingface/safetensors/issues/95
151+
try:
152+
from safetensors.torch import load_file
153+
except ImportError as e:
154+
raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}")
152155
return load_file(model_filename, device='cuda')
153156
else:
154157
return torch.load(model_filename, map_location=map_override)
@@ -157,6 +160,10 @@ def torch_save(model, output_filename):
157160
basename, exttype = os.path.splitext(output_filename)
158161
if(checkpoint_types[exttype] == 'safetensors'):
159162
# [===== >] Reticulating brines...
163+
try:
164+
from safetensors.torch import save_file
165+
except ImportError as e:
166+
raise ImportError(f"Export as safetensors selected, yet it is not installed, use `pip install safetensors`: {e}")
160167
save_file(model, output_filename, metadata={"format": "pt"})
161168
else:
162169
torch.save(model, output_filename)

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,3 @@ kornia
2828
lark
2929
inflection
3030
GitPython
31-
safetensors

0 commit comments

Comments
 (0)