Skip to content

Commit 719c26c

Browse files
Merge branch 'master' of https://github.com/BlenderNeko/ComfyUI
2 parents a404998 + d0b1b6c commit 719c26c

File tree

2 files changed

+87
-44
lines changed

2 files changed

+87
-44
lines changed

comfy/sd.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,10 +372,16 @@ def add_patches(self, patches, strength=1.0):
372372
def clip_layer(self, layer_idx):
373373
self.layer_idx = layer_idx
374374

375-
def encode(self, text):
375+
def tokenize(self, text, return_word_ids=False):
376+
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
377+
378+
def encode(self, text, from_tokens=False):
376379
if self.layer_idx is not None:
377380
self.cond_stage_model.clip_layer(self.layer_idx)
378-
tokens = self.tokenizer.tokenize_with_weights(text)
381+
if from_tokens:
382+
tokens = text
383+
else:
384+
tokens = self.tokenizer.tokenize_with_weights(text)
379385
try:
380386
self.patcher.patch_model()
381387
cond = self.cond_stage_model.encode_token_weights(tokens)

comfy/sd1_clip.py

Lines changed: 79 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -260,60 +260,97 @@ def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedd
260260
self.inv_vocab = {v: k for k, v in vocab.items()}
261261
self.embedding_directory = embedding_directory
262262
self.max_word_length = 8
263+
self.embedding_identifier = "embedding:"
264+
265+
def _try_get_embedding(self, embedding_name:str):
266+
'''
267+
Takes a potential embedding name and tries to retrieve it.
268+
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
269+
'''
270+
embed = load_embed(embedding_name, self.embedding_directory)
271+
if embed is None:
272+
stripped = embedding_name.strip(',')
273+
if len(stripped) < len(embedding_name):
274+
embed = load_embed(stripped, self.embedding_directory)
275+
return (embed, embedding_name[len(stripped):])
276+
return (embed, "")
277+
278+
279+
def tokenize_with_weights(self, text:str, return_word_ids=False):
280+
'''
281+
Takes a prompt and converts it to a list of (token, weight, word id) elements.
282+
Tokens can both be integer tokens and pre computed CLIP tensors.
283+
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
284+
Returned list has the dimensions NxM where M is the input size of CLIP
285+
'''
286+
if self.pad_with_end:
287+
pad_token = self.end_token
288+
else:
289+
pad_token = 0
263290

264-
def tokenize_with_weights(self, text):
265291
text = escape_important(text)
266292
parsed_weights = token_weights(text, 1.0)
267293

294+
#tokenize words
268295
tokens = []
269-
for t in parsed_weights:
270-
to_tokenize = unescape_important(t[0]).replace("\n", " ").split(' ')
271-
while len(to_tokenize) > 0:
272-
word = to_tokenize.pop(0)
273-
temp_tokens = []
274-
embedding_identifier = "embedding:"
275-
if word.startswith(embedding_identifier) and self.embedding_directory is not None:
276-
embedding_name = word[len(embedding_identifier):].strip('\n')
277-
embed = load_embed(embedding_name, self.embedding_directory)
296+
for weighted_segment, weight in parsed_weights:
297+
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
298+
to_tokenize = [x for x in to_tokenize if x != ""]
299+
for word in to_tokenize:
300+
#if we find an embedding, deal with the embedding
301+
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
302+
embedding_name = word[len(self.embedding_identifier):].strip('\n')
303+
embed, leftover = self._try_get_embedding(embedding_name)
278304
if embed is None:
279-
stripped = embedding_name.strip(',')
280-
if len(stripped) < len(embedding_name):
281-
embed = load_embed(stripped, self.embedding_directory)
282-
if embed is not None:
283-
to_tokenize.insert(0, embedding_name[len(stripped):])
284-
285-
if embed is not None:
305+
print(f"warning, embedding:{embedding_name} does not exist, ignoring")
306+
else:
286307
if len(embed.shape) == 1:
287-
temp_tokens += [(embed, t[1])]
308+
tokens.append([(embed, weight)])
288309
else:
289-
for x in range(embed.shape[0]):
290-
temp_tokens += [(embed[x], t[1])]
310+
tokens.append([(embed[x], weight) for x in range(embed.shape[0])])
311+
#if we accidentally have leftover text, continue parsing using leftover, else move on to next word
312+
if leftover != "":
313+
word = leftover
314+
else:
315+
continue
316+
#parse word
317+
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]])
318+
319+
#reshape token array to CLIP input size
320+
batched_tokens = []
321+
batch = [(self.start_token, 1.0, 0)]
322+
batched_tokens.append(batch)
323+
for i, t_group in enumerate(tokens):
324+
#determine if we're going to try and keep the tokens in a single batch
325+
is_large = len(t_group) >= self.max_word_length
326+
327+
while len(t_group) > 0:
328+
if len(t_group) + len(batch) > self.max_length - 1:
329+
remaining_length = self.max_length - len(batch) - 1
330+
#break word in two and add end token
331+
if is_large:
332+
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
333+
batch.append((self.end_token, 1.0, 0))
334+
t_group = t_group[remaining_length:]
335+
#add end token and pad
291336
else:
292-
print("warning, embedding:{} does not exist, ignoring".format(embedding_name))
293-
elif len(word) > 0:
294-
tt = self.tokenizer(word)["input_ids"][1:-1]
295-
for x in tt:
296-
temp_tokens += [(x, t[1])]
297-
tokens_left = self.max_tokens_per_section - (len(tokens) % self.max_tokens_per_section)
298-
299-
#try not to split words in different sections
300-
if tokens_left < len(temp_tokens) and len(temp_tokens) < (self.max_word_length):
301-
for x in range(tokens_left):
302-
tokens += [(self.end_token, 1.0)]
303-
tokens += temp_tokens
337+
batch.append((self.end_token, 1.0, 0))
338+
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
339+
#start new batch
340+
batch = [(self.start_token, 1.0, 0)]
341+
batched_tokens.append(batch)
342+
else:
343+
batch.extend([(t,w,i+1) for t,w in t_group])
344+
t_group = []
345+
346+
#fill last batch
347+
batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1))
304348

305-
out_tokens = []
306-
for x in range(0, len(tokens), self.max_tokens_per_section):
307-
o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_tokens_per_section + x, len(tokens))]
308-
o_token += [(self.end_token, 1.0)]
309-
if self.pad_with_end:
310-
o_token +=[(self.end_token, 1.0)] * (self.max_length - len(o_token))
311-
else:
312-
o_token +=[(0, 1.0)] * (self.max_length - len(o_token))
349+
if not return_word_ids:
350+
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
313351

314-
out_tokens += [o_token]
352+
return batched_tokens
315353

316-
return out_tokens
317354

318355
def untokenize(self, token_weight_pair):
319356
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))

0 commit comments

Comments
 (0)