66from outlines .inputs import Chat
77from outlines .models .base import Model , ModelTypeAdapter
88from outlines .models .transformers import TransformerTokenizer
9+ from outlines .outputs import Output , StreamingOutput
910from outlines .processors import OutlinesLogitsProcessor
11+ from outlines .tools import ToolDef
1012
1113if TYPE_CHECKING :
1214 import mlx .nn as nn
@@ -37,7 +39,7 @@ def format_input(self, model_input):
3739
3840 """
3941 raise NotImplementedError (
40- f"The input type { input } is not available with mlx-lm. "
42+ f"The input type { model_input } is not available with mlx-lm. "
4143 "The available types are `str` and `Chat`."
4244 )
4345
@@ -63,7 +65,7 @@ def format_chat_input(self, model_input: Chat) -> str:
6365 )
6466
6567 def format_output_type (
66- self , output_type : Optional [OutlinesLogitsProcessor ] = None ,
68+ self , output_type : Optional [OutlinesLogitsProcessor ],
6769 ) -> Optional [List [OutlinesLogitsProcessor ]]:
6870 """Generate the logits processor argument to pass to the model.
6971
@@ -83,6 +85,14 @@ def format_output_type(
8385 return [output_type ]
8486
8587
88+ def format_tools (self , tools ):
89+ """Not available for MLXLM."""
90+ if tools :
91+ raise NotImplementedError (
92+ "MLXLM does not support tools."
93+ )
94+
95+
8696class MLXLM (Model ):
8797 """Thin wrapper around an `mlx_lm` model.
8898
@@ -118,9 +128,10 @@ def __init__(
118128 def generate (
119129 self ,
120130 model_input : str ,
121- output_type : Optional [OutlinesLogitsProcessor ] = None ,
131+ output_type : Optional [OutlinesLogitsProcessor ],
132+ tools : Optional [List [ToolDef ]],
122133 ** kwargs ,
123- ) -> str :
134+ ) -> Output :
124135 """Generate text using `mlx-lm`.
125136
126137 Parameters
@@ -130,29 +141,36 @@ def generate(
130141 output_type
131142 The logits processor the model will use to constrain the format of
132143 the generated text.
144+ tools
145+ The tools to use for the generation.
133146 kwargs
134147 Additional keyword arguments to pass to the `mlx-lm` library.
135148
136149 Returns
137150 -------
138- str
151+ Output
139152 The text generated by the model.
140153
141154 """
142155 from mlx_lm import generate
143156
144- return generate (
157+ self .type_adapter .format_tools (tools )
158+
159+ result = generate (
145160 self .model ,
146161 self .mlx_tokenizer ,
147162 self .type_adapter .format_input (model_input ),
148163 logits_processors = self .type_adapter .format_output_type (output_type ),
149164 ** kwargs ,
150165 )
151166
167+ return Output (content = result .text )
168+
152169 def generate_batch (
153170 self ,
154171 model_input ,
155- output_type = None ,
172+ output_type ,
173+ tools ,
156174 ** kwargs ,
157175 ):
158176 raise NotImplementedError (
@@ -162,9 +180,10 @@ def generate_batch(
162180 def generate_stream (
163181 self ,
164182 model_input : str ,
165- output_type : Optional [OutlinesLogitsProcessor ] = None ,
183+ output_type : Optional [OutlinesLogitsProcessor ],
184+ tools : Optional [List [ToolDef ]],
166185 ** kwargs ,
167- ) -> Iterator [str ]:
186+ ) -> Iterator [StreamingOutput ]:
168187 """Stream text using `mlx-lm`.
169188
170189 Parameters
@@ -174,25 +193,29 @@ def generate_stream(
174193 output_type
175194 The logits processor the model will use to constrain the format of
176195 the generated text.
196+ tools
197+ The tools to use for the generation.
177198 kwargs
178199 Additional keyword arguments to pass to the `mlx-lm` library.
179200
180201 Returns
181202 -------
182- Iterator[str ]
203+ Iterator[StreamingOutput ]
183204 An iterator that yields the text generated by the model.
184205
185206 """
186207 from mlx_lm import stream_generate
187208
209+ self .type_adapter .format_tools (tools )
210+
188211 for gen_response in stream_generate (
189212 self .model ,
190213 self .mlx_tokenizer ,
191214 self .type_adapter .format_input (model_input ),
192215 logits_processors = self .type_adapter .format_output_type (output_type ),
193216 ** kwargs ,
194217 ):
195- yield gen_response .text
218+ yield StreamingOutput ( content = gen_response .text )
196219
197220
198221def from_mlxlm (model : "nn.Module" , tokenizer : "PreTrainedTokenizer" ) -> MLXLM :
0 commit comments