Skip to content

Commit c33fe4b

Browse files
Adding subclasses of ModelSettings to support specialized model requests (#766)
1 parent 8a93fed commit c33fe4b

File tree

10 files changed

+203
-40
lines changed

10 files changed

+203
-40
lines changed

docs/agents.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,28 @@ print(result_sync.data)
204204
#> Rome
205205
```
206206

207+
### Model specific settings
208+
209+
<!-- TODO: replace this with the gemini safety settings example once added via https://github.com/pydantic/pydantic-ai/issues/373 -->
210+
211+
If you wish to further customize model behavior, you can use a subclass of [`ModelSettings`][pydantic_ai.settings.ModelSettings], like [`AnthropicModelSettings`][pydantic_ai.models.anthropic.AnthropicModelSettings], associated with your model of choice.
212+
213+
For example:
214+
215+
```py
216+
from pydantic_ai import Agent
217+
from pydantic_ai.models.anthropic import AnthropicModelSettings
218+
219+
agent = Agent('anthropic:claude-3-5-sonnet-latest')
220+
221+
result_sync = agent.run_sync(
222+
'What is the capital of Italy?',
223+
model_settings=AnthropicModelSettings(anthropic_metadata={'user_id': 'my_user_id'}),
224+
)
225+
print(result_sync.data)
226+
#> Rome
227+
```
228+
207229
## Runs vs. Conversations
208230

209231
An agent **run** might represent an entire conversation — there's no limit to how many messages can be exchanged in a single run. However, a **conversation** might also be composed of multiple runs, especially if you need to maintain state between separate interactions or API calls.

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from anthropic.types import (
4242
Message as AnthropicMessage,
4343
MessageParam,
44+
MetadataParam,
4445
RawContentBlockDeltaEvent,
4546
RawContentBlockStartEvent,
4647
RawContentBlockStopEvent,
@@ -79,6 +80,15 @@
7980
"""
8081

8182

83+
class AnthropicModelSettings(ModelSettings):
84+
"""Settings used for an Anthropic model request."""
85+
86+
anthropic_metadata: MetadataParam
87+
"""An object describing metadata about the request.
88+
89+
Contains `user_id`, an external identifier for the user who is associated with the request."""
90+
91+
8292
@dataclass(init=False)
8393
class AnthropicModel(Model):
8494
"""A model that uses the Anthropic API.
@@ -167,35 +177,33 @@ class AnthropicAgentModel(AgentModel):
167177
async def request(
168178
self, messages: list[ModelMessage], model_settings: ModelSettings | None
169179
) -> tuple[ModelResponse, usage.Usage]:
170-
response = await self._messages_create(messages, False, model_settings)
180+
response = await self._messages_create(messages, False, cast(AnthropicModelSettings, model_settings or {}))
171181
return self._process_response(response), _map_usage(response)
172182

173183
@asynccontextmanager
174184
async def request_stream(
175185
self, messages: list[ModelMessage], model_settings: ModelSettings | None
176186
) -> AsyncIterator[StreamedResponse]:
177-
response = await self._messages_create(messages, True, model_settings)
187+
response = await self._messages_create(messages, True, cast(AnthropicModelSettings, model_settings or {}))
178188
async with response:
179189
yield await self._process_streamed_response(response)
180190

181191
@overload
182192
async def _messages_create(
183-
self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
193+
self, messages: list[ModelMessage], stream: Literal[True], model_settings: AnthropicModelSettings
184194
) -> AsyncStream[RawMessageStreamEvent]:
185195
pass
186196

187197
@overload
188198
async def _messages_create(
189-
self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
199+
self, messages: list[ModelMessage], stream: Literal[False], model_settings: AnthropicModelSettings
190200
) -> AnthropicMessage:
191201
pass
192202

193203
async def _messages_create(
194-
self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
204+
self, messages: list[ModelMessage], stream: bool, model_settings: AnthropicModelSettings
195205
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
196206
# standalone function to make it easier to override
197-
model_settings = model_settings or {}
198-
199207
tool_choice: ToolChoiceParam | None
200208

201209
if not self.tools:
@@ -222,6 +230,7 @@ async def _messages_create(
222230
temperature=model_settings.get('temperature', NOT_GIVEN),
223231
top_p=model_settings.get('top_p', NOT_GIVEN),
224232
timeout=model_settings.get('timeout', NOT_GIVEN),
233+
metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
225234
)
226235

227236
def _process_response(self, response: AnthropicMessage) -> ModelResponse:

pydantic_ai_slim/pydantic_ai/models/cohere.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Iterable
44
from dataclasses import dataclass, field
55
from itertools import chain
6-
from typing import Literal, TypeAlias, Union
6+
from typing import Literal, TypeAlias, Union, cast
77

88
from cohere import TextAssistantMessageContentItem
99
from typing_extensions import assert_never
@@ -71,6 +71,12 @@
7171
]
7272

7373

74+
class CohereModelSettings(ModelSettings):
75+
"""Settings used for a Cohere model request."""
76+
77+
# This class is a placeholder for any future cohere-specific settings
78+
79+
7480
@dataclass(init=False)
7581
class CohereModel(Model):
7682
"""A model that uses the Cohere API.
@@ -153,23 +159,25 @@ class CohereAgentModel(AgentModel):
153159
async def request(
154160
self, messages: list[ModelMessage], model_settings: ModelSettings | None
155161
) -> tuple[ModelResponse, result.Usage]:
156-
response = await self._chat(messages, model_settings)
162+
response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}))
157163
return self._process_response(response), _map_usage(response)
158164

159165
async def _chat(
160166
self,
161167
messages: list[ModelMessage],
162-
model_settings: ModelSettings | None,
168+
model_settings: CohereModelSettings,
163169
) -> ChatResponse:
164170
cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
165-
model_settings = model_settings or {}
166171
return await self.client.chat(
167172
model=self.model_name,
168173
messages=cohere_messages,
169174
tools=self.tools or OMIT,
170175
max_tokens=model_settings.get('max_tokens', OMIT),
171176
temperature=model_settings.get('temperature', OMIT),
172177
p=model_settings.get('top_p', OMIT),
178+
seed=model_settings.get('seed', OMIT),
179+
presence_penalty=model_settings.get('presence_penalty', OMIT),
180+
frequency_penalty=model_settings.get('frequency_penalty', OMIT),
173181
)
174182

175183
def _process_response(self, response: ChatResponse) -> ModelResponse:

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from copy import deepcopy
88
from dataclasses import dataclass, field
99
from datetime import datetime
10-
from typing import Annotated, Any, Literal, Protocol, Union
10+
from typing import Annotated, Any, Literal, Protocol, Union, cast
1111
from uuid import uuid4
1212

1313
import pydantic
@@ -48,6 +48,12 @@
4848
"""
4949

5050

51+
class GeminiModelSettings(ModelSettings):
52+
"""Settings used for a Gemini model request."""
53+
54+
# This class is a placeholder for any future gemini-specific settings
55+
56+
5157
@dataclass(init=False)
5258
class GeminiModel(Model):
5359
"""A model that uses Gemini via `generativelanguage.googleapis.com` API.
@@ -171,20 +177,22 @@ def __init__(
171177
async def request(
172178
self, messages: list[ModelMessage], model_settings: ModelSettings | None
173179
) -> tuple[ModelResponse, usage.Usage]:
174-
async with self._make_request(messages, False, model_settings) as http_response:
180+
async with self._make_request(
181+
messages, False, cast(GeminiModelSettings, model_settings or {})
182+
) as http_response:
175183
response = _gemini_response_ta.validate_json(await http_response.aread())
176184
return self._process_response(response), _metadata_as_usage(response)
177185

178186
@asynccontextmanager
179187
async def request_stream(
180188
self, messages: list[ModelMessage], model_settings: ModelSettings | None
181189
) -> AsyncIterator[StreamedResponse]:
182-
async with self._make_request(messages, True, model_settings) as http_response:
190+
async with self._make_request(messages, True, cast(GeminiModelSettings, model_settings or {})) as http_response:
183191
yield await self._process_streamed_response(http_response)
184192

185193
@asynccontextmanager
186194
async def _make_request(
187-
self, messages: list[ModelMessage], streamed: bool, model_settings: ModelSettings | None
195+
self, messages: list[ModelMessage], streamed: bool, model_settings: GeminiModelSettings
188196
) -> AsyncIterator[HTTPResponse]:
189197
sys_prompt_parts, contents = self._message_to_gemini_content(messages)
190198

@@ -204,6 +212,10 @@ async def _make_request(
204212
generation_config['temperature'] = temperature
205213
if (top_p := model_settings.get('top_p')) is not None:
206214
generation_config['top_p'] = top_p
215+
if (presence_penalty := model_settings.get('presence_penalty')) is not None:
216+
generation_config['presence_penalty'] = presence_penalty
217+
if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
218+
generation_config['frequency_penalty'] = frequency_penalty
207219
if generation_config:
208220
request_data['generation_config'] = generation_config
209221

@@ -222,7 +234,7 @@ async def _make_request(
222234
url,
223235
content=request_json,
224236
headers=headers,
225-
timeout=(model_settings or {}).get('timeout', USE_CLIENT_DEFAULT),
237+
timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT),
226238
) as r:
227239
if r.status_code != 200:
228240
await r.aread()
@@ -398,6 +410,8 @@ class _GeminiGenerationConfig(TypedDict, total=False):
398410
max_output_tokens: int
399411
temperature: float
400412
top_p: float
413+
presence_penalty: float
414+
frequency_penalty: float
401415

402416

403417
class _GeminiContent(TypedDict):

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import dataclass, field
66
from datetime import datetime, timezone
77
from itertools import chain
8-
from typing import Literal, overload
8+
from typing import Literal, cast, overload
99

1010
from httpx import AsyncClient as AsyncHTTPClient
1111
from typing_extensions import assert_never
@@ -68,6 +68,12 @@
6868
"""
6969

7070

71+
class GroqModelSettings(ModelSettings):
72+
"""Settings used for a Groq model request."""
73+
74+
# This class is a placeholder for any future groq-specific settings
75+
76+
7177
@dataclass(init=False)
7278
class GroqModel(Model):
7379
"""A model that uses the Groq API.
@@ -155,31 +161,31 @@ class GroqAgentModel(AgentModel):
155161
async def request(
156162
self, messages: list[ModelMessage], model_settings: ModelSettings | None
157163
) -> tuple[ModelResponse, usage.Usage]:
158-
response = await self._completions_create(messages, False, model_settings)
164+
response = await self._completions_create(messages, False, cast(GroqModelSettings, model_settings or {}))
159165
return self._process_response(response), _map_usage(response)
160166

161167
@asynccontextmanager
162168
async def request_stream(
163169
self, messages: list[ModelMessage], model_settings: ModelSettings | None
164170
) -> AsyncIterator[StreamedResponse]:
165-
response = await self._completions_create(messages, True, model_settings)
171+
response = await self._completions_create(messages, True, cast(GroqModelSettings, model_settings or {}))
166172
async with response:
167173
yield await self._process_streamed_response(response)
168174

169175
@overload
170176
async def _completions_create(
171-
self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
177+
self, messages: list[ModelMessage], stream: Literal[True], model_settings: GroqModelSettings
172178
) -> AsyncStream[ChatCompletionChunk]:
173179
pass
174180

175181
@overload
176182
async def _completions_create(
177-
self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
183+
self, messages: list[ModelMessage], stream: Literal[False], model_settings: GroqModelSettings
178184
) -> chat.ChatCompletion:
179185
pass
180186

181187
async def _completions_create(
182-
self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
188+
self, messages: list[ModelMessage], stream: bool, model_settings: GroqModelSettings
183189
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
184190
# standalone function to make it easier to override
185191
if not self.tools:
@@ -191,8 +197,6 @@ async def _completions_create(
191197

192198
groq_messages = list(chain(*(self._map_message(m) for m in messages)))
193199

194-
model_settings = model_settings or {}
195-
196200
return await self.client.chat.completions.create(
197201
model=str(self.model_name),
198202
messages=groq_messages,
@@ -205,6 +209,10 @@ async def _completions_create(
205209
temperature=model_settings.get('temperature', NOT_GIVEN),
206210
top_p=model_settings.get('top_p', NOT_GIVEN),
207211
timeout=model_settings.get('timeout', NOT_GIVEN),
212+
seed=model_settings.get('seed', NOT_GIVEN),
213+
presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
214+
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
215+
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
208216
)
209217

210218
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:

pydantic_ai_slim/pydantic_ai/models/mistral.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass, field
77
from datetime import datetime, timezone
88
from itertools import chain
9-
from typing import Any, Callable, Literal, Union
9+
from typing import Any, Callable, Literal, Union, cast
1010

1111
import pydantic_core
1212
from httpx import AsyncClient as AsyncHTTPClient, Timeout
@@ -85,6 +85,12 @@
8585
"""
8686

8787

88+
class MistralModelSettings(ModelSettings):
89+
"""Settings used for a Mistral model request."""
90+
91+
# This class is a placeholder for any future mistral-specific settings
92+
93+
8894
@dataclass(init=False)
8995
class MistralModel(Model):
9096
"""A model that uses Mistral.
@@ -159,23 +165,22 @@ async def request(
159165
self, messages: list[ModelMessage], model_settings: ModelSettings | None
160166
) -> tuple[ModelResponse, Usage]:
161167
"""Make a non-streaming request to the model from Pydantic AI call."""
162-
response = await self._completions_create(messages, model_settings)
168+
response = await self._completions_create(messages, cast(MistralModelSettings, model_settings or {}))
163169
return self._process_response(response), _map_usage(response)
164170

165171
@asynccontextmanager
166172
async def request_stream(
167173
self, messages: list[ModelMessage], model_settings: ModelSettings | None
168174
) -> AsyncIterator[StreamedResponse]:
169175
"""Make a streaming request to the model from Pydantic AI call."""
170-
response = await self._stream_completions_create(messages, model_settings)
176+
response = await self._stream_completions_create(messages, cast(MistralModelSettings, model_settings or {}))
171177
async with response:
172178
yield await self._process_streamed_response(self.result_tools, response)
173179

174180
async def _completions_create(
175-
self, messages: list[ModelMessage], model_settings: ModelSettings | None
181+
self, messages: list[ModelMessage], model_settings: MistralModelSettings
176182
) -> MistralChatCompletionResponse:
177183
"""Make a non-streaming request to the model."""
178-
model_settings = model_settings or {}
179184
response = await self.client.chat.complete_async(
180185
model=str(self.model_name),
181186
messages=list(chain(*(self._map_message(m) for m in messages))),
@@ -187,19 +192,19 @@ async def _completions_create(
187192
temperature=model_settings.get('temperature', UNSET),
188193
top_p=model_settings.get('top_p', 1),
189194
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
195+
random_seed=model_settings.get('seed', UNSET),
190196
)
191197
assert response, 'A unexpected empty response from Mistral.'
192198
return response
193199

194200
async def _stream_completions_create(
195201
self,
196202
messages: list[ModelMessage],
197-
model_settings: ModelSettings | None,
203+
model_settings: MistralModelSettings,
198204
) -> MistralEventStreamAsync[MistralCompletionEvent]:
199205
"""Create a streaming completion request to the Mistral model."""
200206
response: MistralEventStreamAsync[MistralCompletionEvent] | None
201207
mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
202-
model_settings = model_settings or {}
203208

204209
if self.result_tools and self.function_tools or self.function_tools:
205210
# Function Calling
@@ -213,6 +218,8 @@ async def _stream_completions_create(
213218
top_p=model_settings.get('top_p', 1),
214219
max_tokens=model_settings.get('max_tokens', UNSET),
215220
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
221+
presence_penalty=model_settings.get('presence_penalty'),
222+
frequency_penalty=model_settings.get('frequency_penalty'),
216223
)
217224

218225
elif self.result_tools:

0 commit comments

Comments
 (0)