|
| 1 | +# We only test with the transformers model to limit the number of dependencies |
| 2 | + |
| 3 | +import json |
| 4 | +from functools import partial |
| 5 | +from typing import TYPE_CHECKING, Any |
| 6 | + |
| 7 | +import pytest |
| 8 | +from pydantic import BaseModel |
| 9 | + |
| 10 | +from pydantic_ai import Agent, ModelRetry |
| 11 | +from pydantic_ai.builtin_tools import WebSearchTool |
| 12 | +from pydantic_ai.exceptions import UserError |
| 13 | +from pydantic_ai.messages import ( |
| 14 | + ImageUrl, |
| 15 | + ModelMessage, |
| 16 | + ModelRequest, |
| 17 | + ModelResponse, |
| 18 | + RetryPromptPart, |
| 19 | + SystemPromptPart, |
| 20 | + TextPart, |
| 21 | + ThinkingPart, |
| 22 | + ToolCallPart, |
| 23 | + ToolReturnPart, |
| 24 | + UserPromptPart, |
| 25 | +) |
| 26 | +from pydantic_ai.output import ToolOutput |
| 27 | +from pydantic_ai.profiles import ModelProfile |
| 28 | + |
| 29 | +from ..conftest import try_import |
| 30 | + |
| 31 | +with try_import() as imports_successful: |
| 32 | + from outlines.models.transformers import Transformers, from_transformers |
| 33 | + from transformers import AutoModelForCausalLM, AutoTokenizer |
| 34 | + |
| 35 | + from pydantic_ai.models.outlines import ( |
| 36 | + OutlinesModel, |
| 37 | + ) |
| 38 | + from pydantic_ai.providers.outlines import OutlinesProvider |
| 39 | + |
| 40 | +if TYPE_CHECKING: |
| 41 | + from outlines.models.transformers import Transformers |
| 42 | + from transformers import AutoModelForCausalLM, AutoTokenizer |
| 43 | + |
| 44 | +pytestmark = [ |
| 45 | + pytest.mark.skipif(not imports_successful(), reason='outlines not installed'), |
| 46 | + pytest.mark.anyio, |
| 47 | +] |
| 48 | + |
| 49 | + |
| 50 | +TRANSFORMERS_MODEL_NAME = 'erwanf/gpt2-mini' |
| 51 | + |
| 52 | + |
| 53 | +@pytest.fixture |
| 54 | +def outlines_model() -> "Transformers": |
| 55 | + hf_model = AutoModelForCausalLM.from_pretrained(TRANSFORMERS_MODEL_NAME) # type: ignore[no-untyped-call] |
| 56 | + hf_tokenizer = AutoTokenizer.from_pretrained(TRANSFORMERS_MODEL_NAME) # type: ignore[no-untyped-call] |
| 57 | + chat_template = '{% for message in messages %}{{ message.role }}: {{ message.content }}{% endfor %}' |
| 58 | + hf_tokenizer.chat_template = chat_template |
| 59 | + return from_transformers(model=hf_model, tokenizer_or_processor=hf_tokenizer) # type: ignore[no-untyped-call] |
| 60 | + |
| 61 | + |
| 62 | +def test_init(outlines_model: Any) -> None: |
| 63 | + m = OutlinesModel(outlines_model, provider=OutlinesProvider()) |
| 64 | + assert isinstance(m.model, Transformers) |
| 65 | + assert m.model_name == 'outlines-model' |
| 66 | + assert m.system == 'outlines' |
| 67 | + assert m.settings is None |
| 68 | + assert m.profile == ModelProfile( |
| 69 | + supports_tools=False, |
| 70 | + supports_json_schema_output=True, |
| 71 | + supports_json_object_output=True, |
| 72 | + default_structured_output_mode='native', |
| 73 | + thinking_tags=('<think>', '</think>'), |
| 74 | + ignore_streamed_leading_whitespace=False, |
| 75 | + ) |
| 76 | + |
| 77 | + |
| 78 | +def test_from_transformers(): |
| 79 | + m = OutlinesModel.from_transformers( |
| 80 | + AutoModelForCausalLM.from_pretrained(TRANSFORMERS_MODEL_NAME), # type: ignore[no-untyped-call] |
| 81 | + AutoTokenizer.from_pretrained(TRANSFORMERS_MODEL_NAME), # type: ignore[no-untyped-call] |
| 82 | + ) |
| 83 | + assert isinstance(m.model, Transformers) |
| 84 | + assert m.model_name == 'outlines-model' |
| 85 | + assert m.system == 'outlines' |
| 86 | + assert m.settings is None |
| 87 | + assert m.profile == ModelProfile( |
| 88 | + supports_tools=False, |
| 89 | + supports_json_schema_output=True, |
| 90 | + supports_json_object_output=True, |
| 91 | + default_structured_output_mode='native', |
| 92 | + thinking_tags=('<think>', '</think>'), |
| 93 | + ignore_streamed_leading_whitespace=False, |
| 94 | + ) |
| 95 | + |
| 96 | + |
| 97 | +async def test_request_async(outlines_model: "Transformers") -> None: |
| 98 | + m = OutlinesModel(outlines_model) |
| 99 | + agent = Agent(m) |
| 100 | + result = await agent.run('What is the capital of France?') |
| 101 | + assert len(result.output) > 0 |
| 102 | + result = await agent.run('What is the capital of Germany?', message_history=result.all_messages()) |
| 103 | + assert len(result.output) > 0 |
| 104 | + all_messages = result.all_messages() |
| 105 | + assert len(all_messages) == 4 |
| 106 | + |
| 107 | + assert isinstance(all_messages[0], ModelRequest) |
| 108 | + assert len(all_messages[0].parts) == 1 |
| 109 | + assert isinstance(all_messages[0].parts[0], UserPromptPart) |
| 110 | + assert all_messages[0].parts[0].content == 'What is the capital of France?' |
| 111 | + |
| 112 | + assert isinstance(all_messages[1], ModelResponse) |
| 113 | + assert len(all_messages[1].parts) == 1 |
| 114 | + assert isinstance(all_messages[1].parts[0], TextPart) |
| 115 | + assert isinstance(all_messages[1].parts[0].content, str) |
| 116 | + |
| 117 | + assert isinstance(all_messages[2], ModelRequest) |
| 118 | + assert len(all_messages[2].parts) == 1 |
| 119 | + assert isinstance(all_messages[2].parts[0], UserPromptPart) |
| 120 | + assert all_messages[2].parts[0].content == 'What is the capital of Germany?' |
| 121 | + |
| 122 | + assert isinstance(all_messages[3], ModelResponse) |
| 123 | + assert len(all_messages[3].parts) == 1 |
| 124 | + assert isinstance(all_messages[3].parts[0], TextPart) |
| 125 | + assert isinstance(all_messages[3].parts[0].content, str) |
| 126 | + |
| 127 | + |
| 128 | +def test_request_sync(outlines_model: "Transformers") -> None: |
| 129 | + m = OutlinesModel(outlines_model) |
| 130 | + agent = Agent(m) |
| 131 | + result = agent.run_sync('What is the capital of France?') |
| 132 | + assert len(result.output) > 0 |
| 133 | + all_messages = result.all_messages() |
| 134 | + |
| 135 | + assert len(all_messages) == 2 |
| 136 | + |
| 137 | + assert isinstance(all_messages[0], ModelRequest) |
| 138 | + assert len(all_messages[0].parts) == 1 |
| 139 | + assert isinstance(all_messages[0].parts[0], UserPromptPart) |
| 140 | + assert all_messages[0].parts[0].content == 'What is the capital of France?' |
| 141 | + |
| 142 | + assert isinstance(all_messages[1], ModelResponse) |
| 143 | + assert len(all_messages[1].parts) == 1 |
| 144 | + assert isinstance(all_messages[1].parts[0], TextPart) |
| 145 | + assert isinstance(all_messages[1].parts[0].content, str) |
| 146 | + |
| 147 | + |
| 148 | +async def test_request_streaming(outlines_model: "Transformers") -> None: |
| 149 | + # The transformers model does not support streaming, |
| 150 | + # so we need to mock the generate_stream method. |
| 151 | + def patched_generate_stream(self_ref: "Transformers", *args: Any, **kwargs: Any) -> Any: |
| 152 | + response = self_ref.generate(*args, **kwargs) # type: ignore[no-untyped-call] |
| 153 | + for i in range(0, len(response), 10): |
| 154 | + chunk = response[i : i + 10] |
| 155 | + yield chunk |
| 156 | + |
| 157 | + outlines_model.generate_stream = partial(patched_generate_stream, outlines_model) |
| 158 | + m = OutlinesModel(outlines_model) |
| 159 | + agent = Agent(m) |
| 160 | + async with agent.run_stream('What is the capital of the UK?') as response: |
| 161 | + async for text in response.stream_text(): |
| 162 | + assert isinstance(text, str) |
| 163 | + assert len(text) > 0 |
| 164 | + |
| 165 | + |
| 166 | +def test_tool_definition(outlines_model: "Transformers") -> None: |
| 167 | + m = OutlinesModel(outlines_model) |
| 168 | + |
| 169 | + # function tools |
| 170 | + agent = Agent(m, builtin_tools=[WebSearchTool()]) |
| 171 | + with pytest.raises(UserError, match='Outlines does not support function tools and builtin tools yet.'): |
| 172 | + agent.run_sync('Hello') |
| 173 | + |
| 174 | + # built-in tools |
| 175 | + agent = Agent(m) |
| 176 | + |
| 177 | + @agent.tool_plain |
| 178 | + async def get_location(loc_name: str) -> str: # pragma: no cover |
| 179 | + if loc_name == 'London': |
| 180 | + return json.dumps({'lat': 51, 'lng': 0}) |
| 181 | + else: |
| 182 | + raise ModelRetry('Wrong location, please try again') |
| 183 | + |
| 184 | + with pytest.raises(UserError, match='Outlines does not support function tools and builtin tools yet.'): |
| 185 | + agent.run_sync('Hello') |
| 186 | + |
| 187 | + # output tools |
| 188 | + class MyOutput(BaseModel): |
| 189 | + name: str |
| 190 | + |
| 191 | + agent = Agent(m, output_type=ToolOutput(MyOutput, name='my_output_tool')) |
| 192 | + with pytest.raises(UserError, match='Output tools are not supported by the model.'): |
| 193 | + agent.run_sync('Hello') |
| 194 | + |
| 195 | + |
| 196 | +def test_output_type(outlines_model: "Transformers") -> None: |
| 197 | + class Box(BaseModel): |
| 198 | + width: int |
| 199 | + height: int |
| 200 | + depth: int |
| 201 | + units: int |
| 202 | + |
| 203 | + m = OutlinesModel(outlines_model) |
| 204 | + agent = Agent(m, output_type=Box) |
| 205 | + result = agent.run_sync('Give me the dimensions of a box', model_settings={'max_new_tokens': 100}) # type: ignore[typeddict-item] |
| 206 | + assert isinstance(result.output, Box) |
| 207 | + |
| 208 | + |
| 209 | +def test_model_settings(outlines_model: "Transformers") -> None: |
| 210 | + # set at model level + max_new_tokens |
| 211 | + model = OutlinesModel(outlines_model, settings={'max_new_tokens': 1}) # type: ignore[typeddict-item] |
| 212 | + agent = Agent(model) |
| 213 | + result = agent.run_sync('How are you doing?') |
| 214 | + assert len(result.output) < 10 |
| 215 | + |
| 216 | + # set at agent level + stop_sequences |
| 217 | + model = OutlinesModel(outlines_model) |
| 218 | + agent = Agent( |
| 219 | + model, |
| 220 | + model_settings={ # type: ignore[typeddict-item] |
| 221 | + 'stop_sequences': ['Paris'], |
| 222 | + 'max_new_tokens': 200, |
| 223 | + 'extra_body': {'tokenizer': outlines_model.hf_tokenizer}, |
| 224 | + }, |
| 225 | + ) |
| 226 | + result = agent.run_sync('Write a story about Paris') |
| 227 | + assert result.output.endswith('Paris') |
| 228 | + |
| 229 | + # set at run level + args of ModelSettings that are not supported |
| 230 | + model = OutlinesModel(outlines_model) |
| 231 | + agent = Agent(model) |
| 232 | + with pytest.warns(UserWarning, match='The transformers model does not support'): |
| 233 | + result = agent.run_sync( |
| 234 | + 'Hello', |
| 235 | + model_settings={ |
| 236 | + 'timeout': 1, |
| 237 | + 'parallel_tool_calls': True, |
| 238 | + 'seed': 123, |
| 239 | + 'extra_headers': {'Authorization': 'Bearer 123'}, |
| 240 | + }, |
| 241 | + ) |
| 242 | + assert isinstance(result.output, str) |
| 243 | + |
| 244 | + # presence_penalty and frequency_penalty |
| 245 | + with pytest.warns(UserWarning, match='The transformers model has a single argument `repetition_penalty`'): |
| 246 | + result = agent.run_sync('Hello', model_settings={'presence_penalty': 0.7, 'frequency_penalty': 0.3}) |
| 247 | + assert isinstance(result.output, str) |
| 248 | + |
| 249 | + # logit_bias |
| 250 | + with pytest.warns(UserWarning, match='The transformers model expects the keys of the `logits_bias`'): |
| 251 | + result = agent.run_sync('Hello', model_settings={'logit_bias': {'20,21': 0.5, '22': 0.3, 'a': 0.2}}) # type: ignore[typeddict-item] |
| 252 | + assert isinstance(result.output, str) |
| 253 | + |
| 254 | + |
| 255 | +def test_input_format(outlines_model: "Transformers") -> None: |
| 256 | + m = OutlinesModel(outlines_model) |
| 257 | + agent = Agent(m) |
| 258 | + |
| 259 | + # all accepted message types |
| 260 | + message_history: list[ModelMessage] = [ |
| 261 | + ModelRequest( |
| 262 | + parts=[ |
| 263 | + SystemPromptPart(content='You are a helpful assistance'), |
| 264 | + UserPromptPart(content='Hello'), |
| 265 | + UserPromptPart(content=['Foo', 'Bar']), |
| 266 | + RetryPromptPart(content='Failure'), |
| 267 | + ] |
| 268 | + ), |
| 269 | + ModelResponse( |
| 270 | + parts=[ |
| 271 | + ThinkingPart('Thinking...'), # ignored by the model |
| 272 | + TextPart('Hello there!'), |
| 273 | + ] |
| 274 | + ), |
| 275 | + ] |
| 276 | + agent.run_sync('How are you doing?', message_history=message_history) |
| 277 | + |
| 278 | + # unsupported: multi-modal user prompts |
| 279 | + message_history: list[ModelMessage] = [ |
| 280 | + ModelRequest( |
| 281 | + parts=[UserPromptPart(content=['Describe the image', ImageUrl(url='https://example.com/image.png')])] |
| 282 | + ) |
| 283 | + ] |
| 284 | + with pytest.raises(UserError, match='Outlines does not support multi-modal user prompts yet.'): |
| 285 | + agent.run_sync('How are you doing?', message_history=message_history) |
| 286 | + |
| 287 | + # unsupported: tool calls |
| 288 | + message_history: list[ModelMessage] = [ |
| 289 | + ModelResponse(parts=[ToolCallPart(tool_call_id='1', tool_name='get_location')]), |
| 290 | + ModelRequest(parts=[ToolReturnPart(tool_name='get_location', content='London', tool_call_id='1')]), |
| 291 | + ] |
| 292 | + with pytest.raises(UserError, match='Tool calls are not supported for Outlines models yet.'): |
| 293 | + agent.run_sync('How are you doing?', message_history=message_history) |
| 294 | + |
| 295 | + # unsupported: tool returns |
| 296 | + message_history: list[ModelMessage] = [ |
| 297 | + ModelRequest(parts=[ToolReturnPart(tool_name='get_location', content='London', tool_call_id='1')]) |
| 298 | + ] |
| 299 | + with pytest.raises(UserError, match='Tool calls are not supported for Outlines models yet.'): |
| 300 | + agent.run_sync('How are you doing?', message_history=message_history) |
0 commit comments