Skip to content

Commit 7f748f1

Browse files
committed
Implement tools and outputs for the MLXLM model
1 parent 11f0891 commit 7f748f1

File tree

3 files changed

+62
-23
lines changed

3 files changed

+62
-23
lines changed

outlines/models/mlxlm.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from outlines.inputs import Chat
77
from outlines.models.base import Model, ModelTypeAdapter
88
from outlines.models.transformers import TransformerTokenizer
9+
from outlines.outputs import Output, StreamingOutput
910
from outlines.processors import OutlinesLogitsProcessor
11+
from outlines.tools import ToolDef
1012

1113
if TYPE_CHECKING:
1214
import mlx.nn as nn
@@ -37,7 +39,7 @@ def format_input(self, model_input):
3739
3840
"""
3941
raise NotImplementedError(
40-
f"The input type {input} is not available with mlx-lm. "
42+
f"The input type {model_input} is not available with mlx-lm. "
4143
"The available types are `str` and `Chat`."
4244
)
4345

@@ -63,7 +65,7 @@ def format_chat_input(self, model_input: Chat) -> str:
6365
)
6466

6567
def format_output_type(
66-
self, output_type: Optional[OutlinesLogitsProcessor] = None,
68+
self, output_type: Optional[OutlinesLogitsProcessor],
6769
) -> Optional[List[OutlinesLogitsProcessor]]:
6870
"""Generate the logits processor argument to pass to the model.
6971
@@ -83,6 +85,14 @@ def format_output_type(
8385
return [output_type]
8486

8587

88+
def format_tools(self, tools):
89+
"""Not available for MLXLM."""
90+
if tools:
91+
raise NotImplementedError(
92+
"MLXLM does not support tools."
93+
)
94+
95+
8696
class MLXLM(Model):
8797
"""Thin wrapper around an `mlx_lm` model.
8898
@@ -118,9 +128,10 @@ def __init__(
118128
def generate(
119129
self,
120130
model_input: str,
121-
output_type: Optional[OutlinesLogitsProcessor] = None,
131+
output_type: Optional[OutlinesLogitsProcessor],
132+
tools: Optional[List[ToolDef]],
122133
**kwargs,
123-
) -> str:
134+
) -> Output:
124135
"""Generate text using `mlx-lm`.
125136
126137
Parameters
@@ -130,29 +141,36 @@ def generate(
130141
output_type
131142
The logits processor the model will use to constrain the format of
132143
the generated text.
144+
tools
145+
The tools to use for the generation.
133146
kwargs
134147
Additional keyword arguments to pass to the `mlx-lm` library.
135148
136149
Returns
137150
-------
138-
str
151+
Output
139152
The text generated by the model.
140153
141154
"""
142155
from mlx_lm import generate
143156

144-
return generate(
157+
self.type_adapter.format_tools(tools)
158+
159+
result = generate(
145160
self.model,
146161
self.mlx_tokenizer,
147162
self.type_adapter.format_input(model_input),
148163
logits_processors=self.type_adapter.format_output_type(output_type),
149164
**kwargs,
150165
)
151166

167+
return Output(content=result.text)
168+
152169
def generate_batch(
153170
self,
154171
model_input,
155-
output_type = None,
172+
output_type,
173+
tools,
156174
**kwargs,
157175
):
158176
raise NotImplementedError(
@@ -162,9 +180,10 @@ def generate_batch(
162180
def generate_stream(
163181
self,
164182
model_input: str,
165-
output_type: Optional[OutlinesLogitsProcessor] = None,
183+
output_type: Optional[OutlinesLogitsProcessor],
184+
tools: Optional[List[ToolDef]],
166185
**kwargs,
167-
) -> Iterator[str]:
186+
) -> Iterator[StreamingOutput]:
168187
"""Stream text using `mlx-lm`.
169188
170189
Parameters
@@ -174,25 +193,29 @@ def generate_stream(
174193
output_type
175194
The logits processor the model will use to constrain the format of
176195
the generated text.
196+
tools
197+
The tools to use for the generation.
177198
kwargs
178199
Additional keyword arguments to pass to the `mlx-lm` library.
179200
180201
Returns
181202
-------
182-
Iterator[str]
203+
Iterator[StreamingOutput]
183204
An iterator that yields the text generated by the model.
184205
185206
"""
186207
from mlx_lm import stream_generate
187208

209+
self.type_adapter.format_tools(tools)
210+
188211
for gen_response in stream_generate(
189212
self.model,
190213
self.mlx_tokenizer,
191214
self.type_adapter.format_input(model_input),
192215
logits_processors=self.type_adapter.format_output_type(output_type),
193216
**kwargs,
194217
):
195-
yield gen_response.text
218+
yield StreamingOutput(content=gen_response.text)
196219

197220

198221
def from_mlxlm(model: "nn.Module", tokenizer: "PreTrainedTokenizer") -> MLXLM:

tests/models/test_mlxlm.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from enum import Enum
44
from typing import Generator
55

6+
from pydantic import BaseModel
7+
68
import outlines
79
from outlines.types import Regex
810
from outlines.models.mlxlm import (
@@ -11,7 +13,7 @@
1113
from_mlxlm
1214
)
1315
from outlines.models.transformers import TransformerTokenizer
14-
from pydantic import BaseModel
16+
from outlines.outputs import Output, StreamingOutput
1517

1618
try:
1719
import mlx_lm
@@ -55,14 +57,14 @@ def test_mlxlm_tokenizer(model):
5557

5658
@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon")
5759
def test_mlxlm_simple(model):
58-
result = model.generate("Respond with one word. Not more.", None)
59-
assert isinstance(result, str)
60+
result = model("Respond with one word. Not more.", None)
61+
assert isinstance(result, Output)
6062

6163

6264
@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon")
6365
def test_mlxlm_call(model):
6466
result = model("Respond with one word. Not more.")
65-
assert isinstance(result, str)
67+
assert isinstance(result, Output)
6668

6769

6870
@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon")
@@ -80,15 +82,15 @@ def test_mlxlm_invalid_inference_kwargs(model):
8082
@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon")
8183
def test_mlxlm_inference_kwargs(model):
8284
result = model("Write a short story about a cat.", max_tokens=2)
83-
assert isinstance(result, str)
84-
assert len(result) < 20
85+
assert isinstance(result, Output)
86+
assert len(result.content) < 20
8587

8688

8789
@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon")
8890
def test_mlxlm_regex(model):
8991
result = model("Give a number between 0 and 9.", Regex(r"[0-9]"))
90-
assert isinstance(result, str)
91-
assert re.match(r"[0-9]", result)
92+
assert isinstance(result, Output)
93+
assert re.match(r"[0-9]", result.content)
9294

9395

9496
@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon")
@@ -97,7 +99,7 @@ class Character(BaseModel):
9799
name: str
98100

99101
result = model("Create a character with a name.", Character)
100-
assert "name" in result
102+
assert "name" in result.content
101103

102104

103105
@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon")
@@ -107,7 +109,7 @@ class Foo(Enum):
107109
dog = "dog"
108110

109111
result = model("Cat or dog?", Foo)
110-
assert result in ["cat", "dog"]
112+
assert result.content in ["cat", "dog"]
111113

112114

113115
@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon")
@@ -116,7 +118,7 @@ def test_mlxlm_stream_text_stop(model):
116118
"Respond with one word. Not more.", None, max_tokens=100
117119
)
118120
assert isinstance(generator, Generator)
119-
assert isinstance(next(generator), str)
121+
assert isinstance(next(generator), StreamingOutput)
120122

121123

122124
@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon")

tests/models/test_mlxlm_type_adapter.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
import pytest
21
import io
2+
import pytest
33

44
from outlines_core import Index, Vocabulary
55
from PIL import Image as PILImage
66

77
from outlines.backends.outlines_core import OutlinesCoreLogitsProcessor
88
from outlines.inputs import Chat, Image
99
from outlines.models.mlxlm import MLXLMTypeAdapter
10+
from outlines.tools import ToolDef
1011

1112
try:
1213
import mlx_lm
@@ -82,3 +83,16 @@ def test_mlxlm_type_adapter_format_output_type(adapter, logits_processor):
8283
assert isinstance(formatted, list)
8384
assert len(formatted) == 1
8485
assert isinstance(formatted[0], OutlinesCoreLogitsProcessor)
86+
87+
88+
@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon")
89+
def test_mlxlm_type_adapter_tools(adapter):
90+
with pytest.raises(
91+
NotImplementedError,
92+
match="MLXLM does not support tools."
93+
):
94+
adapter.format_tools(
95+
[ToolDef(name="test", description="test", parameters={})]
96+
)
97+
98+
adapter.format_tools(None)

0 commit comments

Comments
 (0)