Skip to content

Commit eb1e443

Browse files
committed
Add tests for the Outlines model
1 parent 45beb87 commit eb1e443

File tree

5 files changed

+354
-4
lines changed

5 files changed

+354
-4
lines changed

tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,18 @@ def model(
526526
'Qwen/Qwen2.5-72B-Instruct',
527527
provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key),
528528
)
529+
elif request.param == 'outlines':
530+
from outlines.models.transformers import from_transformers
531+
from transformers import AutoModelForCausalLM, AutoTokenizer
532+
533+
from pydantic_ai.models.outlines import OutlinesModel
534+
535+
return OutlinesModel(
536+
from_transformers(
537+
AutoModelForCausalLM.from_pretrained('erwanf/gpt2-mini'),
538+
AutoTokenizer.from_pretrained('erwanf/gpt2-mini'),
539+
)
540+
)
529541
else:
530542
raise ValueError(f'Unknown model: {request.param}')
531543
except ImportError:

tests/models/test_outlines.py

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
# We only test with the transformers model to limit the number of dependencies
2+
3+
import json
4+
from functools import partial
5+
from typing import TYPE_CHECKING, Any
6+
7+
import pytest
8+
from pydantic import BaseModel
9+
10+
from pydantic_ai import Agent, ModelRetry
11+
from pydantic_ai.builtin_tools import WebSearchTool
12+
from pydantic_ai.exceptions import UserError
13+
from pydantic_ai.messages import (
14+
ImageUrl,
15+
ModelMessage,
16+
ModelRequest,
17+
ModelResponse,
18+
RetryPromptPart,
19+
SystemPromptPart,
20+
TextPart,
21+
ThinkingPart,
22+
ToolCallPart,
23+
ToolReturnPart,
24+
UserPromptPart,
25+
)
26+
from pydantic_ai.output import ToolOutput
27+
from pydantic_ai.profiles import ModelProfile
28+
29+
from ..conftest import try_import
30+
31+
with try_import() as imports_successful:
32+
from outlines.models.transformers import Transformers, from_transformers
33+
from transformers import AutoModelForCausalLM, AutoTokenizer
34+
35+
from pydantic_ai.models.outlines import (
36+
OutlinesModel,
37+
)
38+
from pydantic_ai.providers.outlines import OutlinesProvider
39+
40+
if TYPE_CHECKING:
41+
from outlines.models.transformers import Transformers
42+
from transformers import AutoModelForCausalLM, AutoTokenizer
43+
44+
pytestmark = [
45+
pytest.mark.skipif(not imports_successful(), reason='outlines not installed'),
46+
pytest.mark.anyio,
47+
]
48+
49+
50+
TRANSFORMERS_MODEL_NAME = 'erwanf/gpt2-mini'
51+
52+
53+
@pytest.fixture
54+
def outlines_model() -> "Transformers":
55+
hf_model = AutoModelForCausalLM.from_pretrained(TRANSFORMERS_MODEL_NAME) # type: ignore[no-untyped-call]
56+
hf_tokenizer = AutoTokenizer.from_pretrained(TRANSFORMERS_MODEL_NAME) # type: ignore[no-untyped-call]
57+
chat_template = '{% for message in messages %}{{ message.role }}: {{ message.content }}{% endfor %}'
58+
hf_tokenizer.chat_template = chat_template
59+
return from_transformers(model=hf_model, tokenizer_or_processor=hf_tokenizer) # type: ignore[no-untyped-call]
60+
61+
62+
def test_init(outlines_model: Any) -> None:
63+
m = OutlinesModel(outlines_model, provider=OutlinesProvider())
64+
assert isinstance(m.model, Transformers)
65+
assert m.model_name == 'outlines-model'
66+
assert m.system == 'outlines'
67+
assert m.settings is None
68+
assert m.profile == ModelProfile(
69+
supports_tools=False,
70+
supports_json_schema_output=True,
71+
supports_json_object_output=True,
72+
default_structured_output_mode='native',
73+
thinking_tags=('<think>', '</think>'),
74+
ignore_streamed_leading_whitespace=False,
75+
)
76+
77+
78+
def test_from_transformers():
79+
m = OutlinesModel.from_transformers(
80+
AutoModelForCausalLM.from_pretrained(TRANSFORMERS_MODEL_NAME), # type: ignore[no-untyped-call]
81+
AutoTokenizer.from_pretrained(TRANSFORMERS_MODEL_NAME), # type: ignore[no-untyped-call]
82+
)
83+
assert isinstance(m.model, Transformers)
84+
assert m.model_name == 'outlines-model'
85+
assert m.system == 'outlines'
86+
assert m.settings is None
87+
assert m.profile == ModelProfile(
88+
supports_tools=False,
89+
supports_json_schema_output=True,
90+
supports_json_object_output=True,
91+
default_structured_output_mode='native',
92+
thinking_tags=('<think>', '</think>'),
93+
ignore_streamed_leading_whitespace=False,
94+
)
95+
96+
97+
async def test_request_async(outlines_model: "Transformers") -> None:
98+
m = OutlinesModel(outlines_model)
99+
agent = Agent(m)
100+
result = await agent.run('What is the capital of France?')
101+
assert len(result.output) > 0
102+
result = await agent.run('What is the capital of Germany?', message_history=result.all_messages())
103+
assert len(result.output) > 0
104+
all_messages = result.all_messages()
105+
assert len(all_messages) == 4
106+
107+
assert isinstance(all_messages[0], ModelRequest)
108+
assert len(all_messages[0].parts) == 1
109+
assert isinstance(all_messages[0].parts[0], UserPromptPart)
110+
assert all_messages[0].parts[0].content == 'What is the capital of France?'
111+
112+
assert isinstance(all_messages[1], ModelResponse)
113+
assert len(all_messages[1].parts) == 1
114+
assert isinstance(all_messages[1].parts[0], TextPart)
115+
assert isinstance(all_messages[1].parts[0].content, str)
116+
117+
assert isinstance(all_messages[2], ModelRequest)
118+
assert len(all_messages[2].parts) == 1
119+
assert isinstance(all_messages[2].parts[0], UserPromptPart)
120+
assert all_messages[2].parts[0].content == 'What is the capital of Germany?'
121+
122+
assert isinstance(all_messages[3], ModelResponse)
123+
assert len(all_messages[3].parts) == 1
124+
assert isinstance(all_messages[3].parts[0], TextPart)
125+
assert isinstance(all_messages[3].parts[0].content, str)
126+
127+
128+
def test_request_sync(outlines_model: "Transformers") -> None:
129+
m = OutlinesModel(outlines_model)
130+
agent = Agent(m)
131+
result = agent.run_sync('What is the capital of France?')
132+
assert len(result.output) > 0
133+
all_messages = result.all_messages()
134+
135+
assert len(all_messages) == 2
136+
137+
assert isinstance(all_messages[0], ModelRequest)
138+
assert len(all_messages[0].parts) == 1
139+
assert isinstance(all_messages[0].parts[0], UserPromptPart)
140+
assert all_messages[0].parts[0].content == 'What is the capital of France?'
141+
142+
assert isinstance(all_messages[1], ModelResponse)
143+
assert len(all_messages[1].parts) == 1
144+
assert isinstance(all_messages[1].parts[0], TextPart)
145+
assert isinstance(all_messages[1].parts[0].content, str)
146+
147+
148+
async def test_request_streaming(outlines_model: "Transformers") -> None:
149+
# The transformers model does not support streaming,
150+
# so we need to mock the generate_stream method.
151+
def patched_generate_stream(self_ref: "Transformers", *args: Any, **kwargs: Any) -> Any:
152+
response = self_ref.generate(*args, **kwargs) # type: ignore[no-untyped-call]
153+
for i in range(0, len(response), 10):
154+
chunk = response[i : i + 10]
155+
yield chunk
156+
157+
outlines_model.generate_stream = partial(patched_generate_stream, outlines_model)
158+
m = OutlinesModel(outlines_model)
159+
agent = Agent(m)
160+
async with agent.run_stream('What is the capital of the UK?') as response:
161+
async for text in response.stream_text():
162+
assert isinstance(text, str)
163+
assert len(text) > 0
164+
165+
166+
def test_tool_definition(outlines_model: "Transformers") -> None:
167+
m = OutlinesModel(outlines_model)
168+
169+
# function tools
170+
agent = Agent(m, builtin_tools=[WebSearchTool()])
171+
with pytest.raises(UserError, match='Outlines does not support function tools and builtin tools yet.'):
172+
agent.run_sync('Hello')
173+
174+
# built-in tools
175+
agent = Agent(m)
176+
177+
@agent.tool_plain
178+
async def get_location(loc_name: str) -> str: # pragma: no cover
179+
if loc_name == 'London':
180+
return json.dumps({'lat': 51, 'lng': 0})
181+
else:
182+
raise ModelRetry('Wrong location, please try again')
183+
184+
with pytest.raises(UserError, match='Outlines does not support function tools and builtin tools yet.'):
185+
agent.run_sync('Hello')
186+
187+
# output tools
188+
class MyOutput(BaseModel):
189+
name: str
190+
191+
agent = Agent(m, output_type=ToolOutput(MyOutput, name='my_output_tool'))
192+
with pytest.raises(UserError, match='Output tools are not supported by the model.'):
193+
agent.run_sync('Hello')
194+
195+
196+
def test_output_type(outlines_model: "Transformers") -> None:
197+
class Box(BaseModel):
198+
width: int
199+
height: int
200+
depth: int
201+
units: int
202+
203+
m = OutlinesModel(outlines_model)
204+
agent = Agent(m, output_type=Box)
205+
result = agent.run_sync('Give me the dimensions of a box', model_settings={'max_new_tokens': 100}) # type: ignore[typeddict-item]
206+
assert isinstance(result.output, Box)
207+
208+
209+
def test_model_settings(outlines_model: "Transformers") -> None:
210+
# set at model level + max_new_tokens
211+
model = OutlinesModel(outlines_model, settings={'max_new_tokens': 1}) # type: ignore[typeddict-item]
212+
agent = Agent(model)
213+
result = agent.run_sync('How are you doing?')
214+
assert len(result.output) < 10
215+
216+
# set at agent level + stop_sequences
217+
model = OutlinesModel(outlines_model)
218+
agent = Agent(
219+
model,
220+
model_settings={ # type: ignore[typeddict-item]
221+
'stop_sequences': ['Paris'],
222+
'max_new_tokens': 200,
223+
'extra_body': {'tokenizer': outlines_model.hf_tokenizer},
224+
},
225+
)
226+
result = agent.run_sync('Write a story about Paris')
227+
assert result.output.endswith('Paris')
228+
229+
# set at run level + args of ModelSettings that are not supported
230+
model = OutlinesModel(outlines_model)
231+
agent = Agent(model)
232+
with pytest.warns(UserWarning, match='The transformers model does not support'):
233+
result = agent.run_sync(
234+
'Hello',
235+
model_settings={
236+
'timeout': 1,
237+
'parallel_tool_calls': True,
238+
'seed': 123,
239+
'extra_headers': {'Authorization': 'Bearer 123'},
240+
},
241+
)
242+
assert isinstance(result.output, str)
243+
244+
# presence_penalty and frequency_penalty
245+
with pytest.warns(UserWarning, match='The transformers model has a single argument `repetition_penalty`'):
246+
result = agent.run_sync('Hello', model_settings={'presence_penalty': 0.7, 'frequency_penalty': 0.3})
247+
assert isinstance(result.output, str)
248+
249+
# logit_bias
250+
with pytest.warns(UserWarning, match='The transformers model expects the keys of the `logits_bias`'):
251+
result = agent.run_sync('Hello', model_settings={'logit_bias': {'20,21': 0.5, '22': 0.3, 'a': 0.2}}) # type: ignore[typeddict-item]
252+
assert isinstance(result.output, str)
253+
254+
255+
def test_input_format(outlines_model: "Transformers") -> None:
256+
m = OutlinesModel(outlines_model)
257+
agent = Agent(m)
258+
259+
# all accepted message types
260+
message_history: list[ModelMessage] = [
261+
ModelRequest(
262+
parts=[
263+
SystemPromptPart(content='You are a helpful assistance'),
264+
UserPromptPart(content='Hello'),
265+
UserPromptPart(content=['Foo', 'Bar']),
266+
RetryPromptPart(content='Failure'),
267+
]
268+
),
269+
ModelResponse(
270+
parts=[
271+
ThinkingPart('Thinking...'), # ignored by the model
272+
TextPart('Hello there!'),
273+
]
274+
),
275+
]
276+
agent.run_sync('How are you doing?', message_history=message_history)
277+
278+
# unsupported: multi-modal user prompts
279+
message_history: list[ModelMessage] = [
280+
ModelRequest(
281+
parts=[UserPromptPart(content=['Describe the image', ImageUrl(url='https://example.com/image.png')])]
282+
)
283+
]
284+
with pytest.raises(UserError, match='Outlines does not support multi-modal user prompts yet.'):
285+
agent.run_sync('How are you doing?', message_history=message_history)
286+
287+
# unsupported: tool calls
288+
message_history: list[ModelMessage] = [
289+
ModelResponse(parts=[ToolCallPart(tool_call_id='1', tool_name='get_location')]),
290+
ModelRequest(parts=[ToolReturnPart(tool_name='get_location', content='London', tool_call_id='1')]),
291+
]
292+
with pytest.raises(UserError, match='Tool calls are not supported for Outlines models yet.'):
293+
agent.run_sync('How are you doing?', message_history=message_history)
294+
295+
# unsupported: tool returns
296+
message_history: list[ModelMessage] = [
297+
ModelRequest(parts=[ToolReturnPart(tool_name='get_location', content='London', tool_call_id='1')])
298+
]
299+
with pytest.raises(UserError, match='Tool calls are not supported for Outlines models yet.'):
300+
agent.run_sync('How are you doing?', message_history=message_history)

tests/providers/test_outlines.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import pytest
2+
3+
from pydantic_ai.profiles import ModelProfile
4+
from pydantic_ai.providers.outlines import OutlinesProvider
5+
6+
7+
def test_outlines_provider() -> None:
8+
provider = OutlinesProvider()
9+
assert provider.name == 'outlines'
10+
11+
with pytest.raises(
12+
NotImplementedError,
13+
match=(
14+
'The Outlines provider does not have a set base URL as it functions '
15+
+ 'with a set of different underlying models.'
16+
),
17+
):
18+
provider.base_url
19+
20+
with pytest.raises(
21+
NotImplementedError,
22+
match=(
23+
'The Outlines provider does not have a set client as it functions '
24+
+ 'with a set of different underlying models.'
25+
),
26+
):
27+
provider.client
28+
29+
assert provider.model_profile('outlines-model') == ModelProfile(
30+
supports_tools=False,
31+
supports_json_schema_output=True,
32+
supports_json_object_output=True,
33+
default_structured_output_mode='native',
34+
thinking_tags=('<think>', '</think>'),
35+
ignore_streamed_leading_whitespace=False,
36+
)

tests/providers/test_provider_names.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pydantic_ai.providers.ollama import OllamaProvider
3131
from pydantic_ai.providers.openai import OpenAIProvider
3232
from pydantic_ai.providers.openrouter import OpenRouterProvider
33+
from pydantic_ai.providers.outlines import OutlinesProvider
3334
from pydantic_ai.providers.together import TogetherProvider
3435
from pydantic_ai.providers.vercel import VercelProvider
3536

@@ -52,6 +53,7 @@
5253
('heroku', HerokuProvider, 'HEROKU_INFERENCE_KEY'),
5354
('github', GitHubProvider, 'GITHUB_API_KEY'),
5455
('ollama', OllamaProvider, 'OLLAMA_BASE_URL'),
56+
('outlines', OutlinesProvider, None),
5557
]
5658

5759
if not imports_successful():

0 commit comments

Comments
 (0)