Skip to content

Commit 629cd9b

Browse files
feat: enable VLMs (#126)
* started VLM for Ollama: - new ImageCBlock - images argument for `m.instruct` and `Instruction` - `get_images_from_component(c)` method to extract images from components - formatter handles images now * VLM for OpenAI backend: - new formatting from Mellea Message to OpenAI message - [patch] tool calling patch to work with multiple OpenAI-compatible inference engines * - ImageCBlock --> ImageBlock - valid png base64 testing - better `get_images_from_component` - adding images to TemplateRepr - using images from TR for Message construction * - m.instruct takes list of PIL as well now. * - m.chat takes images now * - fixing openai tool args * LiteLLM uses OPENAI formatting for VLMs. * better pretty print for Message images * examples for using vision models with different backends. * change formatter cases fix test failure
1 parent 61d7f0e commit 629cd9b

File tree

13 files changed

+326
-14
lines changed

13 files changed

+326
-14
lines changed
36.2 KB
Loading
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""Examples of using vision models with LiteLLM backend."""
2+
3+
import os
4+
5+
import litellm
6+
from PIL import Image
7+
8+
from mellea import MelleaSession, start_session
9+
from mellea.backends.litellm import LiteLLMBackend
10+
from mellea.backends.openai import OpenAIBackend
11+
from mellea.stdlib.base import ImageBlock
12+
13+
# use LiteLLM to talk to Ollama or anthropic or.....
14+
m = MelleaSession(LiteLLMBackend("ollama/granite3.2-vision"))
15+
# m = MelleaSession(LiteLLMBackend("ollama/llava"))
16+
# m = MelleaSession(LiteLLMBackend("anthropic/claude-3-haiku-20240307"))
17+
18+
test_pil = Image.open("pointing_up.jpg")
19+
20+
# check if model is able to do text chat
21+
ch = m.chat("What's 1+1?")
22+
print(str(ch.content))
23+
24+
# test with PIL image
25+
res = m.instruct(
26+
"Is there a person on the image? Is the subject in the image smiling?",
27+
images=[test_pil],
28+
)
29+
print(str(res))
30+
# print(m.last_prompt())
31+
32+
# with PIL image and using m.chat
33+
res = m.chat("How many eyes can you identify in the image? Explain.", images=[test_pil])
34+
print(str(res.content))
35+
36+
# and now without images again...
37+
res = m.instruct("How many eyes can you identify in the image?", images=[])
38+
print(str(res))
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Example of using Ollama with vision models with linear context."""
2+
3+
from PIL import Image
4+
5+
from mellea import LinearContext, start_session
6+
from mellea.stdlib.base import ImageBlock
7+
8+
m = start_session(model_id="granite3.2-vision", ctx=LinearContext())
9+
# m = start_session(model_id="llava", ctx=LinearContext())
10+
11+
# load image
12+
test_img = Image.open("pointing_up.jpg")
13+
14+
# ask a question about the image
15+
res = m.instruct("Is the subject in the image smiling?", images=[test_img])
16+
print(f"Result:{res!s}")
17+
18+
# This instruction should refer to the first image.
19+
res2 = m.instruct("How many eyes can you identify in the image? Explain.")
20+
print(f"Result:{res2!s}")
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""Examples using vision models with OpenAI backend."""
2+
3+
import os
4+
5+
from PIL import Image
6+
7+
from mellea import MelleaSession
8+
from mellea.backends.openai import OpenAIBackend
9+
from mellea.stdlib.base import ImageBlock
10+
11+
# # using anthropic AI model ...
12+
# anth_key = os.environ.get("ANTHROPIC_API_KEY")
13+
# m = MelleaSession(OpenAIBackend(model_id="claude-3-haiku-20240307",
14+
# api_key=anth_key, # Your Anthropic API key
15+
# base_url="https://api.anthropic.com/v1/" # Anthropic's API endpoint
16+
# ))
17+
18+
# using LM Studio model locally
19+
m = MelleaSession(
20+
OpenAIBackend(model_id="qwen/qwen2.5-vl-7b", base_url="http://127.0.0.1:1234/v1")
21+
)
22+
23+
# load PIL image and convert to mellea ImageBlock
24+
test_pil = Image.open("pointing_up.jpg")
25+
test_img = ImageBlock.from_pil_image(test_pil)
26+
27+
# check if model is able to do text chat
28+
ch = m.chat("What's 1+1?")
29+
print(str(ch.content))
30+
31+
# now test with MELLEA image
32+
res = m.instruct(
33+
"Is there a person on the image? Is the subject in the image smiling?",
34+
images=[test_img],
35+
)
36+
print(str(res))
37+
# print(m.last_prompt())
38+
39+
# and now with PIL image and using m.chat
40+
res = m.chat("How many eyes can you identify in the image? Explain.", images=[test_pil])
41+
print(str(res.content))
42+
43+
# and now without images again...
44+
res = m.instruct("How many eyes can you identify in the image?", images=[])
45+
print(str(res))

mellea/backends/formatter.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ def _to_msg(c: Component | CBlock) -> Message:
7171
match c:
7272
case Message():
7373
return c
74+
case Component():
75+
images = None
76+
tr = c.format_for_llm()
77+
if isinstance(tr, TemplateRepresentation):
78+
images = tr.images
79+
80+
# components can have images
81+
return Message(role=role, content=self.print(c), images=images)
7482
case _:
7583
return Message(role=role, content=self.print(c))
7684

mellea/backends/litellm.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import mellea.backends.model_ids as model_ids
1313
from mellea.backends import BaseModelSubclass
1414
from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter
15+
from mellea.backends.openai import OpenAIBackend
1516
from mellea.backends.tools import (
1617
add_tools_from_context_actions,
1718
add_tools_from_model_options,
@@ -213,18 +214,27 @@ def _generate_from_chat_context_standard(
213214
)
214215
# Convert our linearized context into a sequence of chat messages. Template formatters have a standard way of doing this.
215216
messages: list[Message] = self.formatter.to_chat_messages(linearized_context)
217+
216218
# Add the final message.
217219
match action:
218220
case ALoraRequirement():
219221
raise Exception("The LiteLLM backend does not support activated LoRAs.")
220222
case _:
221223
messages.extend(self.formatter.to_chat_messages([action]))
222224

225+
# TODO: the supports_vision function is not reliably predicting if models support vision. E.g., ollama/llava is not a vision model?
226+
# if any(m.images is not None for m in messages):
227+
# # check if model can handle images
228+
# assert litellm.supports_vision(
229+
# model=self.model_id), f"Model {self.model_id} does not support vision. Please use a different model."
230+
223231
conversation: list[dict] = []
224232
system_prompt = model_opts.get(ModelOption.SYSTEM_PROMPT, "")
225233
if system_prompt != "":
226234
conversation.append({"role": "system", "content": system_prompt})
227-
conversation.extend([{"role": m.role, "content": m.content} for m in messages])
235+
conversation.extend(
236+
[OpenAIBackend.message_to_openai_message(m) for m in messages]
237+
)
228238

229239
if format is not None:
230240
response_format = {

mellea/backends/ollama.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,12 @@ def generate_from_chat_context(
287287
if system_prompt != "":
288288
conversation.append({"role": "system", "content": system_prompt})
289289

290-
conversation.extend([{"role": m.role, "content": m.content} for m in messages])
290+
conversation.extend(
291+
[
292+
{"role": m.role, "content": m.content, "images": m.images}
293+
for m in messages
294+
]
295+
)
291296

292297
# Append tool call information if applicable.
293298
tools: dict[str, Callable] = dict()

mellea/backends/openai.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,40 @@ def _generate_from_chat_context_alora(
350350
),
351351
)
352352

353+
@staticmethod
354+
def message_to_openai_message(msg: Message):
355+
if msg.images is not None:
356+
img_list = [
357+
{
358+
"type": "image_url",
359+
"image_url": {"url": f"data:image/png;base64,{img}"},
360+
}
361+
for img in msg.images
362+
]
363+
364+
return {
365+
"role": msg.role,
366+
"content": [{"type": "text", "text": msg.content}, *img_list],
367+
}
368+
else:
369+
return {"role": msg.role, "content": msg.content}
370+
# Target format:
371+
# {
372+
# "role": "user",
373+
# "content": [
374+
# {
375+
# "type": "text",
376+
# "text": "What's in this picture?"
377+
# },
378+
# {
379+
# "type": "image_url",
380+
# "image_url": {
381+
# "url": "data:image/jpeg;base64,<base64_string>"
382+
# }
383+
# }
384+
# ]
385+
# }
386+
353387
def _generate_from_chat_context_standard(
354388
self,
355389
action: Component | CBlock,
@@ -384,7 +418,7 @@ def _generate_from_chat_context_standard(
384418
system_prompt = model_opts.get(ModelOption.SYSTEM_PROMPT, "")
385419
if system_prompt != "":
386420
conversation.append({"role": "system", "content": system_prompt})
387-
conversation.extend([{"role": m.role, "content": m.content} for m in messages])
421+
conversation.extend([self.message_to_openai_message(m) for m in messages])
388422

389423
if format is not None:
390424
response_format = {
@@ -420,15 +454,14 @@ def _generate_from_chat_context_standard(
420454
thinking = "medium"
421455

422456
formatted_tools = convert_tools_to_json(tools)
457+
use_tools = len(formatted_tools) > 0
458+
423459
chat_response: ChatCompletion = self._client.chat.completions.create(
424460
model=self._hf_model_id,
425461
messages=conversation, # type: ignore
426462
reasoning_effort=thinking, # type: ignore
427463
response_format=response_format, # type: ignore
428-
tool_choice=(
429-
"auto" if formatted_tools and len(formatted_tools) > 0 else "none"
430-
),
431-
tools=formatted_tools, # type: ignore
464+
tools=formatted_tools if use_tools else None, # type: ignore
432465
# parallel_tool_calls=False, # We only support calling one tool per turn. But we do the choosing on our side so we leave this False.
433466
**self._make_backend_specific_and_remove(
434467
model_opts, is_chat_context=ctx.is_chat_context

mellea/stdlib/base.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,17 @@
33
from __future__ import annotations
44

55
import abc
6+
import base64
7+
import binascii
68
import datetime
79
from collections.abc import Callable, Iterable, Mapping
810
from copy import deepcopy
911
from dataclasses import dataclass
12+
from io import BytesIO
1013
from typing import Any, Protocol, runtime_checkable
1114

15+
from PIL import Image as PILImage
16+
1217
from mellea.helpers.fancy_logger import FancyLogger
1318

1419

@@ -43,6 +48,70 @@ def __repr__(self):
4348
return f"CBlock({self.value}, {self._meta.__repr__()})"
4449

4550

51+
class ImageBlock:
52+
"""A `ImageBlock` represents an image (as base64 PNG)."""
53+
54+
def __init__(self, value: str, meta: dict[str, Any] | None = None):
55+
"""Initializes the ImageBlock with a base64 PNG string representation and some metadata."""
56+
assert self.is_valid_base64_png(value), (
57+
"Invalid base64 string representation of image."
58+
)
59+
self._value = value
60+
self._meta = {} if meta is None else meta
61+
62+
@staticmethod
63+
def is_valid_base64_png(s: str) -> bool:
64+
"""Checks if a string is a valid base64 string [AIA PAI Nc Hin R v1.0]."""
65+
try:
66+
# Check if the string has a data URI prefix and remove it.
67+
if "data:" in s and "base64," in s:
68+
s = s.split("base64,")[1]
69+
70+
# Add padding if necessary
71+
s = s.strip()
72+
mod4 = len(s) % 4
73+
if mod4 > 0:
74+
s = s + "=" * (4 - mod4)
75+
76+
# Attempt to decode the Base64 string
77+
decoded_data = base64.b64decode(s, validate=True)
78+
79+
# The official PNG signature is 8 bytes long.
80+
png_signature = b"\x89PNG\r\n\x1a\n"
81+
82+
if decoded_data.startswith(png_signature):
83+
return True
84+
else:
85+
return False
86+
87+
return True
88+
except (binascii.Error, ValueError):
89+
return False
90+
91+
@staticmethod
92+
def pil_to_base64(image: PILImage.Image) -> str:
93+
"""Converts a PIL image to a base64 string representation."""
94+
img_io = BytesIO()
95+
image.save(img_io, "PNG")
96+
return base64.b64encode(img_io.getvalue()).decode("utf-8")
97+
98+
@classmethod
99+
def from_pil_image(
100+
cls, image: PILImage.Image, meta: dict[str, Any] | None = None
101+
) -> ImageBlock:
102+
"""Converts a PIL image to a base64 string representation."""
103+
image_base64 = cls.pil_to_base64(image)
104+
return cls(image_base64, meta)
105+
106+
def __str__(self):
107+
"""Stringifies the block."""
108+
return self._value
109+
110+
def __repr__(self):
111+
"""Provides a python-parsable representation of the block (usually)."""
112+
return f"ImageBlock({self._value}, {self._meta.__repr__()})"
113+
114+
46115
@runtime_checkable
47116
class Component(Protocol):
48117
"""A `Component` is a composite data structure that is intended to be represented to an LLM."""
@@ -59,6 +128,25 @@ def format_for_llm(self) -> TemplateRepresentation | str:
59128
raise NotImplementedError("format_for_llm isn't implemented by default")
60129

61130

131+
def get_images_from_component(c: Component) -> None | list[ImageBlock]:
132+
"""Gets images from a `Component` if they are present and a non-empty list, otherwise returns None."""
133+
if hasattr(c, "images"):
134+
imgs = c.images
135+
if imgs is not None:
136+
assert isinstance(imgs, list), "images field must be a list."
137+
assert all(isinstance(im, ImageBlock) for im in imgs), (
138+
"all elements of images list must be ImageBlocks."
139+
)
140+
if len(imgs) == 0:
141+
return None
142+
else:
143+
return imgs
144+
else:
145+
return None
146+
else:
147+
return None
148+
149+
62150
class ModelOutputThunk(CBlock):
63151
"""A `ModelOutputThunk` is a special type of `CBlock` that we know came from a model's output. It is possible to instantiate one without the output being computed yet."""
64152

@@ -452,6 +540,7 @@ class TemplateRepresentation:
452540
fields: list[Any] | None = None
453541
template: str | None = None
454542
template_order: list[str] | None = None
543+
images: list[ImageBlock] | None = None
455544

456545

457546
@dataclass

0 commit comments

Comments
 (0)