Skip to content

Commit f118c9f

Browse files
committed
Draft
1 parent fde5eea commit f118c9f

File tree

8 files changed

+535
-13
lines changed

8 files changed

+535
-13
lines changed

outlines_example.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from pydantic_ai import Agent, NativeOutput
2+
from pydantic_ai.models.outlines import OutlinesModel
3+
from pydantic_ai.settings import ModelSettings
4+
from pydantic import BaseModel
5+
6+
7+
class Box(BaseModel):
8+
width: int
9+
height: int
10+
depth: int
11+
units: str
12+
13+
14+
def transformers_example():
15+
16+
print("---- start transformers_example ----")
17+
18+
from transformers import AutoModelForCausalLM, AutoTokenizer
19+
20+
hf_model = AutoModelForCausalLM.from_pretrained("erwanf/gpt2-mini")
21+
hf_tokenizer = AutoTokenizer.from_pretrained("erwanf/gpt2-mini")
22+
chat_template = '{% for message in messages %}{{ message.role }}: {{ message.content }}{% endfor %}'
23+
hf_tokenizer.chat_template = chat_template
24+
25+
model = OutlinesModel.transformers(hf_model, hf_tokenizer, settings=ModelSettings(max_new_tokens=100))
26+
agent = Agent(model, output_type=NativeOutput([Box]))
27+
28+
response = agent.run_sync('Give me the dimensions of a box')
29+
print("response.output: ", response.output)
30+
31+
response2 = agent.run_sync('Give me another box', message_history=response.all_messages())
32+
print("response2.output: ", response2.output)
33+
34+
print("all_messages: ", response2.all_messages())
35+
36+
print("---- end transformers_example ----")
37+
38+
39+
def llama_cpp_example():
40+
print("---- start llama_cpp_example ----")
41+
42+
from llama_cpp import Llama
43+
44+
llama_model = Llama.from_pretrained(
45+
repo_id="TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
46+
filename="tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
47+
n_ctx=2048, # 2K context window
48+
)
49+
50+
model = OutlinesModel.llama_cpp(llama_model)
51+
agent = Agent(model, output_type=NativeOutput([Box]))
52+
53+
response = agent.run_sync('Give me the dimensions of a box')
54+
print("response.output: ", response.output)
55+
56+
response2 = agent.run_sync('Give me another box', message_history=response.all_messages())
57+
print("response2.output: ", response2.output)
58+
59+
print("all_messages: ", response2.all_messages())
60+
61+
print("---- end llama_cpp_example ----")
62+
63+
64+
if __name__ == "__main__":
65+
#transformers_example()
66+
llama_cpp_example()
67+
#existing()
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
from collections.abc import AsyncIterable, AsyncIterator
2+
from contextlib import asynccontextmanager
3+
from dataclasses import dataclass, field
4+
from datetime import datetime, timezone
5+
from typing import Any, Literal
6+
7+
from .. import UnexpectedModelBehavior, _utils
8+
from .._run_context import RunContext
9+
from ..messages import (
10+
ModelMessage,
11+
ModelResponse,
12+
ModelResponseStreamEvent,
13+
TextPart,
14+
)
15+
from ..profiles import ModelProfileSpec
16+
from ..providers import Provider, infer_provider
17+
from ..settings import ModelSettings
18+
from . import (
19+
Model,
20+
ModelRequestParameters,
21+
StreamedResponse,
22+
)
23+
24+
try:
25+
from outlines.inputs import Chat
26+
from outlines.models.base import AsyncModel as OutlinesAsyncBaseModel, Model as OutlinesBaseModel
27+
from outlines.models.llamacpp import from_llamacpp # pyright: ignore[reportUnknownVariableType]
28+
from outlines.models.mlxlm import from_mlxlm # pyright: ignore[reportUnknownVariableType]
29+
from outlines.models.sglang import from_sglang
30+
from outlines.models.tgi import from_tgi
31+
from outlines.models.transformers import from_transformers # pyright: ignore[reportUnknownVariableType]
32+
from outlines.models.vllm import from_vllm
33+
from outlines.types.dsl import JsonSchema
34+
except ImportError as _import_error:
35+
raise ImportError(
36+
'Please install `outlines` to use the Outlines model, '
37+
'you can use the `outlines` optional group — `pip install "pydantic-ai-slim[outlines]"`'
38+
) from _import_error
39+
40+
41+
@dataclass
42+
class OutlinesStreamedResponse(StreamedResponse):
43+
"""Implementation of `StreamedResponse` for Outlines models."""
44+
45+
_model_name: str
46+
_response: AsyncIterable[str]
47+
_timestamp: datetime
48+
49+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
50+
async for event in self._response:
51+
event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=event)
52+
if event is not None: # pragma: no branch
53+
yield event
54+
55+
@property
56+
def model_name(self) -> str:
57+
"""Get the model name of the response."""
58+
return self._model_name
59+
60+
@property
61+
def timestamp(self) -> datetime:
62+
"""Get the timestamp of the response."""
63+
return self._timestamp
64+
65+
66+
@dataclass(init=False)
67+
class OutlinesModel(Model):
68+
"""A model that relies on the Outlines library to run non API-based models."""
69+
70+
_system: str = field(default='outlines', repr=False)
71+
72+
def __init__(
73+
self,
74+
model: OutlinesBaseModel | OutlinesAsyncBaseModel,
75+
model_name: str | None = None,
76+
*,
77+
provider: Literal['outlines'] | Provider[OutlinesBaseModel] = 'outlines',
78+
profile: ModelProfileSpec | None = None,
79+
settings: ModelSettings | None = None,
80+
):
81+
"""Initialize an Outlines model.
82+
83+
Args:
84+
model: The Outlines model used for the model.
85+
model_name: The name of the model run by the provider.
86+
provider: The provider to use for OutlinesModel. Can be either the string 'outlines' or an
87+
instance of `Provider[OutlinesBaseModel]`. If not provided, the other parameters will be used.
88+
profile: The model profile to use. Defaults to a profile picked by the provider.
89+
settings: Default model settings for this model instance.
90+
"""
91+
self.model = model
92+
self._model_name = model_name
93+
94+
if isinstance(provider, str):
95+
provider = infer_provider(provider)
96+
97+
super().__init__(settings=settings, profile=profile or provider.model_profile)
98+
99+
@classmethod
100+
def transformers(
101+
cls,
102+
hf_model: Any,
103+
hf_tokenizer: Any,
104+
*,
105+
provider: Literal['outlines'] | Provider[OutlinesBaseModel] = 'outlines',
106+
profile: ModelProfileSpec | None = None,
107+
settings: ModelSettings | None = None,
108+
):
109+
outlines_model: OutlinesBaseModel = from_transformers(hf_model, hf_tokenizer)
110+
return cls(outlines_model, None, provider=provider, profile=profile, settings=settings)
111+
112+
@classmethod
113+
def llama_cpp(
114+
cls,
115+
llama_model: Any,
116+
*,
117+
provider: Literal['outlines'] | Provider[OutlinesBaseModel] = 'outlines',
118+
profile: ModelProfileSpec | None = None,
119+
settings: ModelSettings | None = None,
120+
):
121+
outlines_model: OutlinesBaseModel = from_llamacpp(llama_model)
122+
return cls(outlines_model, None, provider=provider, profile=profile, settings=settings)
123+
124+
@classmethod
125+
def mlxlm(
126+
cls,
127+
mlx_model: Any,
128+
mlx_tokenizer: Any,
129+
*,
130+
provider: Literal['outlines'] | Provider[OutlinesBaseModel] = 'outlines',
131+
profile: ModelProfileSpec | None = None,
132+
settings: ModelSettings | None = None,
133+
):
134+
outlines_model: OutlinesBaseModel = from_mlxlm(mlx_model, mlx_tokenizer)
135+
return cls(outlines_model, None, provider=provider, profile=profile, settings=settings)
136+
137+
@classmethod
138+
def tgi(
139+
cls,
140+
client: Any,
141+
*,
142+
provider: Literal['outlines'] | Provider[OutlinesBaseModel] = 'outlines',
143+
profile: ModelProfileSpec | None = None,
144+
settings: ModelSettings | None = None,
145+
):
146+
outlines_model: OutlinesBaseModel | OutlinesAsyncBaseModel = from_tgi(client)
147+
return cls(outlines_model, None, provider=provider, profile=profile, settings=settings)
148+
149+
@classmethod
150+
def sglang(
151+
cls,
152+
client: Any,
153+
model_name: str,
154+
*,
155+
provider: Literal['outlines'] | Provider[OutlinesBaseModel] = 'outlines',
156+
profile: ModelProfileSpec | None = None,
157+
settings: ModelSettings | None = None,
158+
):
159+
outlines_model: OutlinesBaseModel | OutlinesAsyncBaseModel = from_sglang(client, model_name)
160+
return cls(outlines_model, None, provider=provider, profile=profile, settings=settings)
161+
162+
@classmethod
163+
def vllm(
164+
cls,
165+
client: Any,
166+
model_name: str,
167+
*,
168+
provider: Literal['outlines'] | Provider[OutlinesBaseModel] = 'outlines',
169+
profile: ModelProfileSpec | None = None,
170+
settings: ModelSettings | None = None,
171+
):
172+
outlines_model: OutlinesBaseModel | OutlinesAsyncBaseModel = from_vllm(client, model_name)
173+
return cls(outlines_model, None, provider=provider, profile=profile, settings=settings)
174+
175+
@property
176+
def model_name(self) -> str:
177+
return self._model_name or ''
178+
179+
@property
180+
def system(self) -> str:
181+
return self._system
182+
183+
async def request(
184+
self,
185+
messages: list[ModelMessage],
186+
model_settings: ModelSettings | None,
187+
model_request_parameters: ModelRequestParameters,
188+
) -> ModelResponse:
189+
"""Make a request to the model."""
190+
prompt = self._format_prompt(messages)
191+
output_type = (
192+
JsonSchema(model_request_parameters.output_object.json_schema)
193+
if model_request_parameters.output_object
194+
else None
195+
)
196+
model_settings_dict = dict(model_settings) if model_settings else {}
197+
if isinstance(self.model, OutlinesAsyncBaseModel):
198+
response: str = await self.model(prompt, output_type, None, **model_settings_dict)
199+
else:
200+
response: str = self.model(prompt, output_type, None, **model_settings_dict)
201+
return self._process_response(response)
202+
203+
@asynccontextmanager
204+
async def request_stream(
205+
self,
206+
messages: list[ModelMessage],
207+
model_settings: ModelSettings | None,
208+
model_request_parameters: ModelRequestParameters,
209+
run_context: RunContext[Any] | None = None,
210+
) -> AsyncIterator[StreamedResponse]:
211+
prompt = self._format_prompt(messages)
212+
output_type = (
213+
JsonSchema(model_request_parameters.output_object.json_schema)
214+
if model_request_parameters.output_object
215+
else None
216+
)
217+
model_settings_dict = dict(model_settings) if model_settings else {}
218+
if isinstance(self.model, OutlinesAsyncBaseModel):
219+
response = self.model.stream(prompt, output_type, None, **model_settings_dict)
220+
async for chunk in response:
221+
yield chunk
222+
else:
223+
response = self.model.stream(prompt, output_type, None, **model_settings_dict)
224+
225+
async def async_response():
226+
for chunk in response:
227+
yield chunk
228+
229+
yield await self._process_streamed_response(async_response(), model_request_parameters)
230+
231+
def _format_prompt(self, messages: list[ModelMessage]) -> Chat:
232+
"""Turn the model messages into an Outlines Chat instance."""
233+
chat = Chat()
234+
for message in messages:
235+
if message.kind == 'request':
236+
for part in message.parts:
237+
if part.part_kind == 'system-prompt':
238+
chat.add_system_message(part.content)
239+
elif part.part_kind == 'user-prompt':
240+
chat.add_user_message(str(part.content))
241+
elif message.kind == 'response':
242+
for part in message.parts:
243+
if part.part_kind == 'text':
244+
chat.add_assistant_message(str(part.content))
245+
return chat
246+
247+
def _process_response(self, response: str) -> ModelResponse:
248+
"""Turn the Outlines text response into a Pydantic AI model response instance."""
249+
return ModelResponse(parts=[TextPart(content=response)])
250+
251+
async def _process_streamed_response(
252+
self, response: AsyncIterable[str], model_request_parameters: ModelRequestParameters
253+
) -> StreamedResponse:
254+
"""Turn the Outlines text response into a Pydantic AI streamed response instance."""
255+
peekable_response = _utils.PeekableAsyncStream(response)
256+
first_chunk = await peekable_response.peek()
257+
if isinstance(first_chunk, _utils.Unset):
258+
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') # pragma: no cover
259+
260+
timestamp = datetime.now(tz=timezone.utc)
261+
return OutlinesStreamedResponse(
262+
model_request_parameters=model_request_parameters,
263+
_model_name=self.model_name,
264+
_response=peekable_response,
265+
_timestamp=timestamp,
266+
)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from . import ModelProfile
2+
3+
4+
def outlines_model_profile(model_name: str | None = None) -> ModelProfile:
5+
"""Get the model profile for an Outlines model."""
6+
return ModelProfile(
7+
supports_tools=False,
8+
supports_json_schema_output=True,
9+
supports_json_object_output=True,
10+
default_structured_output_mode='native',
11+
)

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
131131
from .github import GitHubProvider
132132

133133
return GitHubProvider
134+
elif provider == 'outlines': # pragma: no cover
135+
from .outlines import OutlinesProvider
136+
137+
return OutlinesProvider
134138
else: # pragma: no cover
135139
raise ValueError(f'Unknown provider: {provider}')
136140

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from __future__ import annotations as _annotations
2+
3+
from typing import Any
4+
5+
from pydantic_ai.profiles import ModelProfile
6+
from pydantic_ai.profiles.outlines import outlines_model_profile
7+
from pydantic_ai.providers import Provider
8+
9+
10+
class OutlinesProvider(Provider[Any]):
11+
"""Provider for Outlines API."""
12+
13+
@property
14+
def name(self) -> str:
15+
"""The provider name."""
16+
return 'outlines'
17+
18+
@property
19+
def base_url(self) -> str:
20+
"""The base URL for the provider API."""
21+
raise NotImplementedError()
22+
23+
@property
24+
def client(self) -> Any:
25+
"""The client for the provider."""
26+
raise NotImplementedError()
27+
28+
def model_profile(self, model_name: str) -> ModelProfile | None:
29+
"""The model profile for the named model, if available."""
30+
return outlines_model_profile(model_name)

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ ag-ui = ["ag-ui-protocol>=0.1.8", "starlette>=0.45.3"]
9595
retries = ["tenacity>=8.2.3"]
9696
# Temporal
9797
temporal = ["temporalio==1.15.0"]
98+
outlines = ["outlines>=0.0.1"]
9899

99100
[tool.hatch.metadata]
100101
allow-direct-references = true

0 commit comments

Comments
 (0)