Skip to content

Commit 8489cba

Browse files
committed
add unique ID per word/embedding for tokenizer
1 parent f5f7013 commit 8489cba

File tree

1 file changed

+70
-45
lines changed

1 file changed

+70
-45
lines changed

comfy/sd1_clip.py

Lines changed: 70 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -224,60 +224,85 @@ def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedd
224224
self.inv_vocab = {v: k for k, v in vocab.items()}
225225
self.embedding_directory = embedding_directory
226226
self.max_word_length = 8
227-
228-
def tokenize_with_weights(self, text):
227+
self.embedding_identifier = "embedding:"
228+
229+
def _try_get_embedding(self, name:str):
230+
'''
231+
Takes a potential embedding name and tries to retrieve it.
232+
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
233+
'''
234+
embedding_name = name[len(self.embedding_identifier):].strip('\n')
235+
embed = load_embed(embedding_name, self.embedding_directory)
236+
if embed is None:
237+
stripped = embedding_name.strip(',')
238+
if len(stripped) < len(embedding_name):
239+
embed = load_embed(stripped, self.embedding_directory)
240+
return (embed, embedding_name[len(stripped):])
241+
return (embed, "")
242+
243+
244+
def tokenize_with_weights(self, text:str):
245+
'''
246+
Takes a prompt and converts it to a list of (token, weight, word id) elements.
247+
Tokens can both be integer tokens and pre computed CLIP tensors.
248+
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
249+
Returned list has the dimensions NxM where M is the input size of CLIP
250+
'''
229251
text = escape_important(text)
230252
parsed_weights = token_weights(text, 1.0)
231253

254+
#tokenize words
232255
tokens = []
233-
for t in parsed_weights:
234-
to_tokenize = unescape_important(t[0]).replace("\n", " ").split(' ')
235-
while len(to_tokenize) > 0:
236-
word = to_tokenize.pop(0)
237-
temp_tokens = []
238-
embedding_identifier = "embedding:"
239-
if word.startswith(embedding_identifier) and self.embedding_directory is not None:
240-
embedding_name = word[len(embedding_identifier):].strip('\n')
241-
embed = load_embed(embedding_name, self.embedding_directory)
256+
for weighted_segment, weight in parsed_weights:
257+
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
258+
to_tokenize = [x for x in to_tokenize if x != ""]
259+
for word in to_tokenize:
260+
#if we find an embedding, deal with the embedding
261+
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
262+
embed, leftover = self._try_get_embedding(word)
242263
if embed is None:
243-
stripped = embedding_name.strip(',')
244-
if len(stripped) < len(embedding_name):
245-
embed = load_embed(stripped, self.embedding_directory)
246-
if embed is not None:
247-
to_tokenize.insert(0, embedding_name[len(stripped):])
248-
249-
if embed is not None:
264+
print(f"warning, embedding:{word} does not exist, ignoring")
265+
else:
250266
if len(embed.shape) == 1:
251-
temp_tokens += [(embed, t[1])]
267+
tokens.append([(embed, weight)])
252268
else:
253-
for x in range(embed.shape[0]):
254-
temp_tokens += [(embed[x], t[1])]
269+
tokens.append([(embed[x], weight) for x in range(embed.shape[0])])
270+
#if we accidentally have leftover text, continue parsing using leftover, else move on to next word
271+
if leftover != "":
272+
word = leftover
255273
else:
256-
print("warning, embedding:{} does not exist, ignoring".format(embedding_name))
257-
elif len(word) > 0:
258-
tt = self.tokenizer(word)["input_ids"][1:-1]
259-
for x in tt:
260-
temp_tokens += [(x, t[1])]
261-
tokens_left = self.max_tokens_per_section - (len(tokens) % self.max_tokens_per_section)
262-
263-
#try not to split words in different sections
264-
if tokens_left < len(temp_tokens) and len(temp_tokens) < (self.max_word_length):
265-
for x in range(tokens_left):
266-
tokens += [(self.end_token, 1.0)]
267-
tokens += temp_tokens
268-
269-
out_tokens = []
270-
for x in range(0, len(tokens), self.max_tokens_per_section):
271-
o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_tokens_per_section + x, len(tokens))]
272-
o_token += [(self.end_token, 1.0)]
273-
if self.pad_with_end:
274-
o_token +=[(self.end_token, 1.0)] * (self.max_length - len(o_token))
275-
else:
276-
o_token +=[(0, 1.0)] * (self.max_length - len(o_token))
277-
278-
out_tokens += [o_token]
274+
continue
275+
#parse word
276+
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]])
277+
278+
#reshape token array to CLIP input size
279+
batched_tokens = []
280+
batch = []
281+
batched_tokens.append(batch)
282+
for i, t_group in enumerate(tokens):
283+
#start a new batch if there is not enough room
284+
if len(t_group) + len(batch) > self.max_tokens_per_section:
285+
remaining_length = self.max_tokens_per_section - len(batch)
286+
#fill remaining space depending on length of tokens
287+
if len(t_group) > self.max_word_length:
288+
#put part of group of tokens in the batch
289+
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
290+
t_group = t_group[remaining_length:]
291+
else:
292+
#filler tokens
293+
batch.extend([(self.end_token, 1.0, 0)] * remaining_length)
294+
batch = []
295+
batched_tokens.append(batch)
296+
#put current group of tokens in the batch
297+
batch.extend([(t,w,i+1) for t,w in t_group])
298+
299+
#fill last batch
300+
batch.extend([(self.end_token, 1.0, 0)] * (self.max_tokens_per_section - len(batch)))
301+
302+
#add start and end tokens
303+
batched_tokens = [[(self.start_token, 1.0, 0)] + x + [(self.end_token, 1.0, 0)] for x in batched_tokens]
304+
return batched_tokens
279305

280-
return out_tokens
281306

282307
def untokenize(self, token_weight_pair):
283308
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))

0 commit comments

Comments
 (0)