Skip to content

Commit 11f0891

Browse files
committed
Implement tools and outputs for the LlamaCpp model
1 parent 789bda6 commit 11f0891

File tree

3 files changed

+83
-36
lines changed

3 files changed

+83
-36
lines changed

outlines/models/llamacpp.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from outlines.inputs import Chat
1818
from outlines.models.base import Model, ModelTypeAdapter
1919
from outlines.models.tokenizer import Tokenizer
20+
from outlines.outputs import Output, StreamingOutput
2021
from outlines.processors import OutlinesLogitsProcessor
22+
from outlines.tools import ToolDef
2123

2224
if TYPE_CHECKING:
2325
from llama_cpp import Llama, LogitsProcessorList
@@ -196,7 +198,7 @@ def format_chat_input(self, model_input: Chat) -> list:
196198
]
197199

198200
def format_output_type(
199-
self, output_type: Optional[OutlinesLogitsProcessor] = None,
201+
self, output_type: Optional[OutlinesLogitsProcessor],
200202
) -> "LogitsProcessorList":
201203
"""Generate the logits processor argument to pass to the model.
202204
@@ -215,6 +217,13 @@ def format_output_type(
215217

216218
return LogitsProcessorList([output_type])
217219

220+
def format_tools(self, tools):
221+
"""Not available for LlamaCpp."""
222+
if tools:
223+
raise NotImplementedError(
224+
"LlamaCpp does not support tools."
225+
)
226+
218227

219228
class LlamaCpp(Model):
220229
"""Thin wrapper around the `llama_cpp.Llama` model.
@@ -240,9 +249,10 @@ def __init__(self, model: "Llama"):
240249
def generate(
241250
self,
242251
model_input: Union[Chat, str],
243-
output_type: Optional[OutlinesLogitsProcessor] = None,
252+
output_type: Optional[OutlinesLogitsProcessor],
253+
tools: Optional[List[ToolDef]],
244254
**inference_kwargs: Any,
245-
) -> str:
255+
) -> Output:
246256
"""Generate text using `llama-cpp-python`.
247257
248258
Parameters
@@ -252,6 +262,8 @@ def generate(
252262
output_type
253263
The logits processor the model will use to constrain the format of
254264
the generated text.
265+
tools
266+
The tools to use for the generation.
255267
**inference_kwargs
256268
Additional keyword arguments to pass to the `Llama.__call__`
257269
method of the `llama-cpp-python` library.
@@ -262,41 +274,46 @@ def generate(
262274
The text generated by the model.
263275
264276
"""
277+
self.type_adapter.format_tools(tools)
265278
prompt = self.type_adapter.format_input(model_input)
279+
logits_processor = self.type_adapter.format_output_type(output_type)
266280

267281
if isinstance(prompt, str):
268282
completion = self.model(
269283
prompt,
270-
logits_processor=self.type_adapter.format_output_type(output_type),
284+
logits_processor=logits_processor,
271285
**inference_kwargs,
272286
)
273287
result = completion["choices"][0]["text"]
274288
elif isinstance(prompt, list): # pragma: no cover
275289
completion = self.model.create_chat_completion(
276290
prompt,
277-
logits_processor=self.type_adapter.format_output_type(output_type),
291+
logits_processor=logits_processor,
278292
**inference_kwargs,
279293
)
280294
result = completion["choices"][0]["message"]["content"]
281295

282296
self.model.reset()
283297

284-
return result
298+
return Output(content=result)
285299

286300
def generate_batch(
287301
self,
288302
model_input,
289-
output_type = None,
303+
output_type,
290304
**inference_kwargs,
291305
):
292-
raise NotImplementedError("LlamaCpp does not support batch generation.")
306+
raise NotImplementedError(
307+
"LlamaCpp does not support batch generation."
308+
)
293309

294310
def generate_stream(
295311
self,
296312
model_input: Union[Chat, str],
297-
output_type: Optional[OutlinesLogitsProcessor] = None,
313+
output_type: Optional[OutlinesLogitsProcessor],
314+
tools: Optional[List[ToolDef]],
298315
**inference_kwargs: Any,
299-
) -> Iterator[str]:
316+
) -> Iterator[StreamingOutput]:
300317
"""Stream text using `llama-cpp-python`.
301318
302319
Parameters
@@ -306,6 +323,8 @@ def generate_stream(
306323
output_type
307324
The logits processor the model will use to constrain the format of
308325
the generated text.
326+
tools
327+
The tools to use for the generation.
309328
**inference_kwargs
310329
Additional keyword arguments to pass to the `Llama.__call__`
311330
method of the `llama-cpp-python` library.
@@ -316,27 +335,33 @@ def generate_stream(
316335
An iterator that yields the text generated by the model.
317336
318337
"""
338+
self.type_adapter.format_tools(tools)
319339
prompt = self.type_adapter.format_input(model_input)
340+
logits_processor = self.type_adapter.format_output_type(output_type)
320341

321342
if isinstance(prompt, str):
322343
generator = self.model(
323344
prompt,
324-
logits_processor=self.type_adapter.format_output_type(output_type),
345+
logits_processor=logits_processor,
325346
stream=True,
326347
**inference_kwargs,
327348
)
328349
for chunk in generator:
329-
yield chunk["choices"][0]["text"]
350+
yield StreamingOutput(
351+
content=chunk["choices"][0]["text"]
352+
)
330353

331354
elif isinstance(prompt, list): # pragma: no cover
332355
generator = self.model.create_chat_completion(
333356
prompt,
334-
logits_processor=self.type_adapter.format_output_type(output_type),
357+
logits_processor=logits_processor,
335358
stream=True,
336359
**inference_kwargs,
337360
)
338361
for chunk in generator:
339-
yield chunk["choices"][0]["delta"].get("content", "")
362+
yield StreamingOutput(
363+
content=chunk["choices"][0]["delta"].get("content", "")
364+
)
340365

341366

342367
def from_llamacpp(model: "Llama"):

tests/models/test_llamacpp.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
LlamaCppTypeAdapter,
1313
from_llamacpp
1414
)
15+
from outlines.outputs import Output, StreamingOutput
1516
from outlines.types.dsl import Regex, CFG
1617

1718

@@ -71,12 +72,12 @@ def ebnf_grammar():
7172

7273

7374
def test_llamacpp_simple(model):
74-
result = model.generate("Respond with one word. Not more.", None)
75-
assert isinstance(result, str)
75+
result = model("Respond with one word. Not more.", None)
76+
assert isinstance(result, Output)
7677

7778

7879
def test_llamacpp_chat(model):
79-
result = model.generate(
80+
result = model(
8081
Chat(
8182
messages=[
8283
{"role": "system", "content": "You are a helpful assistant."},
@@ -85,23 +86,23 @@ def test_llamacpp_chat(model):
8586
),
8687
max_tokens=10
8788
)
88-
assert isinstance(result, str)
89+
assert isinstance(result, Output)
8990

9091

9192
def test_llamacpp_regex(model):
9293
result = model("Respond with one word. Not more.", Regex(r"[0-9]"))
93-
assert isinstance(result, str)
94-
assert int(result)
95-
assert len(result) == 1
94+
assert isinstance(result, Output)
95+
assert int(result.content)
96+
assert len(result.content) == 1
9697

9798

9899
def test_llamacpp_json(model):
99100
class Foo(BaseModel):
100101
bar: str
101102

102103
result = model("foo? Respond with one word.", Foo, max_tokens=100)
103-
assert isinstance(result, str)
104-
assert "bar" in json.loads(result)
104+
assert isinstance(result, Output)
105+
assert "bar" in json.loads(result.content)
105106

106107

107108
def test_llamacpp_choice(model):
@@ -110,12 +111,14 @@ class Foo(Enum):
110111
foor = "Foo"
111112

112113
result = model("foo?", Foo)
113-
assert result == "Foo" or result == "Bar"
114+
assert isinstance(result, Output)
115+
assert result.content == "Foo" or result.content == "Bar"
114116

115117

116118
def test_llamacpp_cfg(model, ebnf_grammar):
117119
response = model("Respond with one word. Not more.", CFG(ebnf_grammar))
118-
assert response in ["yes", "no"]
120+
assert isinstance(response, Output)
121+
assert response.content in ["yes", "no"]
119122

120123

121124
def test_llamacpp_cfg_outlines_core(model, lark_grammar):
@@ -131,15 +134,16 @@ def test_llamacpp_cfg_outlines_core(model, lark_grammar):
131134

132135

133136
def test_llamacpp_text_stop(model):
134-
result = model.generate("Write the letter a.", None, stop="a", max_tokens=100)
135-
assert "a" not in result
137+
result = model("Write the letter a.", None, stop="a", max_tokens=100)
138+
assert isinstance(result, Output)
139+
assert "a" not in result.content
136140

137141

138142
def test_llamacpp_stream_simple(model):
139143
generator = model.stream("Respond with one word. Not more.", None)
140144

141145
for x in generator:
142-
assert isinstance(x, str)
146+
assert isinstance(x, StreamingOutput)
143147

144148

145149
def test_llamacpp_stream_chat(model):
@@ -153,14 +157,16 @@ def test_llamacpp_stream_chat(model):
153157
max_tokens=10
154158
)
155159
for x in generator:
156-
assert isinstance(x, str)
160+
assert isinstance(x, StreamingOutput)
157161

158162

159163
def test_llamacpp_stream_regex(model):
160164
generator = model.stream("Respond with one word. Not more.", Regex(r"[0-9]"))
161165

162166
x = next(generator)
163-
assert isinstance(x, str)
167+
assert isinstance(x, StreamingOutput)
168+
assert int(x.content)
169+
assert len(x.content) == 1
164170

165171

166172
def test_llamacpp_stream_json(model):
@@ -170,15 +176,17 @@ class Foo(BaseModel):
170176
generator = model.stream("foo?", Foo)
171177

172178
x = next(generator)
173-
assert x == "{"
179+
assert isinstance(x, StreamingOutput)
180+
assert "{" in x.content
174181

175182

176183
def test_llamacpp_stream_cfg(model, ebnf_grammar):
177184
response = ""
178185
for chunk in model.stream(
179186
"Respond with one word. Not more.", CFG(ebnf_grammar)
180187
):
181-
response += chunk
188+
assert isinstance(chunk, StreamingOutput)
189+
response += chunk.content
182190
assert response in ["yes", "no"]
183191

184192

@@ -187,7 +195,7 @@ def test_llamacpp_stream_cfg_outlines_core(model, lark_grammar):
187195
NotImplementedError,
188196
match="Outlines Core does not support context-free grammar."
189197
):
190-
for chunk in model.stream(
198+
for _ in model.stream(
191199
"Respond with one word. Not more.",
192200
CFG(lark_grammar),
193201
backend="outlines_core"
@@ -203,15 +211,16 @@ class Foo(Enum):
203211
generator = model.stream("foo?", Foo)
204212

205213
x = next(generator)
206-
assert x[0] in ("B", "F")
214+
assert isinstance(x, StreamingOutput)
215+
assert x.content[0] in ("B", "F")
207216

208217

209218
def test_llamacpp_stream_text_stop(model):
210219
generator = model.stream("Write the letter a.", None, stop="a", max_tokens=100)
211220

212221
result = next(generator)
213-
assert isinstance(result, str)
214-
assert result != "a"
222+
assert isinstance(result, StreamingOutput)
223+
assert result.content != "a"
215224

216225

217226
def test_llamacpp_batch(model):

tests/models/test_llamacpp_type_adapter.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from outlines.backends.outlines_core import OutlinesCoreLogitsProcessor
99
from outlines.inputs import Chat, Image
1010
from outlines.models.llamacpp import LlamaCppTypeAdapter
11+
from outlines.tools import ToolDef
1112

1213

1314
@pytest.fixture
@@ -67,3 +68,15 @@ def test_llamacpp_type_adapter_format_output_type(adapter, logits_processor):
6768
assert isinstance(formatted, LogitsProcessorList)
6869
assert formatted[0].index == logits_processor.index
6970
assert formatted[0].tensor_library_name == logits_processor.tensor_library_name
71+
72+
73+
def test_llamacpp_type_adapter_tools(adapter):
74+
with pytest.raises(
75+
NotImplementedError,
76+
match="LlamaCpp does not support tools."
77+
):
78+
adapter.format_tools(
79+
[ToolDef(name="test", description="test", parameters={})]
80+
)
81+
82+
adapter.format_tools(None)

0 commit comments

Comments
 (0)