4
4
import copy
5
5
import inspect
6
6
from dataclasses import dataclass , field
7
- from typing import Any , Generic , cast
7
+ from typing import Any , Callable , Generic , cast
8
8
9
9
from openai .types .responses import ResponseCompletedEvent
10
10
from openai .types .responses .response_prompt_param import (
56
56
from .tracing .span_data import AgentSpanData
57
57
from .usage import Usage
58
58
from .util import _coro , _error_tracing
59
+ from .util ._types import MaybeAwaitable
59
60
60
61
DEFAULT_MAX_TURNS = 10
61
62
@@ -81,6 +82,27 @@ def get_default_agent_runner() -> AgentRunner:
81
82
return DEFAULT_AGENT_RUNNER
82
83
83
84
85
+ @dataclass
86
+ class ModelInputData :
87
+ """Container for the data that will be sent to the model."""
88
+
89
+ input : list [TResponseInputItem ]
90
+ instructions : str | None
91
+
92
+
93
+ @dataclass
94
+ class CallModelData (Generic [TContext ]):
95
+ """Data passed to `RunConfig.call_model_input_filter` prior to model call."""
96
+
97
+ model_data : ModelInputData
98
+ agent : Agent [TContext ]
99
+ context : TContext | None
100
+
101
+
102
+ # Type alias for the optional input filter callback
103
+ CallModelInputFilter = Callable [[CallModelData [Any ]], MaybeAwaitable [ModelInputData ]]
104
+
105
+
84
106
@dataclass
85
107
class RunConfig :
86
108
"""Configures settings for the entire agent run."""
@@ -139,6 +161,16 @@ class RunConfig:
139
161
An optional dictionary of additional metadata to include with the trace.
140
162
"""
141
163
164
+ call_model_input_filter : CallModelInputFilter | None = None
165
+ """
166
+ Optional callback that is invoked immediately before calling the model. It receives the current
167
+ agent, context and the model input (instructions and input items), and must return a possibly
168
+ modified `ModelInputData` to use for the model call.
169
+
170
+ This allows you to edit the input sent to the model e.g. to stay within a token limit.
171
+ For example, you can use this to add a system prompt to the input.
172
+ """
173
+
142
174
143
175
class RunOptions (TypedDict , Generic [TContext ]):
144
176
"""Arguments for ``AgentRunner`` methods."""
@@ -593,6 +625,47 @@ def run_streamed(
593
625
)
594
626
return streamed_result
595
627
628
+ @classmethod
629
+ async def _maybe_filter_model_input (
630
+ cls ,
631
+ * ,
632
+ agent : Agent [TContext ],
633
+ run_config : RunConfig ,
634
+ context_wrapper : RunContextWrapper [TContext ],
635
+ input_items : list [TResponseInputItem ],
636
+ system_instructions : str | None ,
637
+ ) -> ModelInputData :
638
+ """Apply optional call_model_input_filter to modify model input.
639
+
640
+ Returns a `ModelInputData` that will be sent to the model.
641
+ """
642
+ effective_instructions = system_instructions
643
+ effective_input : list [TResponseInputItem ] = input_items
644
+
645
+ if run_config .call_model_input_filter is None :
646
+ return ModelInputData (input = effective_input , instructions = effective_instructions )
647
+
648
+ try :
649
+ model_input = ModelInputData (
650
+ input = copy .deepcopy (effective_input ),
651
+ instructions = effective_instructions ,
652
+ )
653
+ filter_payload : CallModelData [TContext ] = CallModelData (
654
+ model_data = model_input ,
655
+ agent = agent ,
656
+ context = context_wrapper .context ,
657
+ )
658
+ maybe_updated = run_config .call_model_input_filter (filter_payload )
659
+ updated = await maybe_updated if inspect .isawaitable (maybe_updated ) else maybe_updated
660
+ if not isinstance (updated , ModelInputData ):
661
+ raise UserError ("call_model_input_filter must return a ModelInputData instance" )
662
+ return updated
663
+ except Exception as e :
664
+ _error_tracing .attach_error_to_current_span (
665
+ SpanError (message = "Error in call_model_input_filter" , data = {"error" : str (e )})
666
+ )
667
+ raise
668
+
596
669
@classmethod
597
670
async def _run_input_guardrails_with_queue (
598
671
cls ,
@@ -863,10 +936,18 @@ async def _run_single_turn_streamed(
863
936
input = ItemHelpers .input_to_new_input_list (streamed_result .input )
864
937
input .extend ([item .to_input_item () for item in streamed_result .new_items ])
865
938
939
+ filtered = await cls ._maybe_filter_model_input (
940
+ agent = agent ,
941
+ run_config = run_config ,
942
+ context_wrapper = context_wrapper ,
943
+ input_items = input ,
944
+ system_instructions = system_prompt ,
945
+ )
946
+
866
947
# 1. Stream the output events
867
948
async for event in model .stream_response (
868
- system_prompt ,
869
- input ,
949
+ filtered . instructions ,
950
+ filtered . input ,
870
951
model_settings ,
871
952
all_tools ,
872
953
output_schema ,
@@ -1034,7 +1115,6 @@ async def _get_single_step_result_from_streamed_response(
1034
1115
run_config : RunConfig ,
1035
1116
tool_use_tracker : AgentToolUseTracker ,
1036
1117
) -> SingleStepResult :
1037
-
1038
1118
original_input = streamed_result .input
1039
1119
pre_step_items = streamed_result .new_items
1040
1120
event_queue = streamed_result ._event_queue
@@ -1161,13 +1241,22 @@ async def _get_new_response(
1161
1241
previous_response_id : str | None ,
1162
1242
prompt_config : ResponsePromptParam | None ,
1163
1243
) -> ModelResponse :
1244
+ # Allow user to modify model input right before the call, if configured
1245
+ filtered = await cls ._maybe_filter_model_input (
1246
+ agent = agent ,
1247
+ run_config = run_config ,
1248
+ context_wrapper = context_wrapper ,
1249
+ input_items = input ,
1250
+ system_instructions = system_prompt ,
1251
+ )
1252
+
1164
1253
model = cls ._get_model (agent , run_config )
1165
1254
model_settings = agent .model_settings .resolve (run_config .model_settings )
1166
1255
model_settings = RunImpl .maybe_reset_tool_choice (agent , tool_use_tracker , model_settings )
1167
1256
1168
1257
new_response = await model .get_response (
1169
- system_instructions = system_prompt ,
1170
- input = input ,
1258
+ system_instructions = filtered . instructions ,
1259
+ input = filtered . input ,
1171
1260
model_settings = model_settings ,
1172
1261
tools = all_tools ,
1173
1262
output_schema = output_schema ,
0 commit comments