Skip to content

Commit c913454

Browse files
authored
Allow modifying the input sent to the model (#1483)
Some customers have reported that the agent loop can go on for a long time and use up the entire context window. This PR allows modifying the data sent to the model.
1 parent ee12d42 commit c913454

File tree

3 files changed

+281
-6
lines changed

3 files changed

+281
-6
lines changed

src/agents/run.py

Lines changed: 95 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import copy
55
import inspect
66
from dataclasses import dataclass, field
7-
from typing import Any, Generic, cast
7+
from typing import Any, Callable, Generic, cast
88

99
from openai.types.responses import ResponseCompletedEvent
1010
from openai.types.responses.response_prompt_param import (
@@ -56,6 +56,7 @@
5656
from .tracing.span_data import AgentSpanData
5757
from .usage import Usage
5858
from .util import _coro, _error_tracing
59+
from .util._types import MaybeAwaitable
5960

6061
DEFAULT_MAX_TURNS = 10
6162

@@ -81,6 +82,27 @@ def get_default_agent_runner() -> AgentRunner:
8182
return DEFAULT_AGENT_RUNNER
8283

8384

85+
@dataclass
86+
class ModelInputData:
87+
"""Container for the data that will be sent to the model."""
88+
89+
input: list[TResponseInputItem]
90+
instructions: str | None
91+
92+
93+
@dataclass
94+
class CallModelData(Generic[TContext]):
95+
"""Data passed to `RunConfig.call_model_input_filter` prior to model call."""
96+
97+
model_data: ModelInputData
98+
agent: Agent[TContext]
99+
context: TContext | None
100+
101+
102+
# Type alias for the optional input filter callback
103+
CallModelInputFilter = Callable[[CallModelData[Any]], MaybeAwaitable[ModelInputData]]
104+
105+
84106
@dataclass
85107
class RunConfig:
86108
"""Configures settings for the entire agent run."""
@@ -139,6 +161,16 @@ class RunConfig:
139161
An optional dictionary of additional metadata to include with the trace.
140162
"""
141163

164+
call_model_input_filter: CallModelInputFilter | None = None
165+
"""
166+
Optional callback that is invoked immediately before calling the model. It receives the current
167+
agent, context and the model input (instructions and input items), and must return a possibly
168+
modified `ModelInputData` to use for the model call.
169+
170+
This allows you to edit the input sent to the model e.g. to stay within a token limit.
171+
For example, you can use this to add a system prompt to the input.
172+
"""
173+
142174

143175
class RunOptions(TypedDict, Generic[TContext]):
144176
"""Arguments for ``AgentRunner`` methods."""
@@ -593,6 +625,47 @@ def run_streamed(
593625
)
594626
return streamed_result
595627

628+
@classmethod
629+
async def _maybe_filter_model_input(
630+
cls,
631+
*,
632+
agent: Agent[TContext],
633+
run_config: RunConfig,
634+
context_wrapper: RunContextWrapper[TContext],
635+
input_items: list[TResponseInputItem],
636+
system_instructions: str | None,
637+
) -> ModelInputData:
638+
"""Apply optional call_model_input_filter to modify model input.
639+
640+
Returns a `ModelInputData` that will be sent to the model.
641+
"""
642+
effective_instructions = system_instructions
643+
effective_input: list[TResponseInputItem] = input_items
644+
645+
if run_config.call_model_input_filter is None:
646+
return ModelInputData(input=effective_input, instructions=effective_instructions)
647+
648+
try:
649+
model_input = ModelInputData(
650+
input=copy.deepcopy(effective_input),
651+
instructions=effective_instructions,
652+
)
653+
filter_payload: CallModelData[TContext] = CallModelData(
654+
model_data=model_input,
655+
agent=agent,
656+
context=context_wrapper.context,
657+
)
658+
maybe_updated = run_config.call_model_input_filter(filter_payload)
659+
updated = await maybe_updated if inspect.isawaitable(maybe_updated) else maybe_updated
660+
if not isinstance(updated, ModelInputData):
661+
raise UserError("call_model_input_filter must return a ModelInputData instance")
662+
return updated
663+
except Exception as e:
664+
_error_tracing.attach_error_to_current_span(
665+
SpanError(message="Error in call_model_input_filter", data={"error": str(e)})
666+
)
667+
raise
668+
596669
@classmethod
597670
async def _run_input_guardrails_with_queue(
598671
cls,
@@ -863,10 +936,18 @@ async def _run_single_turn_streamed(
863936
input = ItemHelpers.input_to_new_input_list(streamed_result.input)
864937
input.extend([item.to_input_item() for item in streamed_result.new_items])
865938

939+
filtered = await cls._maybe_filter_model_input(
940+
agent=agent,
941+
run_config=run_config,
942+
context_wrapper=context_wrapper,
943+
input_items=input,
944+
system_instructions=system_prompt,
945+
)
946+
866947
# 1. Stream the output events
867948
async for event in model.stream_response(
868-
system_prompt,
869-
input,
949+
filtered.instructions,
950+
filtered.input,
870951
model_settings,
871952
all_tools,
872953
output_schema,
@@ -1034,7 +1115,6 @@ async def _get_single_step_result_from_streamed_response(
10341115
run_config: RunConfig,
10351116
tool_use_tracker: AgentToolUseTracker,
10361117
) -> SingleStepResult:
1037-
10381118
original_input = streamed_result.input
10391119
pre_step_items = streamed_result.new_items
10401120
event_queue = streamed_result._event_queue
@@ -1161,13 +1241,22 @@ async def _get_new_response(
11611241
previous_response_id: str | None,
11621242
prompt_config: ResponsePromptParam | None,
11631243
) -> ModelResponse:
1244+
# Allow user to modify model input right before the call, if configured
1245+
filtered = await cls._maybe_filter_model_input(
1246+
agent=agent,
1247+
run_config=run_config,
1248+
context_wrapper=context_wrapper,
1249+
input_items=input,
1250+
system_instructions=system_prompt,
1251+
)
1252+
11641253
model = cls._get_model(agent, run_config)
11651254
model_settings = agent.model_settings.resolve(run_config.model_settings)
11661255
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
11671256

11681257
new_response = await model.get_response(
1169-
system_instructions=system_prompt,
1170-
input=input,
1258+
system_instructions=filtered.instructions,
1259+
input=filtered.input,
11711260
model_settings=model_settings,
11721261
tools=all_tools,
11731262
output_schema=output_schema,

tests/test_call_model_input_filter.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
import pytest
6+
7+
from agents import Agent, RunConfig, Runner, UserError
8+
from agents.run import CallModelData, ModelInputData
9+
10+
from .fake_model import FakeModel
11+
from .test_responses import get_text_input_item, get_text_message
12+
13+
14+
@pytest.mark.asyncio
15+
async def test_call_model_input_filter_sync_non_streamed() -> None:
16+
model = FakeModel()
17+
agent = Agent(name="test", model=model)
18+
19+
# Prepare model output
20+
model.set_next_output([get_text_message("ok")])
21+
22+
def filter_fn(data: CallModelData[Any]) -> ModelInputData:
23+
mi = data.model_data
24+
new_input = list(mi.input) + [get_text_input_item("added-sync")]
25+
return ModelInputData(input=new_input, instructions="filtered-sync")
26+
27+
await Runner.run(
28+
agent,
29+
input="start",
30+
run_config=RunConfig(call_model_input_filter=filter_fn),
31+
)
32+
33+
assert model.last_turn_args["system_instructions"] == "filtered-sync"
34+
assert isinstance(model.last_turn_args["input"], list)
35+
assert len(model.last_turn_args["input"]) == 2
36+
assert model.last_turn_args["input"][-1]["content"] == "added-sync"
37+
38+
39+
@pytest.mark.asyncio
40+
async def test_call_model_input_filter_async_streamed() -> None:
41+
model = FakeModel()
42+
agent = Agent(name="test", model=model)
43+
44+
# Prepare model output
45+
model.set_next_output([get_text_message("ok")])
46+
47+
async def filter_fn(data: CallModelData[Any]) -> ModelInputData:
48+
mi = data.model_data
49+
new_input = list(mi.input) + [get_text_input_item("added-async")]
50+
return ModelInputData(input=new_input, instructions="filtered-async")
51+
52+
result = Runner.run_streamed(
53+
agent,
54+
input="start",
55+
run_config=RunConfig(call_model_input_filter=filter_fn),
56+
)
57+
async for _ in result.stream_events():
58+
pass
59+
60+
assert model.last_turn_args["system_instructions"] == "filtered-async"
61+
assert isinstance(model.last_turn_args["input"], list)
62+
assert len(model.last_turn_args["input"]) == 2
63+
assert model.last_turn_args["input"][-1]["content"] == "added-async"
64+
65+
66+
@pytest.mark.asyncio
67+
async def test_call_model_input_filter_invalid_return_type_raises() -> None:
68+
model = FakeModel()
69+
agent = Agent(name="test", model=model)
70+
71+
def invalid_filter(_data: CallModelData[Any]):
72+
return "bad"
73+
74+
with pytest.raises(UserError):
75+
await Runner.run(
76+
agent,
77+
input="start",
78+
run_config=RunConfig(call_model_input_filter=invalid_filter),
79+
)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
from pathlib import Path
5+
from typing import Any
6+
7+
import pytest
8+
from openai.types.responses import ResponseOutputMessage, ResponseOutputText
9+
10+
# Make the repository tests helpers importable from this unit test
11+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "tests"))
12+
from fake_model import FakeModel # type: ignore
13+
14+
# Import directly from submodules to avoid heavy __init__ side effects
15+
from agents.agent import Agent
16+
from agents.exceptions import UserError
17+
from agents.run import CallModelData, ModelInputData, RunConfig, Runner
18+
19+
20+
@pytest.mark.asyncio
21+
async def test_call_model_input_filter_sync_non_streamed_unit() -> None:
22+
model = FakeModel()
23+
agent = Agent(name="test", model=model)
24+
25+
model.set_next_output(
26+
[
27+
ResponseOutputMessage(
28+
id="1",
29+
type="message",
30+
role="assistant",
31+
content=[ResponseOutputText(text="ok", type="output_text", annotations=[])],
32+
status="completed",
33+
)
34+
]
35+
)
36+
37+
def filter_fn(data: CallModelData[Any]) -> ModelInputData:
38+
mi = data.model_data
39+
new_input = list(mi.input) + [
40+
{"content": "added-sync", "role": "user"}
41+
] # pragma: no cover - trivial
42+
return ModelInputData(input=new_input, instructions="filtered-sync")
43+
44+
await Runner.run(
45+
agent,
46+
input="start",
47+
run_config=RunConfig(call_model_input_filter=filter_fn),
48+
)
49+
50+
assert model.last_turn_args["system_instructions"] == "filtered-sync"
51+
assert isinstance(model.last_turn_args["input"], list)
52+
assert len(model.last_turn_args["input"]) == 2
53+
assert model.last_turn_args["input"][-1]["content"] == "added-sync"
54+
55+
56+
@pytest.mark.asyncio
57+
async def test_call_model_input_filter_async_streamed_unit() -> None:
58+
model = FakeModel()
59+
agent = Agent(name="test", model=model)
60+
61+
model.set_next_output(
62+
[
63+
ResponseOutputMessage(
64+
id="1",
65+
type="message",
66+
role="assistant",
67+
content=[ResponseOutputText(text="ok", type="output_text", annotations=[])],
68+
status="completed",
69+
)
70+
]
71+
)
72+
73+
async def filter_fn(data: CallModelData[Any]) -> ModelInputData:
74+
mi = data.model_data
75+
new_input = list(mi.input) + [
76+
{"content": "added-async", "role": "user"}
77+
] # pragma: no cover - trivial
78+
return ModelInputData(input=new_input, instructions="filtered-async")
79+
80+
result = Runner.run_streamed(
81+
agent,
82+
input="start",
83+
run_config=RunConfig(call_model_input_filter=filter_fn),
84+
)
85+
async for _ in result.stream_events():
86+
pass
87+
88+
assert model.last_turn_args["system_instructions"] == "filtered-async"
89+
assert isinstance(model.last_turn_args["input"], list)
90+
assert len(model.last_turn_args["input"]) == 2
91+
assert model.last_turn_args["input"][-1]["content"] == "added-async"
92+
93+
94+
@pytest.mark.asyncio
95+
async def test_call_model_input_filter_invalid_return_type_raises_unit() -> None:
96+
model = FakeModel()
97+
agent = Agent(name="test", model=model)
98+
99+
def invalid_filter(_data: CallModelData[Any]):
100+
return "bad"
101+
102+
with pytest.raises(UserError):
103+
await Runner.run(
104+
agent,
105+
input="start",
106+
run_config=RunConfig(call_model_input_filter=invalid_filter),
107+
)

0 commit comments

Comments
 (0)