@@ -226,12 +226,11 @@ def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedd
226226 self .max_word_length = 8
227227 self .embedding_identifier = "embedding:"
228228
229- def _try_get_embedding (self , name :str ):
229+ def _try_get_embedding (self , embedding_name :str ):
230230 '''
231231 Takes a potential embedding name and tries to retrieve it.
232232 Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
233233 '''
234- embedding_name = name [len (self .embedding_identifier ):].strip ('\n ' )
235234 embed = load_embed (embedding_name , self .embedding_directory )
236235 if embed is None :
237236 stripped = embedding_name .strip (',' )
@@ -259,9 +258,10 @@ def tokenize_with_weights(self, text:str):
259258 for word in to_tokenize :
260259 #if we find an embedding, deal with the embedding
261260 if word .startswith (self .embedding_identifier ) and self .embedding_directory is not None :
262- embed , leftover = self ._try_get_embedding (word )
261+ embedding_name = word [len (self .embedding_identifier ):].strip ('\n ' )
262+ embed , leftover = self ._try_get_embedding (embedding_name )
263263 if embed is None :
264- print (f"warning, embedding:{ word } does not exist, ignoring" )
264+ print (f"warning, embedding:{ embedding_name } does not exist, ignoring" )
265265 else :
266266 if len (embed .shape ) == 1 :
267267 tokens .append ([(embed , weight )])
@@ -280,21 +280,21 @@ def tokenize_with_weights(self, text:str):
280280 batch = []
281281 batched_tokens .append (batch )
282282 for i , t_group in enumerate (tokens ):
283- #start a new batch if there is not enough room
284- if len (t_group ) + len (batch ) > self .max_tokens_per_section :
285- remaining_length = self .max_tokens_per_section - len (batch )
286- #fill remaining space depending on length of tokens
287- if len (t_group ) > self .max_word_length :
288- #put part of group of tokens in the batch
289- batch .extend ([(t ,w ,i + 1 ) for t ,w in t_group [:remaining_length ]])
290- t_group = t_group [remaining_length :]
283+ #determine if we're going to try and keep the tokens in a single batch
284+ is_large = len (t_group ) >= self .max_word_length
285+ while len (t_group ) > 0 :
286+ if len (t_group ) + len (batch ) > self .max_tokens_per_section :
287+ remaining_length = self .max_tokens_per_section - len (batch )
288+ if is_large :
289+ batch .extend ([(t ,w ,i + 1 ) for t ,w in t_group [:remaining_length ]])
290+ t_group = t_group [remaining_length :]
291+ else :
292+ batch .extend ([(self .end_token , 1.0 , 0 )] * remaining_length )
293+ batch = []
294+ batched_tokens .append (batch )
291295 else :
292- #filler tokens
293- batch .extend ([(self .end_token , 1.0 , 0 )] * remaining_length )
294- batch = []
295- batched_tokens .append (batch )
296- #put current group of tokens in the batch
297- batch .extend ([(t ,w ,i + 1 ) for t ,w in t_group ])
296+ batch .extend ([(t ,w ,i + 1 ) for t ,w in t_group ])
297+ t_group = []
298298
299299 #fill last batch
300300 batch .extend ([(self .end_token , 1.0 , 0 )] * (self .max_tokens_per_section - len (batch )))
0 commit comments