Skip to content

Commit 752f7a1

Browse files
committed
align behavior with old tokenize function
1 parent 44fe868 commit 752f7a1

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

comfy/sd1_clip.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -226,12 +226,11 @@ def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedd
226226
self.max_word_length = 8
227227
self.embedding_identifier = "embedding:"
228228

229-
def _try_get_embedding(self, name:str):
229+
def _try_get_embedding(self, embedding_name:str):
230230
'''
231231
Takes a potential embedding name and tries to retrieve it.
232232
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
233233
'''
234-
embedding_name = name[len(self.embedding_identifier):].strip('\n')
235234
embed = load_embed(embedding_name, self.embedding_directory)
236235
if embed is None:
237236
stripped = embedding_name.strip(',')
@@ -259,9 +258,10 @@ def tokenize_with_weights(self, text:str):
259258
for word in to_tokenize:
260259
#if we find an embedding, deal with the embedding
261260
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
262-
embed, leftover = self._try_get_embedding(word)
261+
embedding_name = word[len(self.embedding_identifier):].strip('\n')
262+
embed, leftover = self._try_get_embedding(embedding_name)
263263
if embed is None:
264-
print(f"warning, embedding:{word} does not exist, ignoring")
264+
print(f"warning, embedding:{embedding_name} does not exist, ignoring")
265265
else:
266266
if len(embed.shape) == 1:
267267
tokens.append([(embed, weight)])
@@ -280,21 +280,21 @@ def tokenize_with_weights(self, text:str):
280280
batch = []
281281
batched_tokens.append(batch)
282282
for i, t_group in enumerate(tokens):
283-
#start a new batch if there is not enough room
284-
if len(t_group) + len(batch) > self.max_tokens_per_section:
285-
remaining_length = self.max_tokens_per_section - len(batch)
286-
#fill remaining space depending on length of tokens
287-
if len(t_group) > self.max_word_length:
288-
#put part of group of tokens in the batch
289-
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
290-
t_group = t_group[remaining_length:]
283+
#determine if we're going to try and keep the tokens in a single batch
284+
is_large = len(t_group) >= self.max_word_length
285+
while len(t_group) > 0:
286+
if len(t_group) + len(batch) > self.max_tokens_per_section:
287+
remaining_length = self.max_tokens_per_section - len(batch)
288+
if is_large:
289+
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
290+
t_group = t_group[remaining_length:]
291+
else:
292+
batch.extend([(self.end_token, 1.0, 0)] * remaining_length)
293+
batch = []
294+
batched_tokens.append(batch)
291295
else:
292-
#filler tokens
293-
batch.extend([(self.end_token, 1.0, 0)] * remaining_length)
294-
batch = []
295-
batched_tokens.append(batch)
296-
#put current group of tokens in the batch
297-
batch.extend([(t,w,i+1) for t,w in t_group])
296+
batch.extend([(t,w,i+1) for t,w in t_group])
297+
t_group = []
298298

299299
#fill last batch
300300
batch.extend([(self.end_token, 1.0, 0)] * (self.max_tokens_per_section - len(batch)))

0 commit comments

Comments
 (0)