Skip to content

Commit e81bbbc

Browse files
authored
Merge branch 'main' into issue-1750-run-hooks
2 parents 61bbbe7 + a4c125e commit e81bbbc

17 files changed

+563
-25
lines changed

examples/basic/tools.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
import asyncio
2+
from typing import Annotated
23

3-
from pydantic import BaseModel
4+
from pydantic import BaseModel, Field
45

56
from agents import Agent, Runner, function_tool
67

78

89
class Weather(BaseModel):
9-
city: str
10-
temperature_range: str
11-
conditions: str
10+
city: str = Field(description="The city name")
11+
temperature_range: str = Field(description="The temperature range in Celsius")
12+
conditions: str = Field(description="The weather conditions")
1213

1314

1415
@function_tool
15-
def get_weather(city: str) -> Weather:
16+
def get_weather(city: Annotated[str, "The city to get the weather for"]) -> Weather:
1617
"""Get the current weather information for a specified city."""
1718
print("[debug] get_weather called")
1819
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
1920

20-
2121
agent = Agent(
2222
name="Hello world",
2323
instructions="You are a helpful agent.",

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "openai-agents"
3-
version = "0.3.0"
3+
version = "0.3.1"
44
description = "OpenAI Agents SDK"
55
readme = "README.md"
66
requires-python = ">=3.9"

src/agents/agent.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
from .util._types import MaybeAwaitable
3131

3232
if TYPE_CHECKING:
33-
from .lifecycle import AgentHooks
33+
from .lifecycle import AgentHooks, RunHooks
3434
from .mcp import MCPServer
35+
from .memory.session import Session
3536
from .result import RunResult
37+
from .run import RunConfig
3638

3739

3840
@dataclass
@@ -384,6 +386,12 @@ def as_tool(
384386
custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None,
385387
is_enabled: bool
386388
| Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True,
389+
run_config: RunConfig | None = None,
390+
max_turns: int | None = None,
391+
hooks: RunHooks[TContext] | None = None,
392+
previous_response_id: str | None = None,
393+
conversation_id: str | None = None,
394+
session: Session | None = None,
387395
) -> Tool:
388396
"""Transform this agent into a tool, callable by other agents.
389397
@@ -410,12 +418,20 @@ def as_tool(
410418
is_enabled=is_enabled,
411419
)
412420
async def run_agent(context: RunContextWrapper, input: str) -> str:
413-
from .run import Runner
421+
from .run import DEFAULT_MAX_TURNS, Runner
422+
423+
resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS
414424

415425
output = await Runner.run(
416426
starting_agent=self,
417427
input=input,
418428
context=context.context,
429+
run_config=run_config,
430+
max_turns=resolved_max_turns,
431+
hooks=hooks,
432+
previous_response_id=previous_response_id,
433+
conversation_id=conversation_id,
434+
session=session,
419435
)
420436
if custom_output_extractor:
421437
return await custom_output_extractor(output)

src/agents/extensions/handoff_filters.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ..items import (
55
HandoffCallItem,
66
HandoffOutputItem,
7+
ReasoningItem,
78
RunItem,
89
ToolCallItem,
910
ToolCallOutputItem,
@@ -41,6 +42,7 @@ def _remove_tools_from_items(items: tuple[RunItem, ...]) -> tuple[RunItem, ...]:
4142
or isinstance(item, HandoffOutputItem)
4243
or isinstance(item, ToolCallItem)
4344
or isinstance(item, ToolCallOutputItem)
45+
or isinstance(item, ReasoningItem)
4446
):
4547
continue
4648
filtered_items.append(item)

src/agents/extensions/models/litellm_model.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from ...logger import logger
4040
from ...model_settings import ModelSettings
4141
from ...models.chatcmpl_converter import Converter
42-
from ...models.chatcmpl_helpers import HEADERS
42+
from ...models.chatcmpl_helpers import HEADERS, USER_AGENT_OVERRIDE
4343
from ...models.chatcmpl_stream_handler import ChatCmplStreamHandler
4444
from ...models.fake_id import FAKE_RESPONSES_ID
4545
from ...models.interface import Model, ModelTracing
@@ -353,7 +353,7 @@ async def _fetch_response(
353353
stream_options=stream_options,
354354
reasoning_effort=reasoning_effort,
355355
top_logprobs=model_settings.top_logprobs,
356-
extra_headers={**HEADERS, **(model_settings.extra_headers or {})},
356+
extra_headers=self._merge_headers(model_settings),
357357
api_key=self.api_key,
358358
base_url=self.base_url,
359359
**extra_kwargs,
@@ -384,6 +384,13 @@ def _remove_not_given(self, value: Any) -> Any:
384384
return None
385385
return value
386386

387+
def _merge_headers(self, model_settings: ModelSettings):
388+
merged = {**HEADERS, **(model_settings.extra_headers or {})}
389+
ua_ctx = USER_AGENT_OVERRIDE.get()
390+
if ua_ctx is not None:
391+
merged["User-Agent"] = ua_ctx
392+
return merged
393+
387394

388395
class LitellmConverter:
389396
@classmethod

src/agents/function_schema.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import re
77
from dataclasses import dataclass
8-
from typing import Any, Callable, Literal, get_args, get_origin, get_type_hints
8+
from typing import Annotated, Any, Callable, Literal, get_args, get_origin, get_type_hints
99

1010
from griffe import Docstring, DocstringSectionKind
1111
from pydantic import BaseModel, Field, create_model
@@ -185,6 +185,31 @@ def generate_func_documentation(
185185
)
186186

187187

188+
def _strip_annotated(annotation: Any) -> tuple[Any, tuple[Any, ...]]:
189+
"""Returns the underlying annotation and any metadata from typing.Annotated."""
190+
191+
metadata: tuple[Any, ...] = ()
192+
ann = annotation
193+
194+
while get_origin(ann) is Annotated:
195+
args = get_args(ann)
196+
if not args:
197+
break
198+
ann = args[0]
199+
metadata = (*metadata, *args[1:])
200+
201+
return ann, metadata
202+
203+
204+
def _extract_description_from_metadata(metadata: tuple[Any, ...]) -> str | None:
205+
"""Extracts a human readable description from Annotated metadata if present."""
206+
207+
for item in metadata:
208+
if isinstance(item, str):
209+
return item
210+
return None
211+
212+
188213
def function_schema(
189214
func: Callable[..., Any],
190215
docstring_style: DocstringStyle | None = None,
@@ -219,17 +244,34 @@ def function_schema(
219244
# 1. Grab docstring info
220245
if use_docstring_info:
221246
doc_info = generate_func_documentation(func, docstring_style)
222-
param_descs = doc_info.param_descriptions or {}
247+
param_descs = dict(doc_info.param_descriptions or {})
223248
else:
224249
doc_info = None
225250
param_descs = {}
226251

252+
type_hints_with_extras = get_type_hints(func, include_extras=True)
253+
type_hints: dict[str, Any] = {}
254+
annotated_param_descs: dict[str, str] = {}
255+
256+
for name, annotation in type_hints_with_extras.items():
257+
if name == "return":
258+
continue
259+
260+
stripped_ann, metadata = _strip_annotated(annotation)
261+
type_hints[name] = stripped_ann
262+
263+
description = _extract_description_from_metadata(metadata)
264+
if description is not None:
265+
annotated_param_descs[name] = description
266+
267+
for name, description in annotated_param_descs.items():
268+
param_descs.setdefault(name, description)
269+
227270
# Ensure name_override takes precedence even if docstring info is disabled.
228271
func_name = name_override or (doc_info.name if doc_info else func.__name__)
229272

230273
# 2. Inspect function signature and get type hints
231274
sig = inspect.signature(func)
232-
type_hints = get_type_hints(func)
233275
params = list(sig.parameters.items())
234276
takes_context = False
235277
filtered_params = []

src/agents/models/chatcmpl_helpers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from contextvars import ContextVar
4+
35
from openai import AsyncOpenAI
46

57
from ..model_settings import ModelSettings
@@ -8,6 +10,10 @@
810
_USER_AGENT = f"Agents/Python {__version__}"
911
HEADERS = {"User-Agent": _USER_AGENT}
1012

13+
USER_AGENT_OVERRIDE: ContextVar[str | None] = ContextVar(
14+
"openai_chatcompletions_user_agent_override", default=None
15+
)
16+
1117

1218
class ChatCmplHelpers:
1319
@classmethod

src/agents/models/openai_chatcompletions.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ..usage import Usage
2626
from ..util._json import _to_dump_compatible
2727
from .chatcmpl_converter import Converter
28-
from .chatcmpl_helpers import HEADERS, ChatCmplHelpers
28+
from .chatcmpl_helpers import HEADERS, USER_AGENT_OVERRIDE, ChatCmplHelpers
2929
from .chatcmpl_stream_handler import ChatCmplStreamHandler
3030
from .fake_id import FAKE_RESPONSES_ID
3131
from .interface import Model, ModelTracing
@@ -306,7 +306,7 @@ async def _fetch_response(
306306
reasoning_effort=self._non_null_or_not_given(reasoning_effort),
307307
verbosity=self._non_null_or_not_given(model_settings.verbosity),
308308
top_logprobs=self._non_null_or_not_given(model_settings.top_logprobs),
309-
extra_headers={**HEADERS, **(model_settings.extra_headers or {})},
309+
extra_headers=self._merge_headers(model_settings),
310310
extra_query=model_settings.extra_query,
311311
extra_body=model_settings.extra_body,
312312
metadata=self._non_null_or_not_given(model_settings.metadata),
@@ -349,3 +349,10 @@ def _get_client(self) -> AsyncOpenAI:
349349
if self._client is None:
350350
self._client = AsyncOpenAI()
351351
return self._client
352+
353+
def _merge_headers(self, model_settings: ModelSettings):
354+
merged = {**HEADERS, **(model_settings.extra_headers or {})}
355+
ua_ctx = USER_AGENT_OVERRIDE.get()
356+
if ua_ctx is not None:
357+
merged["User-Agent"] = ua_ctx
358+
return merged

src/agents/models/openai_responses.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44
from collections.abc import AsyncIterator
5+
from contextvars import ContextVar
56
from dataclasses import dataclass
67
from typing import TYPE_CHECKING, Any, Literal, cast, overload
78

@@ -49,6 +50,11 @@
4950
_USER_AGENT = f"Agents/Python {__version__}"
5051
_HEADERS = {"User-Agent": _USER_AGENT}
5152

53+
# Override for the User-Agent header used by the Responses API.
54+
_USER_AGENT_OVERRIDE: ContextVar[str | None] = ContextVar(
55+
"openai_responses_user_agent_override", default=None
56+
)
57+
5258

5359
class OpenAIResponsesModel(Model):
5460
"""
@@ -312,7 +318,7 @@ async def _fetch_response(
312318
tool_choice=tool_choice,
313319
parallel_tool_calls=parallel_tool_calls,
314320
stream=stream,
315-
extra_headers={**_HEADERS, **(model_settings.extra_headers or {})},
321+
extra_headers=self._merge_headers(model_settings),
316322
extra_query=model_settings.extra_query,
317323
extra_body=model_settings.extra_body,
318324
text=response_format,
@@ -327,6 +333,13 @@ def _get_client(self) -> AsyncOpenAI:
327333
self._client = AsyncOpenAI()
328334
return self._client
329335

336+
def _merge_headers(self, model_settings: ModelSettings):
337+
merged = {**_HEADERS, **(model_settings.extra_headers or {})}
338+
ua_ctx = _USER_AGENT_OVERRIDE.get()
339+
if ua_ctx is not None:
340+
merged["User-Agent"] = ua_ctx
341+
return merged
342+
330343

331344
@dataclass
332345
class ConvertedTools:

src/agents/util/_transforms.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
import re
22

3+
from ..logger import logger
4+
35

46
def transform_string_function_style(name: str) -> str:
57
# Replace spaces with underscores
68
name = name.replace(" ", "_")
79

810
# Replace non-alphanumeric characters with underscores
9-
name = re.sub(r"[^a-zA-Z0-9]", "_", name)
11+
transformed_name = re.sub(r"[^a-zA-Z0-9_]", "_", name)
12+
13+
if transformed_name != name:
14+
final_name = transformed_name.lower()
15+
logger.warning(
16+
f"Tool name {name!r} contains invalid characters for function calling and has been "
17+
f"transformed to {final_name!r}. Please use only letters, digits, and underscores "
18+
"to avoid potential naming conflicts."
19+
)
1020

11-
return name.lower()
21+
return transformed_name.lower()

0 commit comments

Comments
 (0)