Skip to content

Commit bda36ee

Browse files
authored
Update agent.py
1 parent 6bb2e83 commit bda36ee

File tree

1 file changed

+66
-94
lines changed

1 file changed

+66
-94
lines changed

src/agents/agent.py

Lines changed: 66 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import asyncio
1111
import inspect
1212
from collections.abc import Awaitable
13-
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
13+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
1414

1515
from openai.types.responses.response_prompt_param import ResponsePromptParam
1616
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
@@ -36,29 +36,25 @@
3636
class ToolsToFinalOutputResult(BaseModel):
3737
"""
3838
Result object for tool-to-final output processing.
39-
39+
4040
Attributes:
4141
is_final_output: Whether this is the final output. If False, the LLM will run again.
4242
final_output: The final output value, if is_final_output is True.
4343
"""
44-
44+
4545
model_config = ConfigDict(
46-
arbitrary_types_allowed=True,
47-
validate_assignment=True,
48-
extra='forbid',
49-
frozen=True
46+
arbitrary_types_allowed=True, validate_assignment=True, extra="forbid", frozen=True
5047
)
51-
48+
5249
is_final_output: bool = Field(
5350
..., # Required field
54-
description="Whether this is the final output. If False, the LLM will run again."
51+
description="Whether this is the final output. If False, the LLM will run again.",
5552
)
5653
final_output: Any | None = Field(
57-
default=None,
58-
description="The final output value, if is_final_output is True."
54+
default=None, description="The final output value, if is_final_output is True."
5955
)
60-
61-
@field_validator('is_final_output')
56+
57+
@field_validator("is_final_output")
6258
@classmethod
6359
def validate_is_final_output(cls, v: Any) -> bool:
6460
"""Validate is_final_output is a boolean value."""
@@ -79,27 +75,29 @@ def validate_is_final_output(cls, v: Any) -> bool:
7975

8076
class StopAtTools(TypedDict):
8177
"""Configuration for stopping agent execution at specific tools."""
78+
8279
stop_at_tool_names: list[str]
8380

8481

8582
class MCPConfig(TypedDict):
8683
"""Configuration for Model Context Protocol servers."""
84+
8785
convert_schemas_to_strict: NotRequired[bool]
8886

8987

9088
def is_stop_at_tools_dict(v: dict) -> bool:
9189
"""
9290
Validate if a dictionary matches the StopAtTools structure.
93-
91+
9492
Args:
9593
v: Dictionary to validate
96-
94+
9795
Returns:
9896
bool: True if dictionary is valid StopAtTools structure
9997
"""
10098
return (
101-
isinstance(v, dict)
102-
and "stop_at_tool_names" in v
99+
isinstance(v, dict)
100+
and "stop_at_tool_names" in v
103101
and isinstance(v["stop_at_tool_names"], list)
104102
and all(isinstance(x, str) for x in v["stop_at_tool_names"])
105103
)
@@ -108,77 +106,71 @@ def is_stop_at_tools_dict(v: dict) -> bool:
108106
class AgentBase(BaseModel, Generic[TContext]):
109107
"""
110108
Base class for Agent implementations providing core functionality.
111-
109+
112110
This class implements the base agent functionality including tool management,
113111
MCP server configuration, and validation.
114-
112+
115113
Attributes:
116114
name: The name of the agent
117115
handoff_description: Optional description for handoff functionality
118116
tools: List of available tools
119117
mcp_servers: List of MCP servers
120118
mcp_config: Configuration for MCP servers
121119
"""
122-
120+
123121
model_config = ConfigDict(
124122
arbitrary_types_allowed=True,
125123
validate_assignment=True,
126-
extra='forbid',
124+
extra="forbid",
127125
frozen=True,
128-
defer_build=True
126+
defer_build=True,
129127
)
130128

131129
name: str = Field(
132130
..., # Required field
133-
description="The name of the agent."
131+
description="The name of the agent.",
134132
)
135133
handoff_description: str | None = Field(
136-
default=None,
137-
description="Description used when the agent is used as a handoff."
134+
default=None, description="Description used when the agent is used as a handoff."
138135
)
139136
tools: list[Tool] = Field(
140-
default_factory=list,
141-
description="List of tools available to the agent."
137+
default_factory=list, description="List of tools available to the agent."
142138
)
143139
mcp_servers: list[Any] = Field(
144-
default_factory=list,
145-
description="List of MCP servers available to the agent."
140+
default_factory=list, description="List of MCP servers available to the agent."
146141
)
147142
mcp_config: dict[str, Any] = Field(
148143
default_factory=lambda: {"convert_schemas_to_strict": False},
149-
description="MCP configuration settings."
144+
description="MCP configuration settings.",
150145
)
151146

152147
# ... (validators remain the same but with improved docstrings) ...
153148

154149
async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
155150
"""
156151
Fetch available tools from MCP servers.
157-
152+
158153
Args:
159154
run_context: Current run context wrapper
160-
155+
161156
Returns:
162157
list[Tool]: List of available MCP tools
163158
"""
164159
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
165160
if not self.mcp_servers:
166161
return []
167-
162+
168163
return await MCPUtil.get_all_function_tools(
169-
self.mcp_servers,
170-
convert_schemas_to_strict,
171-
run_context,
172-
self
164+
self.mcp_servers, convert_schemas_to_strict, run_context, self
173165
)
174166

175167
async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
176168
"""
177169
Get all available tools, including MCP tools and function tools.
178-
170+
179171
Args:
180172
run_context: Current run context wrapper
181-
173+
182174
Returns:
183175
list[Tool]: Combined list of available tools
184176
"""
@@ -205,10 +197,10 @@ async def _check_tool_enabled(tool: Tool) -> bool:
205197
class Agent(AgentBase, Generic[TContext]):
206198
"""
207199
Primary agent implementation with full configuration and behavior support.
208-
200+
209201
This class extends AgentBase with additional functionality for instructions,
210202
prompts, guardrails, and tool behavior configuration.
211-
203+
212204
Attributes:
213205
instructions: Agent instructions/system prompt
214206
prompt: Prompt configuration
@@ -222,94 +214,73 @@ class Agent(AgentBase, Generic[TContext]):
222214
tool_use_behavior: Tool usage configuration
223215
reset_tool_choice: Tool choice reset behavior
224216
"""
225-
217+
226218
model_config = ConfigDict(
227219
arbitrary_types_allowed=True,
228220
validate_assignment=True,
229-
extra='forbid',
221+
extra="forbid",
230222
frozen=True,
231-
defer_build=True
223+
defer_build=True,
232224
)
233225

234226
instructions: (
235227
str
236228
| Callable[
237-
[RunContextWrapper[TContext], 'Agent[TContext]'],
229+
[RunContextWrapper[TContext], Agent[TContext]],
238230
MaybeAwaitable[str],
239231
]
240232
| None
241-
) = Field(
242-
default=None,
243-
description="The instructions/system prompt for the agent."
244-
)
233+
) = Field(default=None, description="The instructions/system prompt for the agent.")
245234

246235
prompt: Prompt | DynamicPromptFunction | None = Field(
247-
default=None,
248-
description="Prompt configuration for the agent."
236+
default=None, description="Prompt configuration for the agent."
249237
)
250238

251239
handoffs: list[Any] = Field(
252-
default_factory=list,
253-
description="List of sub-agents for delegation."
240+
default_factory=list, description="List of sub-agents for delegation."
254241
)
255242

256-
model: str | Model | None = Field(
257-
default=None,
258-
description="Model implementation to use."
259-
)
243+
model: str | Model | None = Field(default=None, description="Model implementation to use.")
260244

261245
model_settings: ModelSettings = Field(
262-
default_factory=ModelSettings,
263-
description="Model-specific configuration."
246+
default_factory=ModelSettings, description="Model-specific configuration."
264247
)
265248

266249
input_guardrails: list[InputGuardrail[TContext]] = Field(
267-
default_factory=list,
268-
description="Pre-execution validation checks."
250+
default_factory=list, description="Pre-execution validation checks."
269251
)
270252

271253
output_guardrails: list[OutputGuardrail[TContext]] = Field(
272-
default_factory=list,
273-
description="Post-execution validation checks."
254+
default_factory=list, description="Post-execution validation checks."
274255
)
275256

276257
output_type: type[Any] | AgentOutputSchemaBase | None = Field(
277-
default=None,
278-
description="Output type specification."
258+
default=None, description="Output type specification."
279259
)
280260

281-
hooks: Any | None = Field(
282-
default=None,
283-
description="Lifecycle event callbacks."
284-
)
261+
hooks: Any | None = Field(default=None, description="Lifecycle event callbacks.")
285262

286-
tool_use_behavior: Union[
287-
Literal["run_llm_again", "stop_on_first_tool"],
288-
StopAtTools,
289-
ToolsToFinalOutputFunction
290-
] = Field(
291-
default="run_llm_again",
292-
description="Tool usage behavior configuration."
293-
)
263+
tool_use_behavior: (
264+
Literal["run_llm_again", "stop_on_first_tool"] | StopAtTools | ToolsToFinalOutputFunction
265+
) = Field(default="run_llm_again", description="Tool usage behavior configuration.")
294266

295267
reset_tool_choice: bool = Field(
296-
default=True,
297-
description="Whether to reset tool choice after use."
268+
default=True, description="Whether to reset tool choice after use."
298269
)
299270

300271
# ... (validators remain the same but with improved docstrings) ...
301272

302-
@model_validator(mode='after')
303-
def validate_model_configuration(self) -> 'Agent':
273+
@model_validator(mode="after")
274+
def validate_model_configuration(self) -> Agent:
304275
"""
305276
Validate complete model configuration.
306-
277+
307278
This validator ensures that the model configuration is consistent,
308279
particularly the relationship between model and prompt settings.
309-
280+
310281
Returns:
311282
Agent: The validated agent instance
312-
283+
313284
Raises:
314285
ValueError: If validation fails
315286
"""
@@ -320,13 +291,13 @@ def validate_model_configuration(self) -> 'Agent':
320291
def clone(self, **kwargs: Any) -> Agent[TContext]:
321292
"""
322293
Create a copy of the agent with specified modifications.
323-
294+
324295
Args:
325296
**kwargs: Fields to override in the new instance
326-
297+
327298
Returns:
328299
Agent[TContext]: New agent instance with specified changes
329-
300+
330301
Note:
331302
This performs a shallow copy. Mutable attributes like tools and
332303
handoffs are shallow-copied.
@@ -341,15 +312,16 @@ def as_tool(
341312
) -> Tool:
342313
"""
343314
Convert this agent into a tool callable by other agents.
344-
315+
345316
Args:
346317
tool_name: Name for the tool (defaults to agent name)
347318
tool_description: Description of the tool's functionality
348319
custom_output_extractor: Custom function to extract output
349-
320+
350321
Returns:
351322
Tool: Tool instance wrapping this agent
352323
"""
324+
353325
@function_tool(
354326
name_override=tool_name or _transforms.transform_string_function_style(self.name),
355327
description_override=tool_description or "",
@@ -372,13 +344,13 @@ async def run_agent(context: RunContextWrapper, input: str) -> str:
372344
async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None:
373345
"""
374346
Get the system prompt for the agent.
375-
347+
376348
Args:
377349
run_context: Current run context
378-
350+
379351
Returns:
380352
str | None: System prompt if available
381-
353+
382354
Note:
383355
Handles both static strings and dynamic prompt generation.
384356
"""
@@ -397,10 +369,10 @@ async def get_prompt(
397369
) -> ResponsePromptParam | None:
398370
"""
399371
Get the prompt configuration for the agent.
400-
372+
401373
Args:
402374
run_context: Current run context
403-
375+
404376
Returns:
405377
ResponsePromptParam | None: Prompt configuration if available
406378
"""

0 commit comments

Comments
 (0)