@@ -2603,7 +2603,12 @@ def __call__(
26032603
26042604 image_urls = self .get_image_urls (messages )
26052605 template = jinja2 .Template (self .CHAT_FORMAT )
2606- text = template .render (messages = messages , add_generation_prompt = True )
2606+ text = template .render (
2607+ messages = messages ,
2608+ add_generation_prompt = True ,
2609+ eos_token = llama .detokenize ([llama .token_eos ()]),
2610+ bos_token = llama .detokenize ([llama .token_bos ()]),
2611+ )
26072612 split_text = self .split_text_on_image_urls (text , image_urls )
26082613
26092614 def embed_image_bytes (image_bytes : bytes ):
@@ -2624,9 +2629,9 @@ def embed_image_bytes(image_bytes: bytes):
26242629
26252630 # Evaluate prompt
26262631 llama .reset ()
2627- for i , ( type_ , value ) in enumerate ( split_text ) :
2632+ for type_ , value in split_text :
26282633 if type_ == "text" :
2629- tokens = llama .tokenize (value .encode ("utf8" ), add_bos = i == 0 )
2634+ tokens = llama .tokenize (value .encode ("utf8" ), add_bos = False , special = True )
26302635 if llama .n_tokens + len (tokens ) > llama .n_ctx ():
26312636 raise ValueError ("Prompt exceeds n_ctx" ) # TODO: Fix
26322637 llama .eval (tokens )
@@ -2644,6 +2649,8 @@ def embed_image_bytes(image_bytes: bytes):
26442649 llama .n_batch ,
26452650 n_past_p ,
26462651 )
2652+ # Required to avoid issues with hf tokenizer
2653+ llama .input_ids [llama .n_tokens : n_past .value ] = - 1
26472654 llama .n_tokens = n_past .value
26482655
26492656 # Get prompt tokens to avoid a cache miss
@@ -3033,6 +3040,7 @@ class NanoLlavaChatHandler(Llava15ChatHandler):
30333040 # Answer the question<|im_end|><|im_start|>user
30343041 # <image>
30353042 # What is the picture about?<|im_end|><|im_start|>assistant
3043+ DEFAULT_SYSTEM_MESSAGE = "Answer the question"
30363044
30373045 CHAT_FORMAT = (
30383046 "{% for message in messages %}"
0 commit comments