Skip to content

Commit 42e4505

Browse files
committed
Add multimodal support
1 parent 31e4bb9 commit 42e4505

File tree

3 files changed

+172
-18
lines changed

3 files changed

+172
-18
lines changed

outlines/models/lmstudio.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
cast,
1212
)
1313

14-
from outlines.inputs import Chat
14+
from outlines.inputs import Chat, Image
1515
from outlines.models.base import AsyncModel, Model, ModelTypeAdapter
1616
from outlines.types import CFG, JsonSchema, Regex
1717

@@ -22,10 +22,20 @@
2222

2323

2424
class LMStudioTypeAdapter(ModelTypeAdapter):
25-
"""Type adapter for the `LMStudio` model.
25+
"""Type adapter for the `LMStudio` model."""
2626

27-
TODO: Add multimodal (Image) support.
28-
"""
27+
def _prepare_lmstudio_image(self, image: Image):
28+
"""Convert Outlines Image to LMStudio image handle.
29+
30+
LMStudio's SDK only accepts file paths, raw bytes, or binary IO objects.
31+
Unlike Ollama which accepts base64 directly, we must decode from base64.
32+
"""
33+
import base64
34+
35+
import lmstudio as lms
36+
37+
image_bytes = base64.b64decode(image.image_str)
38+
return lms.prepare_image(image_bytes)
2939

3040
@singledispatchmethod
3141
def format_input(self, model_input):
@@ -44,17 +54,33 @@ def format_input(self, model_input):
4454
"""
4555
raise TypeError(
4656
f"The input type {type(model_input)} is not available with "
47-
"LMStudio. The only available types are `str` and `Chat`."
57+
"LMStudio. The only available types are `str`, `list` and `Chat`."
4858
)
4959

5060
@format_input.register(str)
5161
def format_str_model_input(self, model_input: str) -> str:
5262
"""Pass through string input directly to LMStudio."""
5363
return model_input
5464

65+
@format_input.register(list)
66+
def format_list_model_input(self, model_input: list) -> "LMStudioChat":
67+
"""Handle list input containing prompt and images."""
68+
from lmstudio import Chat as LMSChat
69+
70+
prompt = model_input[0]
71+
images = model_input[1:]
72+
73+
if not all(isinstance(img, Image) for img in images):
74+
raise ValueError("All assets provided must be of type Image")
75+
76+
chat = LMSChat()
77+
image_handles = [self._prepare_lmstudio_image(img) for img in images]
78+
chat.add_user_message(prompt, images=image_handles)
79+
return chat
80+
5581
@format_input.register(Chat)
5682
def format_chat_model_input(self, model_input: Chat) -> "LMStudioChat":
57-
"""Convert Outlines Chat to LMStudio Chat."""
83+
"""Convert Outlines Chat to LMStudio Chat with image support."""
5884
from lmstudio import Chat as LMSChat
5985

6086
system_prompt = None
@@ -71,7 +97,15 @@ def format_chat_model_input(self, model_input: Chat) -> "LMStudioChat":
7197
content = message["content"]
7298

7399
if role == "user":
74-
chat.add_user_message(content)
100+
if isinstance(content, str):
101+
chat.add_user_message(content)
102+
elif isinstance(content, list):
103+
prompt = content[0]
104+
images = content[1:]
105+
if not all(isinstance(img, Image) for img in images):
106+
raise ValueError("All assets provided must be of type Image")
107+
image_handles = [self._prepare_lmstudio_image(img) for img in images]
108+
chat.add_user_message(prompt, images=image_handles)
75109
elif role == "assistant":
76110
chat.add_assistant_response(content)
77111

@@ -82,9 +116,6 @@ def format_output_type(
82116
) -> Optional[dict]:
83117
"""Format the output type to pass to the model.
84118
85-
TODO: `int`, `float` and other Python types could be supported via
86-
JSON Schema.
87-
88119
Parameters
89120
----------
90121
output_type
@@ -144,7 +175,7 @@ def __init__(self, client: "Client", model_name: Optional[str] = None):
144175

145176
def generate(
146177
self,
147-
model_input: Chat | str,
178+
model_input: Chat | str | list,
148179
output_type: Optional[Any] = None,
149180
**kwargs: Any,
150181
) -> str:
@@ -194,7 +225,7 @@ def generate_batch(
194225

195226
def generate_stream(
196227
self,
197-
model_input: Chat | str,
228+
model_input: Chat | str | list,
198229
output_type: Optional[Any] = None,
199230
**kwargs: Any,
200231
) -> Iterator[str]:
@@ -262,7 +293,7 @@ def __init__(
262293

263294
async def generate(
264295
self,
265-
model_input: Chat | str,
296+
model_input: Chat | str | list,
266297
output_type: Optional[Any] = None,
267298
**kwargs: Any,
268299
) -> str:
@@ -316,7 +347,7 @@ async def generate_batch(
316347

317348
async def generate_stream( # type: ignore
318349
self,
319-
model_input: Chat | str,
350+
model_input: Chat | str | list,
320351
output_type: Optional[Any] = None,
321352
**kwargs: Any,
322353
) -> AsyncIterator[str]:

tests/models/test_lmstudio.py

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1+
import io
12
import json
23
from enum import Enum
34
from typing import Annotated
45

56
import lmstudio as lms
67
import pytest
8+
from PIL import Image as PILImage
79
from pydantic import BaseModel, Field
810

911
import outlines
10-
from outlines.inputs import Chat
12+
from outlines.inputs import Chat, Image, Video
1113
from outlines.models import AsyncLMStudio, LMStudio
1214

1315
MODEL_NAME = "qwen2.5-coder-1.5b-instruct-mlx"
@@ -37,6 +39,21 @@ def async_model_no_model_name():
3739
return AsyncLMStudio(client)
3840

3941

42+
@pytest.fixture(scope="session")
43+
def image():
44+
width, height = 1, 1
45+
white_background = (255, 255, 255)
46+
image = PILImage.new("RGB", (width, height), white_background)
47+
48+
# Save to an in-memory bytes buffer and read as png
49+
buffer = io.BytesIO()
50+
image.save(buffer, format="PNG")
51+
buffer.seek(0)
52+
image = PILImage.open(buffer)
53+
54+
return image
55+
56+
4057
def test_lmstudio_init_from_client():
4158
client = lms.get_default_client()
4259

@@ -87,6 +104,34 @@ def test_lmstudio_call(model):
87104
assert isinstance(result, str)
88105

89106

107+
@pytest.mark.api_call
108+
def test_lmstudio_simple_vision(image, model):
109+
# This is not using a vision model, so it's not able to describe
110+
# the image, but we're still checking the model input syntax
111+
result = model.generate(
112+
["What does this logo represent?", Image(image)],
113+
model=MODEL_NAME,
114+
)
115+
assert isinstance(result, str)
116+
117+
118+
@pytest.mark.api_call
119+
def test_lmstudio_chat_with_image(image, model):
120+
result = model.generate(
121+
Chat(
122+
[
123+
{"role": "system", "content": "You are a helpful assistant."},
124+
{"role": "user", "content": [
125+
"What does this logo represent?",
126+
Image(image)
127+
]},
128+
]
129+
),
130+
model=MODEL_NAME,
131+
)
132+
assert isinstance(result, str)
133+
134+
90135
@pytest.mark.api_call
91136
def test_lmstudio_chat(model):
92137
chat = Chat(messages=[
@@ -118,10 +163,13 @@ class Foo(Enum):
118163

119164

120165
@pytest.mark.api_call
121-
def test_lmstudio_wrong_input_type(model):
166+
def test_lmstudio_wrong_input_type(model, image):
122167
with pytest.raises(TypeError, match="is not available"):
123168
model.generate({"foo?": "bar?"}, None)
124169

170+
with pytest.raises(ValueError, match="All assets provided must be of type Image"):
171+
model.generate(["foo?", Image(image), Video("")], None)
172+
125173

126174
@pytest.mark.api_call
127175
def test_lmstudio_stream(model):
@@ -198,6 +246,36 @@ async def test_lmstudio_async_call(async_model):
198246
assert isinstance(result, str)
199247

200248

249+
@pytest.mark.api_call
250+
@pytest.mark.asyncio
251+
async def test_lmstudio_async_simple_vision(image, async_model):
252+
# This is not using a vision model, so it's not able to describe
253+
# the image, but we're still checking the model input syntax
254+
result = await async_model.generate(
255+
["What does this logo represent?", Image(image)],
256+
model=MODEL_NAME,
257+
)
258+
assert isinstance(result, str)
259+
260+
261+
@pytest.mark.api_call
262+
@pytest.mark.asyncio
263+
async def test_lmstudio_async_chat_with_image(image, async_model):
264+
result = await async_model.generate(
265+
Chat(
266+
[
267+
{"role": "system", "content": "You are a helpful assistant."},
268+
{"role": "user", "content": [
269+
"What does this logo represent?",
270+
Image(image)
271+
]},
272+
]
273+
),
274+
model=MODEL_NAME,
275+
)
276+
assert isinstance(result, str)
277+
278+
201279
@pytest.mark.api_call
202280
@pytest.mark.asyncio
203281
async def test_lmstudio_async_chat(async_model):
@@ -233,10 +311,13 @@ class Foo(Enum):
233311

234312
@pytest.mark.api_call
235313
@pytest.mark.asyncio
236-
async def test_lmstudio_async_wrong_input_type(async_model):
314+
async def test_lmstudio_async_wrong_input_type(async_model, image):
237315
with pytest.raises(TypeError, match="is not available"):
238316
await async_model.generate({"foo?": "bar?"}, None)
239317

318+
with pytest.raises(ValueError, match="All assets provided must be of type Image"):
319+
await async_model.generate(["foo?", Image(image), Video("")], None)
320+
240321

241322
@pytest.mark.api_call
242323
@pytest.mark.asyncio

tests/models/test_lmstudio_type_adapter.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
import io
12
import json
23
import sys
34
from dataclasses import dataclass
45

56
import pytest
67
from genson import SchemaBuilder
8+
from PIL import Image as PILImage
79
from pydantic import BaseModel
810

9-
from outlines.inputs import Chat
11+
from outlines.inputs import Chat, Image
1012
from outlines.models.lmstudio import LMStudioTypeAdapter
1113
from outlines.types import cfg, json_schema, regex
1214

@@ -34,13 +36,37 @@ def adapter():
3436
return LMStudioTypeAdapter()
3537

3638

39+
@pytest.fixture
40+
def image():
41+
width, height = 1, 1
42+
white_background = (255, 255, 255)
43+
image = PILImage.new("RGB", (width, height), white_background)
44+
45+
# Save to an in-memory bytes buffer and read as png
46+
buffer = io.BytesIO()
47+
image.save(buffer, format="PNG")
48+
buffer.seek(0)
49+
image = PILImage.open(buffer)
50+
51+
return image
52+
53+
3754
def test_lmstudio_type_adapter_input_text(adapter):
3855
text_input = "prompt"
3956
result = adapter.format_input(text_input)
4057
assert isinstance(result, str)
4158
assert result == text_input
4259

4360

61+
def test_lmstudio_type_adapter_input_vision(adapter, image):
62+
import lmstudio as lms
63+
64+
image_input = Image(image)
65+
text_input = "prompt"
66+
result = adapter.format_input([text_input, image_input])
67+
assert isinstance(result, lms.Chat)
68+
69+
4470
def test_lmstudio_type_adapter_input_chat(adapter):
4571
chat_input = Chat(messages=[
4672
{"role": "system", "content": "You are a helpful assistant."},
@@ -66,6 +92,22 @@ def test_lmstudio_type_adapter_input_chat_no_system(adapter):
6692
assert isinstance(result, lms.Chat)
6793

6894

95+
def test_lmstudio_type_adapter_input_chat_with_image(adapter, image):
96+
import lmstudio as lms
97+
98+
image_input = Image(image)
99+
chat_input = Chat(messages=[
100+
{"role": "system", "content": "You are a helpful assistant."},
101+
{"role": "user", "content": [
102+
"What is in this image?",
103+
image_input,
104+
]},
105+
{"role": "assistant", "content": "response"},
106+
])
107+
result = adapter.format_input(chat_input)
108+
assert isinstance(result, lms.Chat)
109+
110+
69111
def test_lmstudio_type_adapter_input_invalid(adapter):
70112
prompt = {"foo": "bar"}
71113
with pytest.raises(TypeError, match="The input type"):

0 commit comments

Comments
 (0)