Skip to content

Commit 2cfdd0a

Browse files
authored
use images with api models (#2981)
* use images with apis * pacify pre-commit
1 parent 178fa84 commit 2cfdd0a

File tree

1 file changed

+82
-2
lines changed

1 file changed

+82
-2
lines changed

lm_eval/models/api_models.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77
from functools import cached_property
88
from typing import (
9+
TYPE_CHECKING,
910
Any,
1011
Awaitable,
1112
Callable,
@@ -30,14 +31,20 @@
3031
pass
3132

3233

34+
import base64
3335
from importlib.util import find_spec
36+
from io import BytesIO
3437

3538
from lm_eval import utils
3639
from lm_eval.api.instance import Instance
3740
from lm_eval.api.model import TemplateLM
3841
from lm_eval.models.utils import Collator, chunks, configure_pad_token
3942

4043

44+
if TYPE_CHECKING:
45+
from PIL import Image
46+
47+
4148
eval_logger = logging.getLogger(__name__)
4249

4350
LogLikelihoodInputs = Tuple[Tuple[str, str], List[int], List[int]]
@@ -51,7 +58,52 @@ def encode(self, encoding):
5158
return self.prompt.encode(encoding)
5259

5360

61+
def create_image_prompt(
62+
imgs: list["Image.Image"], chat: dict, fmt: str = "PNG"
63+
) -> dict:
64+
"""
65+
66+
Parameters
67+
----------
68+
img : list[PIL.Image.Image]
69+
The list of images to encode to base64
70+
chat : dict
71+
fmt : str, optional
72+
Any format Pillow understands (e.g. "PNG", "JPEG").
73+
Defaults to "PNG".
74+
75+
Returns
76+
-------
77+
dict
78+
"""
79+
images = []
80+
for img in imgs:
81+
buf = BytesIO()
82+
img.save(buf, format=fmt)
83+
img_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
84+
img_dict = {
85+
"type": "image_url",
86+
"image_url": {"url": f"data:image/png;base64,{img_b64}", "detail": "auto"},
87+
}
88+
images.append(img_dict)
89+
90+
# chat is in format of list[dict["role": "user"/"system", "content": str, "type": "text"],...]
91+
# with images, we need "content" to be a list of dicts with "type" and "text"/"image_url"
92+
# currently we do not support few-shots so only one user message
93+
# text content also has <image> placeholders, which apparently is not necessary for API class (confirm)
94+
95+
if isinstance(chat[-1]["content"], list):
96+
chat[-1]["content"] = images + chat[-1]["content"]
97+
else:
98+
text_content = {"type": "text", "text": chat[-1]["content"]}
99+
chat[-1]["content"] = images + [text_content]
100+
chat[-1].pop("type")
101+
return chat
102+
103+
54104
class TemplateAPI(TemplateLM):
105+
MULTIMODAL = True
106+
55107
def __init__(
56108
self,
57109
model: str = None,
@@ -83,6 +135,7 @@ def __init__(
83135
eos_string: str = None,
84136
# timeout in seconds
85137
timeout: int = 300,
138+
max_images: int = 1,
86139
**kwargs,
87140
) -> None:
88141
super().__init__()
@@ -129,6 +182,7 @@ def __init__(
129182
self.verify_certificate = verify_certificate
130183
self._eos_string = eos_string
131184
self.timeout = int(timeout)
185+
self.max_images = int(max_images)
132186

133187
eval_logger.info(f"Using tokenizer {self.tokenizer_backend}")
134188
if self.tokenizer_backend is None:
@@ -265,7 +319,12 @@ def apply_chat_template(
265319
)
266320
else:
267321
# bit of a hack. We'll load back before sending to the API
268-
return JsonChatStr(json.dumps(chat_history, ensure_ascii=False))
322+
return JsonChatStr(
323+
json.dumps(
324+
[{**item, "type": "text"} for item in chat_history],
325+
ensure_ascii=False,
326+
)
327+
)
269328

270329
@cached_property
271330
def eot_token_id(self) -> Optional[int]:
@@ -578,7 +637,28 @@ def _collate_gen(_requests):
578637
return -len(_requests[0])
579638

580639
# Let the API deal with tokenization
581-
requests, all_gen_kwargs = zip(*(req.args for req in requests))
640+
if len(requests[0].args) > 2:
641+
assert self.tokenizer is None, (
642+
"tokenizer is not supported for multimodal requests yet!"
643+
)
644+
eval_logger.info(
645+
f"Using max_images {self.max_images}. Set in the model args."
646+
)
647+
requests, all_gen_kwargs, auxiliary_args = zip(
648+
*(req.args for req in requests)
649+
)
650+
requests = tuple(
651+
JsonChatStr(
652+
json.dumps(
653+
create_image_prompt(
654+
y["visual"][: self.max_images], json.loads(x.prompt)
655+
)
656+
)
657+
)
658+
for x, y in zip(requests, auxiliary_args)
659+
)
660+
else:
661+
requests, all_gen_kwargs = zip(*(req.args for req in requests))
582662
if self.tokenized_requests:
583663
encodings_list = self.tok_encode(
584664
requests, add_special_tokens=self.add_bos_token

0 commit comments

Comments
 (0)