@@ -466,7 +466,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
466466 return embed_out
467467
468468class SDTokenizer :
469- def __init__ (self , tokenizer_path = None , max_length = 77 , pad_with_end = True , embedding_directory = None , embedding_size = 768 , embedding_key = 'clip_l' , tokenizer_class = CLIPTokenizer , has_start_token = True , has_end_token = True , pad_to_max_length = True , min_length = None , pad_token = None , end_token = None , min_padding = None , pad_left = False , disable_weights = False , tokenizer_data = {}, tokenizer_args = {}):
469+ def __init__ (self , tokenizer_path = None , max_length = 77 , pad_with_end = True , embedding_directory = None , embedding_size = 768 , embedding_key = 'clip_l' , tokenizer_class = CLIPTokenizer , has_start_token = True , has_end_token = True , pad_to_max_length = True , min_length = None , pad_token = None , end_token = None , start_token = None , min_padding = None , pad_left = False , disable_weights = False , tokenizer_data = {}, tokenizer_args = {}):
470470 if tokenizer_path is None :
471471 tokenizer_path = os .path .join (os .path .dirname (os .path .realpath (__file__ )), "sd1_tokenizer" )
472472 self .tokenizer = tokenizer_class .from_pretrained (tokenizer_path , ** tokenizer_args )
@@ -479,16 +479,23 @@ def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedd
479479 empty = self .tokenizer ('' )["input_ids" ]
480480 self .tokenizer_adds_end_token = has_end_token
481481 if has_start_token :
482- self .tokens_start = 1
483- self .start_token = empty [0 ]
482+ if len (empty ) > 0 :
483+ self .tokens_start = 1
484+ self .start_token = empty [0 ]
485+ else :
486+ self .tokens_start = 0
487+ self .start_token = start_token
488+ if start_token is None :
489+ logging .warning ("WARNING: There's something wrong with your tokenizers.'" )
490+
484491 if end_token is not None :
485492 self .end_token = end_token
486493 else :
487494 if has_end_token :
488495 self .end_token = empty [1 ]
489496 else :
490497 self .tokens_start = 0
491- self .start_token = None
498+ self .start_token = start_token
492499 if end_token is not None :
493500 self .end_token = end_token
494501 else :
0 commit comments