Skip to content

Commit 7d27c42

Browse files
sydney-runklehyperlint-ai[bot]dmontagusamuelcolvin
authored
Basic ModelSettings logic (#227)
Co-authored-by: hyperlint-ai[bot] <154288675+hyperlint-ai[bot]@users.noreply.github.com> Co-authored-by: David Montague <[email protected]> Co-authored-by: Samuel Colvin <[email protected]>
1 parent 70105cb commit 7d27c42

File tree

15 files changed

+296
-54
lines changed

15 files changed

+296
-54
lines changed

docs/agents.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,31 @@ You can also pass messages from previous runs to continue a conversation or prov
101101
nest_asyncio.apply()
102102
```
103103

104+
### Additional Configuration
105+
106+
PydanticAI offers a [`settings.ModelSettings`][pydantic_ai.settings.ModelSettings] structure to help you fine tune your requests.
107+
This structure allows you to configure common parameters that influence the model's behavior, such as `temperature`, `max_tokens`,
108+
`timeout`, and more.
109+
110+
There are two ways to apply these settings:
111+
1. Passing to `run{_sync,_stream}` functions via the `model_settings` argument. This allows for fine-tuning on a per-request basis.
112+
2. Setting during [`Agent`][pydantic_ai.agent.Agent] initialization via the `model_settings` argument. These settings will be applied by default to all subsequent run calls using said agent. However, `model_settings` provided during a specific run call will override the agent's default settings.
113+
114+
For example, if you'd like to set the `temperature` setting to `0.0` to ensure less random behavior,
115+
you can do the following:
116+
117+
```py
118+
from pydantic_ai import Agent
119+
120+
agent = Agent('openai:gpt-4o')
121+
122+
result_sync = agent.run_sync(
123+
'What is the capital of Italy?', model_settings={'temperature': 0.0}
124+
)
125+
print(result_sync.data)
126+
#> Rome
127+
```
128+
104129
## Runs vs. Conversations
105130

106131
An agent **run** might represent an entire conversation — there's no limit to how many messages can be exchanged in a single run. However, a **conversation** might also be composed of multiple runs, especially if you need to maintain state between separate interactions or API calls.

docs/api/settings.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# `pydantic_ai.settings`
2+
3+
::: pydantic_ai.settings
4+
options:
5+
inherited_members: true
6+
members:
7+
- ModelSettings

mkdocs.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ nav:
3838
- api/result.md
3939
- api/messages.md
4040
- api/exceptions.md
41-
- api/models/anthropic.md
41+
- api/settings.md
4242
- api/models/base.md
43+
- api/models/anthropic.md
4344
- api/models/openai.md
4445
- api/models/ollama.md
4546
- api/models/gemini.md

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
result,
2323
)
2424
from .result import ResultData
25+
from .settings import ModelSettings, merge_model_settings
2526
from .tools import (
2627
AgentDeps,
2728
RunContext,
@@ -81,6 +82,13 @@ class Agent(Generic[AgentDeps, ResultData]):
8182
end_strategy: EndStrategy
8283
"""Strategy for handling tool calls when a final result is found."""
8384

85+
model_settings: ModelSettings | None = None
86+
"""Optional model request settings to use for this agents's runs, by default.
87+
88+
Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will
89+
be merged with this value, with the runtime argument taking priority.
90+
"""
91+
8492
last_run_messages: list[_messages.Message] | None = None
8593
"""The messages from the last run, useful when a run raised an exception.
8694
@@ -108,6 +116,7 @@ def __init__(
108116
system_prompt: str | Sequence[str] = (),
109117
deps_type: type[AgentDeps] = NoneType,
110118
name: str | None = None,
119+
model_settings: ModelSettings | None = None,
111120
retries: int = 1,
112121
result_tool_name: str = 'final_result',
113122
result_tool_description: str | None = None,
@@ -130,6 +139,7 @@ def __init__(
130139
or add a type hint `: Agent[None, <return type>]`.
131140
name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame
132141
when the agent is first run.
142+
model_settings: Optional model request settings to use for this agent's runs, by default.
133143
retries: The default number of retries to allow before raising an error.
134144
result_tool_name: The name of the tool to use for the final result.
135145
result_tool_description: The description of the final result tool.
@@ -151,6 +161,7 @@ def __init__(
151161

152162
self.end_strategy = end_strategy
153163
self.name = name
164+
self.model_settings = model_settings
154165
self._result_schema = _result.ResultSchema[result_type].build(
155166
result_type, result_tool_name, result_tool_description
156167
)
@@ -178,6 +189,7 @@ async def run(
178189
message_history: list[_messages.Message] | None = None,
179190
model: models.Model | models.KnownModelName | None = None,
180191
deps: AgentDeps = None,
192+
model_settings: ModelSettings | None = None,
181193
infer_name: bool = True,
182194
) -> result.RunResult[ResultData]:
183195
"""Run the agent with a user prompt in async mode.
@@ -199,6 +211,7 @@ async def run(
199211
model: Optional model to use for this run, required if `model` was not set when creating the agent.
200212
deps: Optional dependencies to use for this run.
201213
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
214+
model_settings: Optional settings to use for this model's request.
202215
203216
Returns:
204217
The result of the run.
@@ -225,14 +238,16 @@ async def run(
225238

226239
cost = result.Cost()
227240

241+
model_settings = merge_model_settings(self.model_settings, model_settings)
242+
228243
run_step = 0
229244
while True:
230245
run_step += 1
231246
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
232247
agent_model = await self._prepare_model(model_used, deps)
233248

234249
with _logfire.span('model request', run_step=run_step) as model_req_span:
235-
model_response, request_cost = await agent_model.request(messages)
250+
model_response, request_cost = await agent_model.request(messages, model_settings)
236251
model_req_span.set_attribute('response', model_response)
237252
model_req_span.set_attribute('cost', request_cost)
238253
model_req_span.message = f'model request -> {model_response.role}'
@@ -267,6 +282,7 @@ def run_sync(
267282
message_history: list[_messages.Message] | None = None,
268283
model: models.Model | models.KnownModelName | None = None,
269284
deps: AgentDeps = None,
285+
model_settings: ModelSettings | None = None,
270286
infer_name: bool = True,
271287
) -> result.RunResult[ResultData]:
272288
"""Run the agent with a user prompt synchronously.
@@ -291,6 +307,7 @@ async def main():
291307
model: Optional model to use for this run, required if `model` was not set when creating the agent.
292308
deps: Optional dependencies to use for this run.
293309
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
310+
model_settings: Optional settings to use for this model's request.
294311
295312
Returns:
296313
The result of the run.
@@ -299,7 +316,14 @@ async def main():
299316
self._infer_name(inspect.currentframe())
300317
loop = asyncio.get_event_loop()
301318
return loop.run_until_complete(
302-
self.run(user_prompt, message_history=message_history, model=model, deps=deps, infer_name=False)
319+
self.run(
320+
user_prompt,
321+
message_history=message_history,
322+
model=model,
323+
deps=deps,
324+
infer_name=False,
325+
model_settings=model_settings,
326+
)
303327
)
304328

305329
@asynccontextmanager
@@ -310,6 +334,7 @@ async def run_stream(
310334
message_history: list[_messages.Message] | None = None,
311335
model: models.Model | models.KnownModelName | None = None,
312336
deps: AgentDeps = None,
337+
model_settings: ModelSettings | None = None,
313338
infer_name: bool = True,
314339
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
315340
"""Run the agent with a user prompt in async mode, returning a streamed response.
@@ -332,6 +357,7 @@ async def main():
332357
model: Optional model to use for this run, required if `model` was not set when creating the agent.
333358
deps: Optional dependencies to use for this run.
334359
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
360+
model_settings: Optional settings to use for this model's request.
335361
336362
Returns:
337363
The result of the run.
@@ -359,6 +385,7 @@ async def main():
359385
tool.current_retry = 0
360386

361387
cost = result.Cost()
388+
model_settings = merge_model_settings(self.model_settings, model_settings)
362389

363390
run_step = 0
364391
while True:
@@ -368,7 +395,7 @@ async def main():
368395
agent_model = await self._prepare_model(model_used, deps)
369396

370397
with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
371-
async with agent_model.request_stream(messages) as model_response:
398+
async with agent_model.request_stream(messages, model_settings) as model_response:
372399
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
373400
# We want to end the "model request" span here, but we can't exit the context manager
374401
# in the traditional way

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from ..exceptions import UserError
1919
from ..messages import Message, ModelAnyResponse, ModelStructuredResponse
20+
from ..settings import ModelSettings
2021

2122
if TYPE_CHECKING:
2223
from ..result import Cost
@@ -113,12 +114,16 @@ class AgentModel(ABC):
113114
"""Model configured for each step of an Agent run."""
114115

115116
@abstractmethod
116-
async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, Cost]:
117+
async def request(
118+
self, messages: list[Message], model_settings: ModelSettings | None
119+
) -> tuple[ModelAnyResponse, Cost]:
117120
"""Make a request to the model."""
118121
raise NotImplementedError()
119122

120123
@asynccontextmanager
121-
async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
124+
async def request_stream(
125+
self, messages: list[Message], model_settings: ModelSettings | None
126+
) -> AsyncIterator[EitherStreamedResponse]:
122127
"""Make a request to the model and return a streaming response."""
123128
raise NotImplementedError(f'Streamed requests not supported by this {self.__class__.__name__}')
124129
# yield is required to make this a generator for type checking

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
ModelTextResponse,
2020
ToolCall,
2121
)
22+
from ..settings import ModelSettings
2223
from ..tools import ToolDefinition
2324
from . import (
2425
AgentModel,
@@ -151,28 +152,34 @@ class AnthropicAgentModel(AgentModel):
151152
allow_text_result: bool
152153
tools: list[ToolParam]
153154

154-
async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, result.Cost]:
155-
response = await self._messages_create(messages, False)
155+
async def request(
156+
self, messages: list[Message], model_settings: ModelSettings | None
157+
) -> tuple[ModelAnyResponse, result.Cost]:
158+
response = await self._messages_create(messages, False, model_settings)
156159
return self._process_response(response), _map_cost(response)
157160

158161
@asynccontextmanager
159-
async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
160-
response = await self._messages_create(messages, True)
162+
async def request_stream(
163+
self, messages: list[Message], model_settings: ModelSettings | None
164+
) -> AsyncIterator[EitherStreamedResponse]:
165+
response = await self._messages_create(messages, True, model_settings)
161166
async with response:
162167
yield await self._process_streamed_response(response)
163168

164169
@overload
165170
async def _messages_create(
166-
self, messages: list[Message], stream: Literal[True]
171+
self, messages: list[Message], stream: Literal[True], model_settings: ModelSettings | None
167172
) -> AsyncStream[RawMessageStreamEvent]:
168173
pass
169174

170175
@overload
171-
async def _messages_create(self, messages: list[Message], stream: Literal[False]) -> AnthropicMessage:
176+
async def _messages_create(
177+
self, messages: list[Message], stream: Literal[False], model_settings: ModelSettings | None
178+
) -> AnthropicMessage:
172179
pass
173180

174181
async def _messages_create(
175-
self, messages: list[Message], stream: bool
182+
self, messages: list[Message], stream: bool, model_settings: ModelSettings | None
176183
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
177184
# standalone function to make it easier to override
178185
if not self.tools:
@@ -191,15 +198,19 @@ async def _messages_create(
191198
else:
192199
anthropic_messages.append(self._map_message(m))
193200

201+
model_settings = model_settings or {}
202+
194203
return await self.client.messages.create(
195-
max_tokens=1024,
204+
max_tokens=model_settings.get('max_tokens', 1024),
196205
system=system_prompt or NOT_GIVEN,
197206
messages=anthropic_messages,
198207
model=self.model_name,
199-
temperature=0.0,
200208
tools=self.tools or NOT_GIVEN,
201209
tool_choice=tool_choice or NOT_GIVEN,
202210
stream=stream,
211+
temperature=model_settings.get('temperature', NOT_GIVEN),
212+
top_p=model_settings.get('top_p', NOT_GIVEN),
213+
timeout=model_settings.get('timeout', NOT_GIVEN),
203214
)
204215

205216
@staticmethod

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import re
55
from collections.abc import AsyncIterator, Awaitable, Iterable
66
from contextlib import asynccontextmanager
7-
from dataclasses import dataclass, field
7+
from dataclasses import dataclass, field, replace
88
from datetime import datetime
99
from itertools import chain
1010
from typing import Callable, Union, cast
@@ -14,6 +14,7 @@
1414

1515
from .. import _utils, result
1616
from ..messages import ArgsJson, Message, ModelAnyResponse, ModelStructuredResponse, ToolCall
17+
from ..settings import ModelSettings
1718
from ..tools import ToolDefinition
1819
from . import AgentModel, EitherStreamedResponse, Model, StreamStructuredResponse, StreamTextResponse
1920

@@ -59,7 +60,7 @@ async def agent_model(
5960
result_tools: list[ToolDefinition],
6061
) -> AgentModel:
6162
return FunctionAgentModel(
62-
self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools)
63+
self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools, None)
6364
)
6465

6566
def name(self) -> str:
@@ -88,6 +89,8 @@ class AgentInfo:
8889
"""Whether a plain text result is allowed."""
8990
result_tools: list[ToolDefinition]
9091
"""The tools that can called as the final result of the run."""
92+
model_settings: ModelSettings | None
93+
"""The model settings passed to the run call."""
9194

9295

9396
@dataclass
@@ -127,18 +130,24 @@ class FunctionAgentModel(AgentModel):
127130
stream_function: StreamFunctionDef | None
128131
agent_info: AgentInfo
129132

130-
async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, result.Cost]:
133+
async def request(
134+
self, messages: list[Message], model_settings: ModelSettings | None
135+
) -> tuple[ModelAnyResponse, result.Cost]:
136+
agent_info = replace(self.agent_info, model_settings=model_settings)
137+
131138
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
132139
if inspect.iscoroutinefunction(self.function):
133-
response = await self.function(messages, self.agent_info)
140+
response = await self.function(messages, agent_info)
134141
else:
135-
response_ = await _utils.run_in_executor(self.function, messages, self.agent_info)
142+
response_ = await _utils.run_in_executor(self.function, messages, agent_info)
136143
response = cast(ModelAnyResponse, response_)
137144
# TODO is `messages` right here? Should it just be new messages?
138145
return response, _estimate_cost(chain(messages, [response]))
139146

140147
@asynccontextmanager
141-
async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
148+
async def request_stream(
149+
self, messages: list[Message], model_settings: ModelSettings | None
150+
) -> AsyncIterator[EitherStreamedResponse]:
142151
assert (
143152
self.stream_function is not None
144153
), 'FunctionModel must receive a `stream_function` to support streamed requests'

0 commit comments

Comments
 (0)