Skip to content

Commit 5b758bd

Browse files
authored
Add logprobs to ModelSettings (#971)
1 parent 0df7903 commit 5b758bd

File tree

6 files changed

+70
-4
lines changed

6 files changed

+70
-4
lines changed

src/agents/extensions/models/litellm_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ async def _fetch_response(
321321
stream=stream,
322322
stream_options=stream_options,
323323
reasoning_effort=reasoning_effort,
324+
top_logprobs=model_settings.top_logprobs,
324325
extra_headers={**HEADERS, **(model_settings.extra_headers or {})},
325326
api_key=self.api_key,
326327
base_url=self.base_url,

src/agents/model_settings.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class MCPToolChoice:
5555
ToolChoice: TypeAlias = Union[Literal["auto", "required", "none"], str, MCPToolChoice, None]
5656

5757

58+
5859
@dataclass
5960
class ModelSettings:
6061
"""Settings to use when calling an LLM.
@@ -116,6 +117,10 @@ class ModelSettings:
116117
"""Additional output data to include in the model response.
117118
[include parameter](https://platform.openai.com/docs/api-reference/responses/create#responses-create-include)"""
118119

120+
top_logprobs: int | None = None
121+
"""Number of top tokens to return logprobs for. Setting this will
122+
automatically include ``"message.output_text.logprobs"`` in the response."""
123+
119124
extra_query: Query | None = None
120125
"""Additional query fields to provide with the request.
121126
Defaults to None if not provided."""

src/agents/models/openai_chatcompletions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ async def _fetch_response(
287287
stream_options=self._non_null_or_not_given(stream_options),
288288
store=self._non_null_or_not_given(store),
289289
reasoning_effort=self._non_null_or_not_given(reasoning_effort),
290+
top_logprobs=self._non_null_or_not_given(model_settings.top_logprobs),
290291
extra_headers={**HEADERS, **(model_settings.extra_headers or {})},
291292
extra_query=model_settings.extra_query,
292293
extra_body=model_settings.extra_body,

src/agents/models/openai_responses.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
from collections.abc import AsyncIterator
55
from dataclasses import dataclass
6-
from typing import TYPE_CHECKING, Any, Literal, overload
6+
from typing import TYPE_CHECKING, Any, Literal, cast, overload
77

88
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven
99
from openai.types import ChatModel
@@ -247,9 +247,12 @@ async def _fetch_response(
247247
converted_tools = Converter.convert_tools(tools, handoffs)
248248
response_format = Converter.get_response_format(output_schema)
249249

250-
include: list[ResponseIncludable] = converted_tools.includes
250+
include_set: set[str] = set(converted_tools.includes)
251251
if model_settings.response_include is not None:
252-
include = list({*include, *model_settings.response_include})
252+
include_set.update(model_settings.response_include)
253+
if model_settings.top_logprobs is not None:
254+
include_set.add("message.output_text.logprobs")
255+
include = cast(list[ResponseIncludable], list(include_set))
253256

254257
if _debug.DONT_LOG_MODEL_DATA:
255258
logger.debug("Calling LLM")
@@ -264,6 +267,10 @@ async def _fetch_response(
264267
f"Previous response id: {previous_response_id}\n"
265268
)
266269

270+
extra_args = dict(model_settings.extra_args or {})
271+
if model_settings.top_logprobs is not None:
272+
extra_args["top_logprobs"] = model_settings.top_logprobs
273+
267274
return await self._client.responses.create(
268275
previous_response_id=self._non_null_or_not_given(previous_response_id),
269276
instructions=self._non_null_or_not_given(system_instructions),
@@ -286,7 +293,7 @@ async def _fetch_response(
286293
store=self._non_null_or_not_given(model_settings.store),
287294
reasoning=self._non_null_or_not_given(model_settings.reasoning),
288295
metadata=self._non_null_or_not_given(model_settings.metadata),
289-
**(model_settings.extra_args or {}),
296+
**extra_args,
290297
)
291298

292299
def _get_client(self) -> AsyncOpenAI:

tests/model_settings/test_serialization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def test_all_fields_serialization() -> None:
5858
store=False,
5959
include_usage=False,
6060
response_include=["reasoning.encrypted_content"],
61+
top_logprobs=1,
6162
extra_query={"foo": "bar"},
6263
extra_body={"foo": "bar"},
6364
extra_headers={"foo": "bar"},
@@ -164,6 +165,7 @@ def test_pydantic_serialization() -> None:
164165
metadata={"foo": "bar"},
165166
store=False,
166167
include_usage=False,
168+
top_logprobs=1,
167169
extra_query={"foo": "bar"},
168170
extra_body={"foo": "bar"},
169171
extra_headers={"foo": "bar"},

tests/test_logprobs.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
3+
4+
from agents import ModelSettings, ModelTracing, OpenAIResponsesModel
5+
6+
7+
class DummyResponses:
8+
async def create(self, **kwargs):
9+
self.kwargs = kwargs
10+
11+
class DummyResponse:
12+
id = "dummy"
13+
output = []
14+
usage = type(
15+
"Usage",
16+
(),
17+
{
18+
"input_tokens": 0,
19+
"output_tokens": 0,
20+
"total_tokens": 0,
21+
"input_tokens_details": InputTokensDetails(cached_tokens=0),
22+
"output_tokens_details": OutputTokensDetails(reasoning_tokens=0),
23+
},
24+
)()
25+
26+
return DummyResponse()
27+
28+
29+
class DummyClient:
30+
def __init__(self):
31+
self.responses = DummyResponses()
32+
33+
34+
@pytest.mark.allow_call_model_methods
35+
@pytest.mark.asyncio
36+
async def test_top_logprobs_param_passed():
37+
client = DummyClient()
38+
model = OpenAIResponsesModel(model="gpt-4", openai_client=client) # type: ignore
39+
await model.get_response(
40+
system_instructions=None,
41+
input="hi",
42+
model_settings=ModelSettings(top_logprobs=2),
43+
tools=[],
44+
output_schema=None,
45+
handoffs=[],
46+
tracing=ModelTracing.DISABLED,
47+
previous_response_id=None,
48+
)
49+
assert client.responses.kwargs["top_logprobs"] == 2
50+
assert "message.output_text.logprobs" in client.responses.kwargs["include"]

0 commit comments

Comments
 (0)