@@ -247,6 +247,11 @@ def tokenize_with_weights(self, text:str, return_word_ids=False):
247247 Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
248248 Returned list has the dimensions NxM where M is the input size of CLIP
249249 '''
250+ if self .pad_with_end :
251+ pad_token = self .end_token
252+ else :
253+ pad_token = 0
254+
250255 text = escape_important (text )
251256 parsed_weights = token_weights (text , 1.0 )
252257
@@ -277,30 +282,33 @@ def tokenize_with_weights(self, text:str, return_word_ids=False):
277282
278283 #reshape token array to CLIP input size
279284 batched_tokens = []
280- batch = []
285+ batch = [( self . start_token , 1.0 , 0 ) ]
281286 batched_tokens .append (batch )
282287 for i , t_group in enumerate (tokens ):
283288 #determine if we're going to try and keep the tokens in a single batch
284289 is_large = len (t_group ) >= self .max_word_length
290+
285291 while len (t_group ) > 0 :
286- if len (t_group ) + len (batch ) > self .max_tokens_per_section :
287- remaining_length = self .max_tokens_per_section - len (batch )
292+ if len (t_group ) + len (batch ) > self .max_length - 1 :
293+ remaining_length = self .max_length - len (batch ) - 1
294+ #break word in two and add end token
288295 if is_large :
289296 batch .extend ([(t ,w ,i + 1 ) for t ,w in t_group [:remaining_length ]])
297+ batch .append ((self .end_token , 1.0 , 0 ))
290298 t_group = t_group [remaining_length :]
299+ #add end token and pad
291300 else :
292- batch .extend ([(self .end_token , 1.0 , 0 )] * remaining_length )
293- batch = []
301+ batch .append ((self .end_token , 1.0 , 0 ))
302+ batch .extend ([(pad_token , 1.0 , 0 )] * (remaining_length ))
303+ #start new batch
304+ batch = [(self .start_token , 1.0 , 0 )]
294305 batched_tokens .append (batch )
295306 else :
296307 batch .extend ([(t ,w ,i + 1 ) for t ,w in t_group ])
297308 t_group = []
298309
299310 #fill last batch
300- batch .extend ([(self .end_token , 1.0 , 0 )] * (self .max_tokens_per_section - len (batch )))
301-
302- #add start and end tokens
303- batched_tokens = [[(self .start_token , 1.0 , 0 )] + x + [(self .end_token , 1.0 , 0 )] for x in batched_tokens ]
311+ batch .extend ([(self .end_token , 1.0 , 0 )] + [(pad_token , 1.0 , 0 )] * (self .max_length - len (batch ) - 1 ))
304312
305313 if not return_word_ids :
306314 batched_tokens = [[(t , w ) for t , w ,_ in x ] for x in batched_tokens ]
0 commit comments