Skip to content

Commit 9955e1e

Browse files
committed
Implement tools and outputs for the TGI model
1 parent 39a474d commit 9955e1e

File tree

3 files changed

+79
-35
lines changed

3 files changed

+79
-35
lines changed

outlines/models/tgi.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
Any,
88
AsyncIterator,
99
Iterator,
10+
List,
1011
Optional,
1112
Union,
1213
)
1314

14-
from outlines.models.base import AsyncModel,Model, ModelTypeAdapter
15+
from outlines.models.base import AsyncModel, Model, ModelTypeAdapter
16+
from outlines.outputs import Output, StreamingOutput
17+
from outlines.tools import ToolDef
1518
from outlines.types.dsl import python_types_to_terms, to_regex, JsonSchema, CFG
1619

1720
if TYPE_CHECKING:
@@ -47,7 +50,7 @@ def format_input(self, model_input):
4750
def format_str_input(self, model_input: str) -> str:
4851
return model_input
4952

50-
def format_output_type(self, output_type: Optional[Any] = None) -> dict:
53+
def format_output_type(self, output_type: Optional[Any]) -> dict:
5154
"""Generate the structured output argument to pass to the client.
5255
5356
Argument
@@ -84,6 +87,13 @@ def format_output_type(self, output_type: Optional[Any] = None) -> dict:
8487
}
8588
}
8689

90+
def format_tools(self, tools):
91+
"""Not available for TGI."""
92+
if tools:
93+
raise NotImplementedError(
94+
"Tools are not available for TGI."
95+
)
96+
8797

8898
class TGI(Model):
8999
"""Thin wrapper around a `huggingface_hub.InferenceClient` client used to
@@ -109,9 +119,10 @@ def __init__(self, client):
109119
def generate(
110120
self,
111121
model_input: str,
112-
output_type: Optional[Any] = None,
122+
output_type: Optional[Any],
123+
tools: Optional[List[ToolDef]],
113124
**inference_kwargs: Any,
114-
) -> str:
125+
) -> Output:
115126
"""Generate text using TGI.
116127
117128
Parameters
@@ -122,37 +133,44 @@ def generate(
122133
The desired format of the response generated by the model. All
123134
output types except `CFG` are supported provided your server uses
124135
a backend that supports them.
136+
tools
137+
The tools to use for the generation.
125138
inference_kwargs
126139
Additional keyword arguments to pass to the client.
127140
128141
Returns
129142
-------
130-
str
143+
Output
131144
The text generated by the model.
132145
133146
"""
147+
self.type_adapter.format_tools(tools)
134148
client_args = self._build_client_args(
135149
model_input,
136150
output_type,
137151
**inference_kwargs,
138152
)
139153

140-
return self.client.text_generation(**client_args)
154+
response = self.client.text_generation(**client_args)
155+
156+
return Output(content=response)
141157

142158
def generate_batch(
143159
self,
144160
model_input,
145-
output_type = None,
161+
output_type,
162+
tools,
146163
**inference_kwargs,
147164
):
148165
raise NotImplementedError("TGI does not support batch inference.")
149166

150167
def generate_stream(
151168
self,
152169
model_input: str,
153-
output_type: Optional[Any] = None,
170+
output_type: Optional[Any],
171+
tools: Optional[List[ToolDef]],
154172
**inference_kwargs: Any,
155-
) -> Iterator[str]:
173+
) -> Iterator[StreamingOutput]:
156174
"""Stream text using TGI.
157175
158176
Parameters
@@ -163,15 +181,18 @@ def generate_stream(
163181
The desired format of the response generated by the model. All
164182
output types except `CFG` are supported provided your server uses
165183
a backend that supports them.
184+
tools
185+
The tools to use for the generation.
166186
inference_kwargs
167187
Additional keyword arguments to pass to the client.
168188
169189
Returns
170190
-------
171-
Iterator[str]
191+
Iterator[StreamingOutput]
172192
An iterator that yields the text generated by the model.
173193
174194
"""
195+
self.type_adapter.format_tools(tools)
175196
client_args = self._build_client_args(
176197
model_input, output_type, **inference_kwargs,
177198
)
@@ -181,12 +202,12 @@ def generate_stream(
181202
)
182203

183204
for chunk in stream: # pragma: no cover
184-
yield chunk
205+
yield StreamingOutput(content=chunk)
185206

186207
def _build_client_args(
187208
self,
188209
model_input: str,
189-
output_type: Optional[Any] = None,
210+
output_type: Optional[Any],
190211
**inference_kwargs: Any,
191212
) -> dict:
192213
"""Build the arguments to pass to the TGI client."""
@@ -226,9 +247,10 @@ def __init__(self, client):
226247
async def generate(
227248
self,
228249
model_input: str,
229-
output_type: Optional[Any] = None,
250+
output_type: Optional[Any],
251+
tools: Optional[List[ToolDef]],
230252
**inference_kwargs: Any,
231-
) -> str:
253+
) -> Output:
232254
"""Generate text using TGI.
233255
234256
Parameters
@@ -239,37 +261,42 @@ async def generate(
239261
The desired format of the response generated by the model. All
240262
output types except `CFG` are supported provided your server uses
241263
a backend that supports them.
264+
tools
265+
The tools to use for the generation.
242266
inference_kwargs
243267
Additional keyword arguments to pass to the client.
244268
245269
Returns
246270
-------
247-
str
271+
Output
248272
The text generated by the model.
249273
250274
"""
275+
self.type_adapter.format_tools(tools)
251276
client_args = self._build_client_args(
252277
model_input, output_type, **inference_kwargs,
253278
)
254279

255280
response = await self.client.text_generation(**client_args)
256281

257-
return response
282+
return Output(content=response)
258283

259284
async def generate_batch(
260285
self,
261286
model_input,
262-
output_type = None,
287+
output_type,
288+
tools,
263289
**inference_kwargs,
264290
):
265291
raise NotImplementedError("TGI does not support batch inference.")
266292

267293
async def generate_stream( # type: ignore
268294
self,
269295
model_input: str,
270-
output_type: Optional[Any] = None,
296+
output_type: Optional[Any],
297+
tools: Optional[List[ToolDef]],
271298
**inference_kwargs: Any,
272-
) -> AsyncIterator[str]:
299+
) -> AsyncIterator[StreamingOutput]:
273300
"""Stream text using TGI.
274301
275302
Parameters
@@ -280,15 +307,18 @@ async def generate_stream( # type: ignore
280307
The desired format of the response generated by the model. All
281308
output types except `CFG` are supported provided your server uses
282309
a backend that supports them.
310+
tools
311+
The tools to use for the generation.
283312
inference_kwargs
284313
Additional keyword arguments to pass to the client.
285314
286315
Returns
287316
-------
288-
AsyncIterator[str]
317+
AsyncIterator[StreamingOutput]
289318
An async iterator that yields the text generated by the model.
290319
291320
"""
321+
self.type_adapter.format_tools(tools)
292322
client_args = self._build_client_args(
293323
model_input, output_type, **inference_kwargs,
294324
)
@@ -298,12 +328,12 @@ async def generate_stream( # type: ignore
298328
)
299329

300330
async for chunk in stream: # pragma: no cover
301-
yield chunk
331+
yield StreamingOutput(content=chunk)
302332

303333
def _build_client_args(
304334
self,
305335
model_input: str,
306-
output_type: Optional[Any] = None,
336+
output_type: Optional[Any],
307337
**inference_kwargs: Any,
308338
) -> dict:
309339
"""Build the arguments to pass to the TGI client."""

tests/models/test_tgi.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from huggingface_hub import InferenceClient, AsyncInferenceClient
88

99
from outlines.models.tgi import TGI, AsyncTGI, from_tgi
10+
from outlines.outputs import Output, StreamingOutput
1011
from outlines.types.dsl import CFG, Regex, JsonSchema
1112
from tests.test_utils.mock_tgi_client import MockTGIInferenceClient, MockAsyncTGIInferenceClient
1213

@@ -42,7 +43,7 @@
4243
'max_new_tokens': 10,
4344
'stream': True
4445
},
45-
["foo", "bar"]
46+
[Output(content="foo"), Output(content="bar")]
4647
),
4748
(
4849
{
@@ -108,7 +109,7 @@ def test_tgi_init():
108109

109110
def test_tgi_sync_simple_call(sync_model):
110111
result = sync_model("Respond with a single word.", max_new_tokens=10)
111-
assert isinstance(result, str)
112+
assert isinstance(result, Output)
112113

113114

114115
def test_tgi_sync_streaming(sync_model):
@@ -117,7 +118,7 @@ def test_tgi_sync_streaming(sync_model):
117118
max_new_tokens=10,
118119
)
119120
assert isinstance(result, Generator)
120-
assert isinstance(next(result), str)
121+
assert isinstance(next(result), StreamingOutput)
121122

122123

123124
def test_tgi_sync_batch(sync_model):
@@ -130,14 +131,14 @@ def test_tgi_sync_batch(sync_model):
130131
def test_tgi_sync_json(sync_model):
131132
json_string = '{"type": "object", "properties": {"bar": {"type": "string"}}, "required": ["bar"]}'
132133
result = sync_model("foo?", JsonSchema(json_string), max_new_tokens=10)
133-
assert isinstance(result, str)
134-
assert "bar" in result
134+
assert isinstance(result, Output)
135+
assert "bar" in result.content
135136

136137

137138
def test_tgi_sync_regex(sync_model):
138139
result = sync_model("foo?", Regex(r"[0-9]{3}"), max_new_tokens=10)
139-
assert isinstance(result, str)
140-
assert re.match(r"[0-9]{3}", result)
140+
assert isinstance(result, Output)
141+
assert re.match(r"[0-9]{3}", result.content)
141142

142143

143144
def test_tgi_sync_cfg(sync_model):
@@ -151,15 +152,15 @@ def test_tgi_sync_cfg(sync_model):
151152
@pytest.mark.asyncio
152153
async def test_tgi_async_simple_call(async_model):
153154
result = await async_model("Respond with a single word.", max_new_tokens=10)
154-
assert isinstance(result, str)
155+
assert isinstance(result, Output)
155156

156157

157158
@pytest.mark.asyncio
158159
async def test_tgi_async_streaming(async_model):
159160
result = async_model.stream("Respond with a single word.", max_new_tokens=10)
160161
assert isinstance(result, AsyncGenerator)
161162
async for chunk in result:
162-
assert isinstance(chunk, str)
163+
assert isinstance(chunk, StreamingOutput)
163164
break # Just check the first chunk
164165

165166

@@ -175,15 +176,15 @@ async def test_tgi_async_batch(async_model):
175176
async def test_tgi_async_json(async_model):
176177
json_string = '{"type": "object", "properties": {"bar": {"type": "string"}}, "required": ["bar"]}'
177178
result = await async_model("foo?", JsonSchema(json_string), max_new_tokens=10)
178-
assert isinstance(result, str)
179-
assert "bar" in result
179+
assert isinstance(result, Output)
180+
assert "bar" in result.content
180181

181182

182183
@pytest.mark.asyncio
183184
async def test_tgi_async_regex(async_model):
184185
result = await async_model("foo?", Regex(r"[0-9]{3}"), max_new_tokens=10)
185-
assert isinstance(result, str)
186-
assert re.match(r"[0-9]{3}", result)
186+
assert isinstance(result, Output)
187+
assert re.match(r"[0-9]{3}", result.content)
187188

188189

189190
@pytest.mark.asyncio
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33

44
from outlines.models.tgi import TGITypeAdapter
5+
from outlines.tools import ToolDef
56
from outlines.types import CFG, JsonSchema
67

78

@@ -86,3 +87,15 @@ def test_tgi_type_adapter_output_type_invalid(
8687
match="TGI does not support CFG-based structured outputs.",
8788
):
8889
type_adapter.format_output_type(cfg_instance)
90+
91+
92+
def test_tgi_type_adapter_tools(type_adapter):
93+
with pytest.raises(
94+
NotImplementedError,
95+
match="Tools are not available for TGI.",
96+
):
97+
type_adapter.format_tools(
98+
[ToolDef(name="test", description="test", parameters={})]
99+
)
100+
101+
type_adapter.format_tools(None)

0 commit comments

Comments
 (0)