Skip to content

Commit d0b1b6c

Browse files
committed
fixed improper padding
1 parent da115bd commit d0b1b6c

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

comfy/sd1_clip.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,11 @@ def tokenize_with_weights(self, text:str, return_word_ids=False):
247247
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
248248
Returned list has the dimensions NxM where M is the input size of CLIP
249249
'''
250+
if self.pad_with_end:
251+
pad_token = self.end_token
252+
else:
253+
pad_token = 0
254+
250255
text = escape_important(text)
251256
parsed_weights = token_weights(text, 1.0)
252257

@@ -277,30 +282,33 @@ def tokenize_with_weights(self, text:str, return_word_ids=False):
277282

278283
#reshape token array to CLIP input size
279284
batched_tokens = []
280-
batch = []
285+
batch = [(self.start_token, 1.0, 0)]
281286
batched_tokens.append(batch)
282287
for i, t_group in enumerate(tokens):
283288
#determine if we're going to try and keep the tokens in a single batch
284289
is_large = len(t_group) >= self.max_word_length
290+
285291
while len(t_group) > 0:
286-
if len(t_group) + len(batch) > self.max_tokens_per_section:
287-
remaining_length = self.max_tokens_per_section - len(batch)
292+
if len(t_group) + len(batch) > self.max_length - 1:
293+
remaining_length = self.max_length - len(batch) - 1
294+
#break word in two and add end token
288295
if is_large:
289296
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
297+
batch.append((self.end_token, 1.0, 0))
290298
t_group = t_group[remaining_length:]
299+
#add end token and pad
291300
else:
292-
batch.extend([(self.end_token, 1.0, 0)] * remaining_length)
293-
batch = []
301+
batch.append((self.end_token, 1.0, 0))
302+
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
303+
#start new batch
304+
batch = [(self.start_token, 1.0, 0)]
294305
batched_tokens.append(batch)
295306
else:
296307
batch.extend([(t,w,i+1) for t,w in t_group])
297308
t_group = []
298309

299310
#fill last batch
300-
batch.extend([(self.end_token, 1.0, 0)] * (self.max_tokens_per_section - len(batch)))
301-
302-
#add start and end tokens
303-
batched_tokens = [[(self.start_token, 1.0, 0)] + x + [(self.end_token, 1.0, 0)] for x in batched_tokens]
311+
batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1))
304312

305313
if not return_word_ids:
306314
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]

0 commit comments

Comments
 (0)