Skip to content

Commit 04d9bc1

Browse files
Safely load pickled embeds that don't load with weights_only=True.
1 parent 334aab0 commit 04d9bc1

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

comfy/sd1_clip.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig
44
import torch
55
import traceback
6+
import zipfile
67

78
class ClipTokenWeightEncoder:
89
def encode_token_weights(self, token_weight_pairs):
@@ -171,6 +172,26 @@ def unescape_important(text):
171172
text = text.replace("\0\2", "(")
172173
return text
173174

175+
def safe_load_embed_zip(embed_path):
176+
with zipfile.ZipFile(embed_path) as myzip:
177+
names = list(filter(lambda a: "data/" in a, myzip.namelist()))
178+
names.reverse()
179+
for n in names:
180+
with myzip.open(n) as myfile:
181+
data = myfile.read()
182+
number = len(data) // 4
183+
length_embed = 1024 #sd2.x
184+
if number < 768:
185+
continue
186+
if number % 768 == 0:
187+
length_embed = 768 #sd1.x
188+
num_embeds = number // length_embed
189+
embed = torch.frombuffer(data, dtype=torch.float)
190+
out = embed.reshape((num_embeds, length_embed)).clone()
191+
del embed
192+
return out
193+
194+
174195
def load_embed(embedding_name, embedding_directory):
175196
if isinstance(embedding_directory, str):
176197
embedding_directory = [embedding_directory]
@@ -195,13 +216,18 @@ def load_embed(embedding_name, embedding_directory):
195216

196217
embed_path = valid_file
197218

219+
embed_out = None
220+
198221
try:
199222
if embed_path.lower().endswith(".safetensors"):
200223
import safetensors.torch
201224
embed = safetensors.torch.load_file(embed_path, device="cpu")
202225
else:
203226
if 'weights_only' in torch.load.__code__.co_varnames:
204-
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
227+
try:
228+
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
229+
except:
230+
embed_out = safe_load_embed_zip(embed_path)
205231
else:
206232
embed = torch.load(embed_path, map_location="cpu")
207233
except Exception as e:
@@ -210,11 +236,13 @@ def load_embed(embedding_name, embedding_directory):
210236
print("error loading embedding, skipping loading:", embedding_name)
211237
return None
212238

213-
if 'string_to_param' in embed:
214-
values = embed['string_to_param'].values()
215-
else:
216-
values = embed.values()
217-
return next(iter(values))
239+
if embed_out is None:
240+
if 'string_to_param' in embed:
241+
values = embed['string_to_param'].values()
242+
else:
243+
values = embed.values()
244+
embed_out = next(iter(values))
245+
return embed_out
218246

219247
class SD1Tokenizer:
220248
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None):

0 commit comments

Comments
 (0)