Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions src/agents/extensions/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from ...tracing.span_data import GenerationSpanData
from ...tracing.spans import Span
from ...usage import Usage
from ...util._json import _to_dump_compatible


class InternalChatCompletionMessage(ChatCompletionMessage):
Expand Down Expand Up @@ -265,6 +266,8 @@ async def _fetch_response(
"role": "system",
},
)
converted_messages = _to_dump_compatible(converted_messages)

if tracing.include_data():
span.span_data.input = converted_messages

Expand All @@ -283,13 +286,25 @@ async def _fetch_response(
for handoff in handoffs:
converted_tools.append(Converter.convert_handoff_tool(handoff))

converted_tools = _to_dump_compatible(converted_tools)

if _debug.DONT_LOG_MODEL_DATA:
logger.debug("Calling LLM")
else:
messages_json = json.dumps(
converted_messages,
indent=2,
ensure_ascii=False,
)
tools_json = json.dumps(
converted_tools,
indent=2,
ensure_ascii=False,
)
logger.debug(
f"Calling Litellm model: {self.model}\n"
f"{json.dumps(converted_messages, indent=2, ensure_ascii=False)}\n"
f"Tools:\n{json.dumps(converted_tools, indent=2, ensure_ascii=False)}\n"
f"{messages_json}\n"
f"Tools:\n{tools_json}\n"
f"Stream: {stream}\n"
f"Tool choice: {tool_choice}\n"
f"Response format: {response_format}\n"
Expand Down Expand Up @@ -369,9 +384,9 @@ def convert_message_to_openai(
if message.role != "assistant":
raise ModelBehaviorError(f"Unsupported role: {message.role}")

tool_calls: list[
ChatCompletionMessageFunctionToolCall | ChatCompletionMessageCustomToolCall
] | None = (
tool_calls: (
list[ChatCompletionMessageFunctionToolCall | ChatCompletionMessageCustomToolCall] | None
) = (
[LitellmConverter.convert_tool_call_to_openai(tool) for tool in message.tool_calls]
if message.tool_calls
else None
Expand Down
19 changes: 17 additions & 2 deletions src/agents/models/openai_chatcompletions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ..tracing.span_data import GenerationSpanData
from ..tracing.spans import Span
from ..usage import Usage
from ..util._json import _to_dump_compatible
from .chatcmpl_converter import Converter
from .chatcmpl_helpers import HEADERS, ChatCmplHelpers
from .chatcmpl_stream_handler import ChatCmplStreamHandler
Expand Down Expand Up @@ -237,6 +238,8 @@ async def _fetch_response(
"role": "system",
},
)
converted_messages = _to_dump_compatible(converted_messages)

if tracing.include_data():
span.span_data.input = converted_messages

Expand All @@ -255,12 +258,24 @@ async def _fetch_response(
for handoff in handoffs:
converted_tools.append(Converter.convert_handoff_tool(handoff))

converted_tools = _to_dump_compatible(converted_tools)

if _debug.DONT_LOG_MODEL_DATA:
logger.debug("Calling LLM")
else:
messages_json = json.dumps(
converted_messages,
indent=2,
ensure_ascii=False,
)
tools_json = json.dumps(
converted_tools,
indent=2,
ensure_ascii=False,
)
logger.debug(
f"{json.dumps(converted_messages, indent=2, ensure_ascii=False)}\n"
f"Tools:\n{json.dumps(converted_tools, indent=2, ensure_ascii=False)}\n"
f"{messages_json}\n"
f"Tools:\n{tools_json}\n"
f"Stream: {stream}\n"
f"Tool choice: {tool_choice}\n"
f"Response format: {response_format}\n"
Expand Down
19 changes: 16 additions & 3 deletions src/agents/models/openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from ..tracing import SpanError, response_span
from ..usage import Usage
from ..util._json import _to_dump_compatible
from ..version import __version__
from .interface import Model, ModelTracing

Expand Down Expand Up @@ -240,6 +241,7 @@ async def _fetch_response(
prompt: ResponsePromptParam | None = None,
) -> Response | AsyncStream[ResponseStreamEvent]:
list_input = ItemHelpers.input_to_new_input_list(input)
list_input = _to_dump_compatible(list_input)

parallel_tool_calls = (
True
Expand All @@ -251,6 +253,7 @@ async def _fetch_response(

tool_choice = Converter.convert_tool_choice(model_settings.tool_choice)
converted_tools = Converter.convert_tools(tools, handoffs)
converted_tools_payload = _to_dump_compatible(converted_tools.tools)
response_format = Converter.get_response_format(output_schema)

include_set: set[str] = set(converted_tools.includes)
Expand All @@ -263,10 +266,20 @@ async def _fetch_response(
if _debug.DONT_LOG_MODEL_DATA:
logger.debug("Calling LLM")
else:
input_json = json.dumps(
list_input,
indent=2,
ensure_ascii=False,
)
tools_json = json.dumps(
converted_tools_payload,
indent=2,
ensure_ascii=False,
)
logger.debug(
f"Calling LLM {self.model} with input:\n"
f"{json.dumps(list_input, indent=2, ensure_ascii=False)}\n"
f"Tools:\n{json.dumps(converted_tools.tools, indent=2, ensure_ascii=False)}\n"
f"{input_json}\n"
f"Tools:\n{tools_json}\n"
f"Stream: {stream}\n"
f"Tool choice: {tool_choice}\n"
f"Response format: {response_format}\n"
Expand All @@ -290,7 +303,7 @@ async def _fetch_response(
model=self.model,
input=list_input,
include=include,
tools=converted_tools.tools,
tools=converted_tools_payload,
prompt=self._non_null_or_not_given(prompt),
temperature=self._non_null_or_not_given(model_settings.temperature),
top_p=self._non_null_or_not_given(model_settings.top_p),
Expand Down
20 changes: 19 additions & 1 deletion src/agents/util/_json.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import Literal
from collections.abc import Iterable
from typing import Any, Literal

from pydantic import TypeAdapter, ValidationError
from typing_extensions import TypeVar
Expand Down Expand Up @@ -29,3 +30,20 @@ def validate_json(json_str: str, type_adapter: TypeAdapter[T], partial: bool) ->
raise ModelBehaviorError(
f"Invalid JSON when parsing {json_str} for {type_adapter}; {e}"
) from e


def _to_dump_compatible(obj: Any) -> Any:
return _to_dump_compatible_internal(obj)


def _to_dump_compatible_internal(obj: Any) -> Any:
if isinstance(obj, dict):
return {k: _to_dump_compatible_internal(v) for k, v in obj.items()}

if isinstance(obj, (list, tuple)):
return [_to_dump_compatible_internal(x) for x in obj]

if isinstance(obj, Iterable) and not isinstance(obj, (str, bytes, bytearray)):
return [_to_dump_compatible_internal(x) for x in obj]

return obj
187 changes: 187 additions & 0 deletions tests/test_model_payload_iterators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from __future__ import annotations

from collections.abc import Iterable, Iterator
from typing import Any, cast

import httpx
import pytest
from openai import NOT_GIVEN
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.responses import ToolParam

from agents import (
ModelSettings,
ModelTracing,
OpenAIChatCompletionsModel,
OpenAIResponsesModel,
generation_span,
)
from agents.models import (
openai_chatcompletions as chat_module,
openai_responses as responses_module,
)


class _SingleUseIterable:
"""Helper iterable that raises if iterated more than once."""

def __init__(self, values: list[object]) -> None:
self._values = list(values)
self.iterations = 0

def __iter__(self) -> Iterator[object]:
if self.iterations:
raise RuntimeError("Iterable should have been materialized exactly once.")
self.iterations += 1
yield from self._values


def _force_materialization(value: object) -> None:
if isinstance(value, dict):
for nested in value.values():
_force_materialization(nested)
elif isinstance(value, list):
for nested in value:
_force_materialization(nested)
elif isinstance(value, Iterable) and not isinstance(value, (str, bytes, bytearray)):
list(value)


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_chat_completions_materializes_iterator_payload(
monkeypatch: pytest.MonkeyPatch,
) -> None:
message_iter = _SingleUseIterable([{"type": "text", "text": "hi"}])
tool_iter = _SingleUseIterable([{"type": "string"}])

chat_converter = cast(Any, chat_module).Converter

monkeypatch.setattr(
chat_converter,
"items_to_messages",
classmethod(lambda _cls, _input: [{"role": "user", "content": message_iter}]),
)
monkeypatch.setattr(
chat_converter,
"tool_to_openai",
classmethod(
lambda _cls, _tool: {
"type": "function",
"function": {
"name": "dummy",
"parameters": {"properties": tool_iter},
},
}
),
)

captured_kwargs: dict[str, Any] = {}

class DummyCompletions:
async def create(self, **kwargs):
captured_kwargs.update(kwargs)
_force_materialization(kwargs["messages"])
if kwargs["tools"] is not NOT_GIVEN:
_force_materialization(kwargs["tools"])
return ChatCompletion(
id="dummy-id",
created=0,
model="gpt-4",
object="chat.completion",
choices=[],
usage=None,
)

class DummyClient:
def __init__(self) -> None:
self.chat = type("_Chat", (), {"completions": DummyCompletions()})()
self.base_url = httpx.URL("http://example.test")

model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore[arg-type]

with generation_span(disabled=True) as span:
await cast(Any, model)._fetch_response(
system_instructions=None,
input="ignored",
model_settings=ModelSettings(),
tools=[object()],
output_schema=None,
handoffs=[],
span=span,
tracing=ModelTracing.DISABLED,
stream=False,
)

assert message_iter.iterations == 1
assert tool_iter.iterations == 1
assert isinstance(captured_kwargs["messages"][0]["content"], list)
assert isinstance(captured_kwargs["tools"][0]["function"]["parameters"]["properties"], list)


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_responses_materializes_iterator_payload(monkeypatch: pytest.MonkeyPatch) -> None:
input_iter = _SingleUseIterable([{"type": "input_text", "text": "hello"}])
tool_iter = _SingleUseIterable([{"type": "string"}])

responses_item_helpers = cast(Any, responses_module).ItemHelpers
responses_converter = cast(Any, responses_module).Converter

monkeypatch.setattr(
responses_item_helpers,
"input_to_new_input_list",
classmethod(lambda _cls, _input: [{"role": "user", "content": input_iter}]),
)

converted_tools = responses_module.ConvertedTools(
tools=cast(
list[ToolParam],
[
{
"type": "function",
"name": "dummy",
"parameters": {"properties": tool_iter},
}
],
),
includes=[],
)
monkeypatch.setattr(
responses_converter,
"convert_tools",
classmethod(lambda _cls, _tools, _handoffs: converted_tools),
)

captured_kwargs: dict[str, Any] = {}

class DummyResponses:
async def create(self, **kwargs):
captured_kwargs.update(kwargs)
_force_materialization(kwargs["input"])
_force_materialization(kwargs["tools"])
return object()

class DummyClient:
def __init__(self) -> None:
self.responses = DummyResponses()

model = OpenAIResponsesModel(model="gpt-4.1", openai_client=DummyClient()) # type: ignore[arg-type]

await cast(Any, model)._fetch_response(
system_instructions=None,
input="ignored",
model_settings=ModelSettings(),
tools=[],
output_schema=None,
handoffs=[],
previous_response_id=None,
conversation_id=None,
stream=False,
prompt=None,
)

assert input_iter.iterations == 1
assert tool_iter.iterations == 1
assert isinstance(captured_kwargs["input"][0]["content"], list)
assert isinstance(captured_kwargs["tools"][0]["parameters"]["properties"], list)
Loading