77 Any ,
88 AsyncIterator ,
99 Iterator ,
10+ List ,
1011 Optional ,
1112 Union ,
1213)
1314
14- from outlines .models .base import AsyncModel ,Model , ModelTypeAdapter
15+ from outlines .models .base import AsyncModel , Model , ModelTypeAdapter
16+ from outlines .outputs import Output , StreamingOutput
17+ from outlines .tools import ToolDef
1518from outlines .types .dsl import python_types_to_terms , to_regex , JsonSchema , CFG
1619
1720if TYPE_CHECKING :
@@ -47,7 +50,7 @@ def format_input(self, model_input):
4750 def format_str_input (self , model_input : str ) -> str :
4851 return model_input
4952
50- def format_output_type (self , output_type : Optional [Any ] = None ) -> dict :
53+ def format_output_type (self , output_type : Optional [Any ]) -> dict :
5154 """Generate the structured output argument to pass to the client.
5255
5356 Argument
@@ -84,6 +87,13 @@ def format_output_type(self, output_type: Optional[Any] = None) -> dict:
8487 }
8588 }
8689
90+ def format_tools (self , tools ):
91+ """Not available for TGI."""
92+ if tools :
93+ raise NotImplementedError (
94+ "Tools are not available for TGI."
95+ )
96+
8797
8898class TGI (Model ):
8999 """Thin wrapper around a `huggingface_hub.InferenceClient` client used to
@@ -109,9 +119,10 @@ def __init__(self, client):
109119 def generate (
110120 self ,
111121 model_input : str ,
112- output_type : Optional [Any ] = None ,
122+ output_type : Optional [Any ],
123+ tools : Optional [List [ToolDef ]],
113124 ** inference_kwargs : Any ,
114- ) -> str :
125+ ) -> Output :
115126 """Generate text using TGI.
116127
117128 Parameters
@@ -122,37 +133,44 @@ def generate(
122133 The desired format of the response generated by the model. All
123134 output types except `CFG` are supported provided your server uses
124135 a backend that supports them.
136+ tools
137+ The tools to use for the generation.
125138 inference_kwargs
126139 Additional keyword arguments to pass to the client.
127140
128141 Returns
129142 -------
130- str
143+ Output
131144 The text generated by the model.
132145
133146 """
147+ self .type_adapter .format_tools (tools )
134148 client_args = self ._build_client_args (
135149 model_input ,
136150 output_type ,
137151 ** inference_kwargs ,
138152 )
139153
140- return self .client .text_generation (** client_args )
154+ response = self .client .text_generation (** client_args )
155+
156+ return Output (content = response )
141157
142158 def generate_batch (
143159 self ,
144160 model_input ,
145- output_type = None ,
161+ output_type ,
162+ tools ,
146163 ** inference_kwargs ,
147164 ):
148165 raise NotImplementedError ("TGI does not support batch inference." )
149166
150167 def generate_stream (
151168 self ,
152169 model_input : str ,
153- output_type : Optional [Any ] = None ,
170+ output_type : Optional [Any ],
171+ tools : Optional [List [ToolDef ]],
154172 ** inference_kwargs : Any ,
155- ) -> Iterator [str ]:
173+ ) -> Iterator [StreamingOutput ]:
156174 """Stream text using TGI.
157175
158176 Parameters
@@ -163,15 +181,18 @@ def generate_stream(
163181 The desired format of the response generated by the model. All
164182 output types except `CFG` are supported provided your server uses
165183 a backend that supports them.
184+ tools
185+ The tools to use for the generation.
166186 inference_kwargs
167187 Additional keyword arguments to pass to the client.
168188
169189 Returns
170190 -------
171- Iterator[str ]
191+ Iterator[StreamingOutput ]
172192 An iterator that yields the text generated by the model.
173193
174194 """
195+ self .type_adapter .format_tools (tools )
175196 client_args = self ._build_client_args (
176197 model_input , output_type , ** inference_kwargs ,
177198 )
@@ -181,12 +202,12 @@ def generate_stream(
181202 )
182203
183204 for chunk in stream : # pragma: no cover
184- yield chunk
205+ yield StreamingOutput ( content = chunk )
185206
186207 def _build_client_args (
187208 self ,
188209 model_input : str ,
189- output_type : Optional [Any ] = None ,
210+ output_type : Optional [Any ],
190211 ** inference_kwargs : Any ,
191212 ) -> dict :
192213 """Build the arguments to pass to the TGI client."""
@@ -226,9 +247,10 @@ def __init__(self, client):
226247 async def generate (
227248 self ,
228249 model_input : str ,
229- output_type : Optional [Any ] = None ,
250+ output_type : Optional [Any ],
251+ tools : Optional [List [ToolDef ]],
230252 ** inference_kwargs : Any ,
231- ) -> str :
253+ ) -> Output :
232254 """Generate text using TGI.
233255
234256 Parameters
@@ -239,37 +261,42 @@ async def generate(
239261 The desired format of the response generated by the model. All
240262 output types except `CFG` are supported provided your server uses
241263 a backend that supports them.
264+ tools
265+ The tools to use for the generation.
242266 inference_kwargs
243267 Additional keyword arguments to pass to the client.
244268
245269 Returns
246270 -------
247- str
271+ Output
248272 The text generated by the model.
249273
250274 """
275+ self .type_adapter .format_tools (tools )
251276 client_args = self ._build_client_args (
252277 model_input , output_type , ** inference_kwargs ,
253278 )
254279
255280 response = await self .client .text_generation (** client_args )
256281
257- return response
282+ return Output ( content = response )
258283
259284 async def generate_batch (
260285 self ,
261286 model_input ,
262- output_type = None ,
287+ output_type ,
288+ tools ,
263289 ** inference_kwargs ,
264290 ):
265291 raise NotImplementedError ("TGI does not support batch inference." )
266292
267293 async def generate_stream ( # type: ignore
268294 self ,
269295 model_input : str ,
270- output_type : Optional [Any ] = None ,
296+ output_type : Optional [Any ],
297+ tools : Optional [List [ToolDef ]],
271298 ** inference_kwargs : Any ,
272- ) -> AsyncIterator [str ]:
299+ ) -> AsyncIterator [StreamingOutput ]:
273300 """Stream text using TGI.
274301
275302 Parameters
@@ -280,15 +307,18 @@ async def generate_stream( # type: ignore
280307 The desired format of the response generated by the model. All
281308 output types except `CFG` are supported provided your server uses
282309 a backend that supports them.
310+ tools
311+ The tools to use for the generation.
283312 inference_kwargs
284313 Additional keyword arguments to pass to the client.
285314
286315 Returns
287316 -------
288- AsyncIterator[str ]
317+ AsyncIterator[StreamingOutput ]
289318 An async iterator that yields the text generated by the model.
290319
291320 """
321+ self .type_adapter .format_tools (tools )
292322 client_args = self ._build_client_args (
293323 model_input , output_type , ** inference_kwargs ,
294324 )
@@ -298,12 +328,12 @@ async def generate_stream( # type: ignore
298328 )
299329
300330 async for chunk in stream : # pragma: no cover
301- yield chunk
331+ yield StreamingOutput ( content = chunk )
302332
303333 def _build_client_args (
304334 self ,
305335 model_input : str ,
306- output_type : Optional [Any ] = None ,
336+ output_type : Optional [Any ],
307337 ** inference_kwargs : Any ,
308338 ) -> dict :
309339 """Build the arguments to pass to the TGI client."""
0 commit comments