4
4
import gc
5
5
from collections import namedtuple
6
6
import torch
7
- from safetensors .torch import load_file , save_file
8
7
import re
9
8
from omegaconf import OmegaConf
10
9
@@ -149,6 +148,10 @@ def torch_load(model_filename, model_info, map_override=None):
149
148
# safely load weights
150
149
# TODO: safetensors supports zero copy fast load to gpu, see issue #684.
151
150
# GPU only for now, see https://github.com/huggingface/safetensors/issues/95
151
+ try :
152
+ from safetensors .torch import load_file
153
+ except ImportError as e :
154
+ raise ImportError (f"The model is in safetensors format and it is not installed, use `pip install safetensors`: { e } " )
152
155
return load_file (model_filename , device = 'cuda' )
153
156
else :
154
157
return torch .load (model_filename , map_location = map_override )
@@ -157,6 +160,10 @@ def torch_save(model, output_filename):
157
160
basename , exttype = os .path .splitext (output_filename )
158
161
if (checkpoint_types [exttype ] == 'safetensors' ):
159
162
# [===== >] Reticulating brines...
163
+ try :
164
+ from safetensors .torch import save_file
165
+ except ImportError as e :
166
+ raise ImportError (f"Export as safetensors selected, yet it is not installed, use `pip install safetensors`: { e } " )
160
167
save_file (model , output_filename , metadata = {"format" : "pt" })
161
168
else :
162
169
torch .save (model , output_filename )
0 commit comments