@@ -260,60 +260,97 @@ def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedd
260260 self .inv_vocab = {v : k for k , v in vocab .items ()}
261261 self .embedding_directory = embedding_directory
262262 self .max_word_length = 8
263+ self .embedding_identifier = "embedding:"
264+
265+ def _try_get_embedding (self , embedding_name :str ):
266+ '''
267+ Takes a potential embedding name and tries to retrieve it.
268+ Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
269+ '''
270+ embed = load_embed (embedding_name , self .embedding_directory )
271+ if embed is None :
272+ stripped = embedding_name .strip (',' )
273+ if len (stripped ) < len (embedding_name ):
274+ embed = load_embed (stripped , self .embedding_directory )
275+ return (embed , embedding_name [len (stripped ):])
276+ return (embed , "" )
277+
278+
279+ def tokenize_with_weights (self , text :str , return_word_ids = False ):
280+ '''
281+ Takes a prompt and converts it to a list of (token, weight, word id) elements.
282+ Tokens can both be integer tokens and pre computed CLIP tensors.
283+ Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
284+ Returned list has the dimensions NxM where M is the input size of CLIP
285+ '''
286+ if self .pad_with_end :
287+ pad_token = self .end_token
288+ else :
289+ pad_token = 0
263290
264- def tokenize_with_weights (self , text ):
265291 text = escape_important (text )
266292 parsed_weights = token_weights (text , 1.0 )
267293
294+ #tokenize words
268295 tokens = []
269- for t in parsed_weights :
270- to_tokenize = unescape_important (t [0 ]).replace ("\n " , " " ).split (' ' )
271- while len (to_tokenize ) > 0 :
272- word = to_tokenize .pop (0 )
273- temp_tokens = []
274- embedding_identifier = "embedding:"
275- if word .startswith (embedding_identifier ) and self .embedding_directory is not None :
276- embedding_name = word [len (embedding_identifier ):].strip ('\n ' )
277- embed = load_embed (embedding_name , self .embedding_directory )
296+ for weighted_segment , weight in parsed_weights :
297+ to_tokenize = unescape_important (weighted_segment ).replace ("\n " , " " ).split (' ' )
298+ to_tokenize = [x for x in to_tokenize if x != "" ]
299+ for word in to_tokenize :
300+ #if we find an embedding, deal with the embedding
301+ if word .startswith (self .embedding_identifier ) and self .embedding_directory is not None :
302+ embedding_name = word [len (self .embedding_identifier ):].strip ('\n ' )
303+ embed , leftover = self ._try_get_embedding (embedding_name )
278304 if embed is None :
279- stripped = embedding_name .strip (',' )
280- if len (stripped ) < len (embedding_name ):
281- embed = load_embed (stripped , self .embedding_directory )
282- if embed is not None :
283- to_tokenize .insert (0 , embedding_name [len (stripped ):])
284-
285- if embed is not None :
305+ print (f"warning, embedding:{ embedding_name } does not exist, ignoring" )
306+ else :
286307 if len (embed .shape ) == 1 :
287- temp_tokens += [(embed , t [ 1 ])]
308+ tokens . append ( [(embed , weight )])
288309 else :
289- for x in range (embed .shape [0 ]):
290- temp_tokens += [(embed [x ], t [1 ])]
310+ tokens .append ([(embed [x ], weight ) for x in range (embed .shape [0 ])])
311+ #if we accidentally have leftover text, continue parsing using leftover, else move on to next word
312+ if leftover != "" :
313+ word = leftover
314+ else :
315+ continue
316+ #parse word
317+ tokens .append ([(t , weight ) for t in self .tokenizer (word )["input_ids" ][1 :- 1 ]])
318+
319+ #reshape token array to CLIP input size
320+ batched_tokens = []
321+ batch = [(self .start_token , 1.0 , 0 )]
322+ batched_tokens .append (batch )
323+ for i , t_group in enumerate (tokens ):
324+ #determine if we're going to try and keep the tokens in a single batch
325+ is_large = len (t_group ) >= self .max_word_length
326+
327+ while len (t_group ) > 0 :
328+ if len (t_group ) + len (batch ) > self .max_length - 1 :
329+ remaining_length = self .max_length - len (batch ) - 1
330+ #break word in two and add end token
331+ if is_large :
332+ batch .extend ([(t ,w ,i + 1 ) for t ,w in t_group [:remaining_length ]])
333+ batch .append ((self .end_token , 1.0 , 0 ))
334+ t_group = t_group [remaining_length :]
335+ #add end token and pad
291336 else :
292- print ("warning, embedding:{} does not exist, ignoring" .format (embedding_name ))
293- elif len (word ) > 0 :
294- tt = self .tokenizer (word )["input_ids" ][1 :- 1 ]
295- for x in tt :
296- temp_tokens += [(x , t [1 ])]
297- tokens_left = self .max_tokens_per_section - (len (tokens ) % self .max_tokens_per_section )
298-
299- #try not to split words in different sections
300- if tokens_left < len (temp_tokens ) and len (temp_tokens ) < (self .max_word_length ):
301- for x in range (tokens_left ):
302- tokens += [(self .end_token , 1.0 )]
303- tokens += temp_tokens
337+ batch .append ((self .end_token , 1.0 , 0 ))
338+ batch .extend ([(pad_token , 1.0 , 0 )] * (remaining_length ))
339+ #start new batch
340+ batch = [(self .start_token , 1.0 , 0 )]
341+ batched_tokens .append (batch )
342+ else :
343+ batch .extend ([(t ,w ,i + 1 ) for t ,w in t_group ])
344+ t_group = []
345+
346+ #fill last batch
347+ batch .extend ([(self .end_token , 1.0 , 0 )] + [(pad_token , 1.0 , 0 )] * (self .max_length - len (batch ) - 1 ))
304348
305- out_tokens = []
306- for x in range (0 , len (tokens ), self .max_tokens_per_section ):
307- o_token = [(self .start_token , 1.0 )] + tokens [x :min (self .max_tokens_per_section + x , len (tokens ))]
308- o_token += [(self .end_token , 1.0 )]
309- if self .pad_with_end :
310- o_token += [(self .end_token , 1.0 )] * (self .max_length - len (o_token ))
311- else :
312- o_token += [(0 , 1.0 )] * (self .max_length - len (o_token ))
349+ if not return_word_ids :
350+ batched_tokens = [[(t , w ) for t , w ,_ in x ] for x in batched_tokens ]
313351
314- out_tokens += [ o_token ]
352+ return batched_tokens
315353
316- return out_tokens
317354
318355 def untokenize (self , token_weight_pair ):
319356 return list (map (lambda a : (a , self .inv_vocab [a [0 ]]), token_weight_pair ))
0 commit comments