Skip to content

Commit 35a097d

Browse files
authored
Remove the AgentModel class (#800)
1 parent 962e2ab commit 35a097d

File tree

21 files changed

+643
-683
lines changed

21 files changed

+643
-683
lines changed

docs/api/models/base.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
options:
55
members:
66
- KnownModelName
7+
- ModelRequestParameters
78
- Model
8-
- AgentModel
99
- AbstractToolDefinition
1010
- StreamedResponse
1111
- ALLOW_MODEL_REQUESTS

docs/api/models/test.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ async def test_my_agent():
1717
with my_agent.override(model=m):
1818
result = await my_agent.run('Testing my agent...')
1919
assert result.data == 'success (no tool calls)'
20-
assert m.agent_model_function_tools == []
20+
assert m.last_model_request_parameters.function_tools == []
2121
```
2222

2323
See [Unit testing with `TestModel`](../../testing-evals.md#unit-testing-with-testmodel) for detailed documentation.

docs/api/models/vertexai.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
Custom interface to the `*-aiplatform.googleapis.com` API for Gemini models.
44

5-
This model uses [`GeminiAgentModel`][pydantic_ai.models.gemini.GeminiAgentModel] with just the URL and auth method
6-
changed from [`GeminiModel`][pydantic_ai.models.gemini.GeminiModel], it relies on the VertexAI
5+
This model inherits from [`GeminiModel`][pydantic_ai.models.gemini.GeminiModel] with just the URL and auth method
6+
changed, it relies on the VertexAI
77
[`generateContent`](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/generateContent)
88
and
99
[`streamGenerateContent`](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/streamGenerateContent)

docs/models.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -628,9 +628,8 @@ agent = Agent(model)
628628

629629
To implement support for models not already supported, you will need to subclass the [`Model`][pydantic_ai.models.Model] abstract base class.
630630

631-
This in turn will require you to implement the following other abstract base classes:
631+
For streaming, you'll also need to implement the following abstract base class:
632632

633-
* [`AgentModel`][pydantic_ai.models.AgentModel]
634633
* [`StreamedResponse`][pydantic_ai.models.StreamedResponse]
635634

636635
The best place to start is to review the source code for existing implementations, e.g. [`OpenAIModel`](https://github.com/pydantic/pydantic-ai/blob/main/pydantic_ai_slim/pydantic_ai/models/openai.py).

docs/tools.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ The return type of tool can be anything which Pydantic can serialize to JSON as
292292

293293
If a tool has a single parameter that can be represented as an object in JSON schema (e.g. dataclass, TypedDict, pydantic model), the schema for the tool is simplified to be just that object.
294294

295-
Here's an example, we use [`TestModel.agent_model_function_tools`][pydantic_ai.models.test.TestModel.agent_model_function_tools] to inspect the tool schema that would be passed to the model.
295+
Here's an example where we use [`TestModel.last_model_request_parameters`][pydantic_ai.models.test.TestModel.last_model_request_parameters] to inspect the tool schema that would be passed to the model.
296296

297297
```python {title="single_parameter_tool.py"}
298298
from pydantic import BaseModel
@@ -320,7 +320,7 @@ test_model = TestModel()
320320
result = agent.run_sync('hello', model=test_model)
321321
print(result.data)
322322
#> {"foobar":"x=0 y='a' z=3.14"}
323-
print(test_model.agent_model_function_tools)
323+
print(test_model.last_model_request_parameters.function_tools)
324324
"""
325325
[
326326
ToolDefinition(
@@ -425,7 +425,7 @@ agent = Agent(test_model, tools=[greet_tool], deps_type=Literal['human', 'machin
425425
result = agent.run_sync('testing...', deps='human')
426426
print(result.data)
427427
#> {"greet":"hello a"}
428-
print(test_model.agent_model_function_tools)
428+
print(test_model.last_model_request_parameters.function_tools)
429429
"""
430430
[
431431
ToolDefinition(

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,9 @@ async def run(
204204
return StreamModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
205205

206206

207-
async def _prepare_model(
207+
async def _prepare_request_parameters(
208208
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
209-
) -> models.AgentModel:
209+
) -> models.ModelRequestParameters:
210210
"""Build tools and create an agent model."""
211211
function_tool_defs: list[ToolDefinition] = []
212212

@@ -220,7 +220,7 @@ async def add_tool(tool: Tool[DepsT]) -> None:
220220
await asyncio.gather(*map(add_tool, ctx.deps.function_tools.values()))
221221

222222
result_schema = ctx.deps.result_schema
223-
return await run_context.model.agent_model(
223+
return models.ModelRequestParameters(
224224
function_tools=function_tool_defs,
225225
allow_text_result=_allow_text_result(result_schema),
226226
result_tools=result_schema.tool_defs() if result_schema is not None else [],
@@ -245,13 +245,15 @@ async def run(
245245
# Increment run_step
246246
ctx.state.run_step += 1
247247

248-
with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step):
249-
agent_model = await _prepare_model(ctx)
248+
with _logfire.span('preparing model request params {run_step=}', run_step=ctx.state.run_step):
249+
model_request_parameters = await _prepare_request_parameters(ctx)
250250

251251
# Actually make the model request
252252
model_settings = merge_model_settings(ctx.deps.model_settings, None)
253253
with _logfire.span('model request') as span:
254-
model_response, request_usage = await agent_model.request(ctx.state.message_history, model_settings)
254+
model_response, request_usage = await ctx.deps.model.request(
255+
ctx.state.message_history, model_settings, model_request_parameters
256+
)
255257
span.set_attribute('response', model_response)
256258
span.set_attribute('usage', request_usage)
257259

@@ -405,12 +407,14 @@ async def run_to_result(
405407
ctx.state.run_step += 1
406408

407409
with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step):
408-
agent_model = await _prepare_model(ctx)
410+
model_request_parameters = await _prepare_request_parameters(ctx)
409411

410412
# Actually make the model request
411413
model_settings = merge_model_settings(ctx.deps.model_settings, None)
412414
with _logfire.span('model request {run_step=}', run_step=ctx.state.run_step) as model_req_span:
413-
async with agent_model.request_stream(ctx.state.message_history, model_settings) as streamed_response:
415+
async with ctx.deps.model.request_stream(
416+
ctx.state.message_history, model_settings, model_request_parameters
417+
) as streamed_response:
414418
ctx.state.usage.requests += 1
415419
model_req_span.set_attribute('response_type', streamed_response.__class__.__name__)
416420
# We want to end the "model request" span here, but we can't exit the context manager

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -161,49 +161,38 @@
161161
"""
162162

163163

164+
@dataclass
165+
class ModelRequestParameters:
166+
"""Configuration for an agent's request to a model, specifically related to tools and result handling."""
167+
168+
function_tools: list[ToolDefinition]
169+
allow_text_result: bool
170+
result_tools: list[ToolDefinition]
171+
172+
164173
class Model(ABC):
165174
"""Abstract class for a model."""
166175

167-
@abstractmethod
168-
async def agent_model(
169-
self,
170-
*,
171-
function_tools: list[ToolDefinition],
172-
allow_text_result: bool,
173-
result_tools: list[ToolDefinition],
174-
) -> AgentModel:
175-
"""Create an agent model, this is called for each step of an agent run.
176-
177-
This is async in case slow/async config checks need to be performed that can't be done in `__init__`.
178-
179-
Args:
180-
function_tools: The tools available to the agent.
181-
allow_text_result: Whether a plain text final response/result is permitted.
182-
result_tools: Tool definitions for the final result tool(s), if any.
183-
184-
Returns:
185-
An agent model.
186-
"""
187-
raise NotImplementedError()
188-
189176
@abstractmethod
190177
def name(self) -> str:
191178
raise NotImplementedError()
192179

193-
194-
class AgentModel(ABC):
195-
"""Model configured for each step of an Agent run."""
196-
197180
@abstractmethod
198181
async def request(
199-
self, messages: list[ModelMessage], model_settings: ModelSettings | None
182+
self,
183+
messages: list[ModelMessage],
184+
model_settings: ModelSettings | None,
185+
model_request_parameters: ModelRequestParameters,
200186
) -> tuple[ModelResponse, Usage]:
201187
"""Make a request to the model."""
202188
raise NotImplementedError()
203189

204190
@asynccontextmanager
205191
async def request_stream(
206-
self, messages: list[ModelMessage], model_settings: ModelSettings | None
192+
self,
193+
messages: list[ModelMessage],
194+
model_settings: ModelSettings | None,
195+
model_request_parameters: ModelRequestParameters,
207196
) -> AsyncIterator[StreamedResponse]:
208197
"""Make a request to the model and return a streaming response."""
209198
# This method is not required, but you need to implement it if you want to support streamed responses
@@ -274,7 +263,7 @@ def check_allow_model_requests() -> None:
274263
"""Check if model requests are allowed.
275264
276265
If you're defining your own models that have costs or latency associated with their use, you should call this in
277-
[`Model.agent_model`][pydantic_ai.models.Model.agent_model].
266+
[`Model.request`][pydantic_ai.models.Model.request] and [`Model.request_stream`][pydantic_ai.models.Model.request_stream].
278267
279268
Raises:
280269
RuntimeError: If model requests are not allowed.

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 59 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
from ..settings import ModelSettings
2929
from ..tools import ToolDefinition
3030
from . import (
31-
AgentModel,
3231
Model,
32+
ModelRequestParameters,
3333
StreamedResponse,
3434
cached_async_http_client,
3535
check_allow_model_requests,
@@ -134,81 +134,70 @@ def __init__(
134134
else:
135135
self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())
136136

137-
async def agent_model(
138-
self,
139-
*,
140-
function_tools: list[ToolDefinition],
141-
allow_text_result: bool,
142-
result_tools: list[ToolDefinition],
143-
) -> AgentModel:
144-
check_allow_model_requests()
145-
tools = [self._map_tool_definition(r) for r in function_tools]
146-
if result_tools:
147-
tools += [self._map_tool_definition(r) for r in result_tools]
148-
return AnthropicAgentModel(
149-
self.client,
150-
self.model_name,
151-
allow_text_result,
152-
tools,
153-
)
154-
155137
def name(self) -> str:
156138
return f'anthropic:{self.model_name}'
157139

158-
@staticmethod
159-
def _map_tool_definition(f: ToolDefinition) -> ToolParam:
160-
return {
161-
'name': f.name,
162-
'description': f.description,
163-
'input_schema': f.parameters_json_schema,
164-
}
165-
166-
167-
@dataclass
168-
class AnthropicAgentModel(AgentModel):
169-
"""Implementation of `AgentModel` for Anthropic models."""
170-
171-
client: AsyncAnthropic
172-
model_name: AnthropicModelName
173-
allow_text_result: bool
174-
tools: list[ToolParam]
175-
176140
async def request(
177-
self, messages: list[ModelMessage], model_settings: ModelSettings | None
141+
self,
142+
messages: list[ModelMessage],
143+
model_settings: ModelSettings | None,
144+
model_request_parameters: ModelRequestParameters,
178145
) -> tuple[ModelResponse, usage.Usage]:
179-
response = await self._messages_create(messages, False, cast(AnthropicModelSettings, model_settings or {}))
146+
check_allow_model_requests()
147+
response = await self._messages_create(
148+
messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
149+
)
180150
return self._process_response(response), _map_usage(response)
181151

182152
@asynccontextmanager
183153
async def request_stream(
184-
self, messages: list[ModelMessage], model_settings: ModelSettings | None
154+
self,
155+
messages: list[ModelMessage],
156+
model_settings: ModelSettings | None,
157+
model_request_parameters: ModelRequestParameters,
185158
) -> AsyncIterator[StreamedResponse]:
186-
response = await self._messages_create(messages, True, cast(AnthropicModelSettings, model_settings or {}))
159+
check_allow_model_requests()
160+
response = await self._messages_create(
161+
messages, True, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
162+
)
187163
async with response:
188164
yield await self._process_streamed_response(response)
189165

190166
@overload
191167
async def _messages_create(
192-
self, messages: list[ModelMessage], stream: Literal[True], model_settings: AnthropicModelSettings
168+
self,
169+
messages: list[ModelMessage],
170+
stream: Literal[True],
171+
model_settings: AnthropicModelSettings,
172+
model_request_parameters: ModelRequestParameters,
193173
) -> AsyncStream[RawMessageStreamEvent]:
194174
pass
195175

196176
@overload
197177
async def _messages_create(
198-
self, messages: list[ModelMessage], stream: Literal[False], model_settings: AnthropicModelSettings
178+
self,
179+
messages: list[ModelMessage],
180+
stream: Literal[False],
181+
model_settings: AnthropicModelSettings,
182+
model_request_parameters: ModelRequestParameters,
199183
) -> AnthropicMessage:
200184
pass
201185

202186
async def _messages_create(
203-
self, messages: list[ModelMessage], stream: bool, model_settings: AnthropicModelSettings
187+
self,
188+
messages: list[ModelMessage],
189+
stream: bool,
190+
model_settings: AnthropicModelSettings,
191+
model_request_parameters: ModelRequestParameters,
204192
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
205193
# standalone function to make it easier to override
194+
tools = self._get_tools(model_request_parameters)
206195
tool_choice: ToolChoiceParam | None
207196

208-
if not self.tools:
197+
if not tools:
209198
tool_choice = None
210199
else:
211-
if not self.allow_text_result:
200+
if not model_request_parameters.allow_text_result:
212201
tool_choice = {'type': 'any'}
213202
else:
214203
tool_choice = {'type': 'auto'}
@@ -223,7 +212,7 @@ async def _messages_create(
223212
system=system_prompt or NOT_GIVEN,
224213
messages=anthropic_messages,
225214
model=self.model_name,
226-
tools=self.tools or NOT_GIVEN,
215+
tools=tools or NOT_GIVEN,
227216
tool_choice=tool_choice or NOT_GIVEN,
228217
stream=stream,
229218
temperature=model_settings.get('temperature', NOT_GIVEN),
@@ -260,8 +249,13 @@ async def _process_streamed_response(self, response: AsyncStream[RawMessageStrea
260249
timestamp = datetime.now(tz=timezone.utc)
261250
return AnthropicStreamedResponse(_model_name=self.model_name, _response=peekable_response, _timestamp=timestamp)
262251

263-
@staticmethod
264-
def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
252+
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolParam]:
253+
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
254+
if model_request_parameters.result_tools:
255+
tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
256+
return tools
257+
258+
def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
265259
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
266260
system_prompt: str = ''
267261
anthropic_messages: list[MessageParam] = []
@@ -310,20 +304,28 @@ def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]
310304
content.append(TextBlockParam(text=item.content, type='text'))
311305
else:
312306
assert isinstance(item, ToolCallPart)
313-
content.append(_map_tool_call(item))
307+
content.append(self._map_tool_call(item))
314308
anthropic_messages.append(MessageParam(role='assistant', content=content))
315309
else:
316310
assert_never(m)
317311
return system_prompt, anthropic_messages
318312

313+
@staticmethod
314+
def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
315+
return ToolUseBlockParam(
316+
id=_guard_tool_call_id(t=t, model_source='Anthropic'),
317+
type='tool_use',
318+
name=t.tool_name,
319+
input=t.args_as_dict(),
320+
)
319321

320-
def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
321-
return ToolUseBlockParam(
322-
id=_guard_tool_call_id(t=t, model_source='Anthropic'),
323-
type='tool_use',
324-
name=t.tool_name,
325-
input=t.args_as_dict(),
326-
)
322+
@staticmethod
323+
def _map_tool_definition(f: ToolDefinition) -> ToolParam:
324+
return {
325+
'name': f.name,
326+
'description': f.description,
327+
'input_schema': f.parameters_json_schema,
328+
}
327329

328330

329331
def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage:

0 commit comments

Comments
 (0)