66import logging
77from functools import cached_property
88from typing import (
9+ TYPE_CHECKING ,
910 Any ,
1011 Awaitable ,
1112 Callable ,
3031 pass
3132
3233
34+ import base64
3335from importlib .util import find_spec
36+ from io import BytesIO
3437
3538from lm_eval import utils
3639from lm_eval .api .instance import Instance
3740from lm_eval .api .model import TemplateLM
3841from lm_eval .models .utils import Collator , chunks , configure_pad_token
3942
4043
44+ if TYPE_CHECKING :
45+ from PIL import Image
46+
47+
4148eval_logger = logging .getLogger (__name__ )
4249
4350LogLikelihoodInputs = Tuple [Tuple [str , str ], List [int ], List [int ]]
@@ -51,7 +58,52 @@ def encode(self, encoding):
5158 return self .prompt .encode (encoding )
5259
5360
61+ def create_image_prompt (
62+ imgs : list ["Image.Image" ], chat : dict , fmt : str = "PNG"
63+ ) -> dict :
64+ """
65+
66+ Parameters
67+ ----------
68+ img : list[PIL.Image.Image]
69+ The list of images to encode to base64
70+ chat : dict
71+ fmt : str, optional
72+ Any format Pillow understands (e.g. "PNG", "JPEG").
73+ Defaults to "PNG".
74+
75+ Returns
76+ -------
77+ dict
78+ """
79+ images = []
80+ for img in imgs :
81+ buf = BytesIO ()
82+ img .save (buf , format = fmt )
83+ img_b64 = base64 .b64encode (buf .getvalue ()).decode ("utf-8" )
84+ img_dict = {
85+ "type" : "image_url" ,
86+ "image_url" : {"url" : f"data:image/png;base64,{ img_b64 } " , "detail" : "auto" },
87+ }
88+ images .append (img_dict )
89+
90+ # chat is in format of list[dict["role": "user"/"system", "content": str, "type": "text"],...]
91+ # with images, we need "content" to be a list of dicts with "type" and "text"/"image_url"
92+ # currently we do not support few-shots so only one user message
93+ # text content also has <image> placeholders, which apparently is not necessary for API class (confirm)
94+
95+ if isinstance (chat [- 1 ]["content" ], list ):
96+ chat [- 1 ]["content" ] = images + chat [- 1 ]["content" ]
97+ else :
98+ text_content = {"type" : "text" , "text" : chat [- 1 ]["content" ]}
99+ chat [- 1 ]["content" ] = images + [text_content ]
100+ chat [- 1 ].pop ("type" )
101+ return chat
102+
103+
54104class TemplateAPI (TemplateLM ):
105+ MULTIMODAL = True
106+
55107 def __init__ (
56108 self ,
57109 model : str = None ,
@@ -83,6 +135,7 @@ def __init__(
83135 eos_string : str = None ,
84136 # timeout in seconds
85137 timeout : int = 300 ,
138+ max_images : int = 1 ,
86139 ** kwargs ,
87140 ) -> None :
88141 super ().__init__ ()
@@ -129,6 +182,7 @@ def __init__(
129182 self .verify_certificate = verify_certificate
130183 self ._eos_string = eos_string
131184 self .timeout = int (timeout )
185+ self .max_images = int (max_images )
132186
133187 eval_logger .info (f"Using tokenizer { self .tokenizer_backend } " )
134188 if self .tokenizer_backend is None :
@@ -265,7 +319,12 @@ def apply_chat_template(
265319 )
266320 else :
267321 # bit of a hack. We'll load back before sending to the API
268- return JsonChatStr (json .dumps (chat_history , ensure_ascii = False ))
322+ return JsonChatStr (
323+ json .dumps (
324+ [{** item , "type" : "text" } for item in chat_history ],
325+ ensure_ascii = False ,
326+ )
327+ )
269328
270329 @cached_property
271330 def eot_token_id (self ) -> Optional [int ]:
@@ -578,7 +637,28 @@ def _collate_gen(_requests):
578637 return - len (_requests [0 ])
579638
580639 # Let the API deal with tokenization
581- requests , all_gen_kwargs = zip (* (req .args for req in requests ))
640+ if len (requests [0 ].args ) > 2 :
641+ assert self .tokenizer is None , (
642+ "tokenizer is not supported for multimodal requests yet!"
643+ )
644+ eval_logger .info (
645+ f"Using max_images { self .max_images } . Set in the model args."
646+ )
647+ requests , all_gen_kwargs , auxiliary_args = zip (
648+ * (req .args for req in requests )
649+ )
650+ requests = tuple (
651+ JsonChatStr (
652+ json .dumps (
653+ create_image_prompt (
654+ y ["visual" ][: self .max_images ], json .loads (x .prompt )
655+ )
656+ )
657+ )
658+ for x , y in zip (requests , auxiliary_args )
659+ )
660+ else :
661+ requests , all_gen_kwargs = zip (* (req .args for req in requests ))
582662 if self .tokenized_requests :
583663 encodings_list = self .tok_encode (
584664 requests , add_special_tokens = self .add_bos_token
0 commit comments