@@ -224,60 +224,85 @@ def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedd
224224 self .inv_vocab = {v : k for k , v in vocab .items ()}
225225 self .embedding_directory = embedding_directory
226226 self .max_word_length = 8
227-
228- def tokenize_with_weights (self , text ):
227+ self .embedding_identifier = "embedding:"
228+
229+ def _try_get_embedding (self , name :str ):
230+ '''
231+ Takes a potential embedding name and tries to retrieve it.
232+ Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
233+ '''
234+ embedding_name = name [len (self .embedding_identifier ):].strip ('\n ' )
235+ embed = load_embed (embedding_name , self .embedding_directory )
236+ if embed is None :
237+ stripped = embedding_name .strip (',' )
238+ if len (stripped ) < len (embedding_name ):
239+ embed = load_embed (stripped , self .embedding_directory )
240+ return (embed , embedding_name [len (stripped ):])
241+ return (embed , "" )
242+
243+
244+ def tokenize_with_weights (self , text :str ):
245+ '''
246+ Takes a prompt and converts it to a list of (token, weight, word id) elements.
247+ Tokens can both be integer tokens and pre computed CLIP tensors.
248+ Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
249+ Returned list has the dimensions NxM where M is the input size of CLIP
250+ '''
229251 text = escape_important (text )
230252 parsed_weights = token_weights (text , 1.0 )
231253
254+ #tokenize words
232255 tokens = []
233- for t in parsed_weights :
234- to_tokenize = unescape_important (t [0 ]).replace ("\n " , " " ).split (' ' )
235- while len (to_tokenize ) > 0 :
236- word = to_tokenize .pop (0 )
237- temp_tokens = []
238- embedding_identifier = "embedding:"
239- if word .startswith (embedding_identifier ) and self .embedding_directory is not None :
240- embedding_name = word [len (embedding_identifier ):].strip ('\n ' )
241- embed = load_embed (embedding_name , self .embedding_directory )
256+ for weighted_segment , weight in parsed_weights :
257+ to_tokenize = unescape_important (weighted_segment ).replace ("\n " , " " ).split (' ' )
258+ to_tokenize = [x for x in to_tokenize if x != "" ]
259+ for word in to_tokenize :
260+ #if we find an embedding, deal with the embedding
261+ if word .startswith (self .embedding_identifier ) and self .embedding_directory is not None :
262+ embed , leftover = self ._try_get_embedding (word )
242263 if embed is None :
243- stripped = embedding_name .strip (',' )
244- if len (stripped ) < len (embedding_name ):
245- embed = load_embed (stripped , self .embedding_directory )
246- if embed is not None :
247- to_tokenize .insert (0 , embedding_name [len (stripped ):])
248-
249- if embed is not None :
264+ print (f"warning, embedding:{ word } does not exist, ignoring" )
265+ else :
250266 if len (embed .shape ) == 1 :
251- temp_tokens += [(embed , t [ 1 ])]
267+ tokens . append ( [(embed , weight )])
252268 else :
253- for x in range (embed .shape [0 ]):
254- temp_tokens += [(embed [x ], t [1 ])]
269+ tokens .append ([(embed [x ], weight ) for x in range (embed .shape [0 ])])
270+ #if we accidentally have leftover text, continue parsing using leftover, else move on to next word
271+ if leftover != "" :
272+ word = leftover
255273 else :
256- print ("warning, embedding:{} does not exist, ignoring" .format (embedding_name ))
257- elif len (word ) > 0 :
258- tt = self .tokenizer (word )["input_ids" ][1 :- 1 ]
259- for x in tt :
260- temp_tokens += [(x , t [1 ])]
261- tokens_left = self .max_tokens_per_section - (len (tokens ) % self .max_tokens_per_section )
262-
263- #try not to split words in different sections
264- if tokens_left < len (temp_tokens ) and len (temp_tokens ) < (self .max_word_length ):
265- for x in range (tokens_left ):
266- tokens += [(self .end_token , 1.0 )]
267- tokens += temp_tokens
268-
269- out_tokens = []
270- for x in range (0 , len (tokens ), self .max_tokens_per_section ):
271- o_token = [(self .start_token , 1.0 )] + tokens [x :min (self .max_tokens_per_section + x , len (tokens ))]
272- o_token += [(self .end_token , 1.0 )]
273- if self .pad_with_end :
274- o_token += [(self .end_token , 1.0 )] * (self .max_length - len (o_token ))
275- else :
276- o_token += [(0 , 1.0 )] * (self .max_length - len (o_token ))
277-
278- out_tokens += [o_token ]
274+ continue
275+ #parse word
276+ tokens .append ([(t , weight ) for t in self .tokenizer (word )["input_ids" ][1 :- 1 ]])
277+
278+ #reshape token array to CLIP input size
279+ batched_tokens = []
280+ batch = []
281+ batched_tokens .append (batch )
282+ for i , t_group in enumerate (tokens ):
283+ #start a new batch if there is not enough room
284+ if len (t_group ) + len (batch ) > self .max_tokens_per_section :
285+ remaining_length = self .max_tokens_per_section - len (batch )
286+ #fill remaining space depending on length of tokens
287+ if len (t_group ) > self .max_word_length :
288+ #put part of group of tokens in the batch
289+ batch .extend ([(t ,w ,i + 1 ) for t ,w in t_group [:remaining_length ]])
290+ t_group = t_group [remaining_length :]
291+ else :
292+ #filler tokens
293+ batch .extend ([(self .end_token , 1.0 , 0 )] * remaining_length )
294+ batch = []
295+ batched_tokens .append (batch )
296+ #put current group of tokens in the batch
297+ batch .extend ([(t ,w ,i + 1 ) for t ,w in t_group ])
298+
299+ #fill last batch
300+ batch .extend ([(self .end_token , 1.0 , 0 )] * (self .max_tokens_per_section - len (batch )))
301+
302+ #add start and end tokens
303+ batched_tokens = [[(self .start_token , 1.0 , 0 )] + x + [(self .end_token , 1.0 , 0 )] for x in batched_tokens ]
304+ return batched_tokens
279305
280- return out_tokens
281306
282307 def untokenize (self , token_weight_pair ):
283308 return list (map (lambda a : (a , self .inv_vocab [a [0 ]]), token_weight_pair ))
0 commit comments