@@ -1704,22 +1704,15 @@ def execute_tool_calls(self, tool_calls: List[Any]) -> Any:
17041704 """
17051705 return self ._execute_tool_calls (tool_calls )
17061706
1707- def _execute_tool_calls (self , tool_calls : List [Any ], run : Optional [ThreadRun ] = None , required_action_handler : Optional ['CreateAndProcessRequiredActionHandler' ] = None ) -> Any :
1708- """
1709- Execute a tool of the specified type with the provided tool calls.
1710-
1711- :param List[Any] tool_calls: A list of tool calls to execute.
1712- :return: The output of the tool operations.
1713- :rtype: Any
1714- """
1707+ def _execute_tool_calls (self , tool_calls : List [Any ], run : Optional [ThreadRun ] = None , run_handler : Optional ['RunHandler' ] = None ) -> Any :
17151708 tool_outputs = []
17161709
17171710 for tool_call in tool_calls :
17181711 if tool_call .type == "function" :
17191712 output : Optional [Any ] = None
17201713
1721- if required_action_handler and run :
1722- output = required_action_handler .submit_function_call_output (run , tool_call , tool_call .function )
1714+ if run_handler and run :
1715+ output = run_handler .submit_function_call_output (run , tool_call , tool_call .function )
17231716 try :
17241717 if not output :
17251718 tool = self .get_tool (FunctionTool )
@@ -1782,7 +1775,7 @@ async def execute_tool_calls(self, tool_calls: List[Any]) -> Any:
17821775EventFunctionReturnT = TypeVar ("EventFunctionReturnT" )
17831776T = TypeVar ("T" )
17841777BaseAsyncAgentEventHandlerT = TypeVar ("BaseAsyncAgentEventHandlerT" , bound = "BaseAsyncAgentEventHandler" )
1785- BaseAgentEventHandlerT = TypeVar ( "BaseAgentEventHandlerT" , bound = " BaseAgentEventHandler" )
1778+ # BaseAgentEventHandlerT is defined after BaseAgentEventHandler class to avoid forward reference during parsing.
17861779
17871780async def async_chain (* iterators : AsyncIterator [T ]) -> AsyncIterator [T ]:
17881781 for iterator in iterators :
@@ -1856,10 +1849,26 @@ async def until_done(self) -> None:
18561849 pass
18571850
18581851
1859- class CreateAndProcessRequiredActionHandler :
1852+ class RunHandler :
1853+ """Helper that drives a run to completion for the "create and process" pattern.
1854+
1855+ Extension Points:
1856+ * ``submit_function_call_output`` -- override to customize how function tool results are produced.
1857+ * ``submit_mcp_tool_approval`` -- override to implement an approval workflow (UI prompt, policy, etc.).
1858+ """
1859+
1860+ def _start (self , runs_operations : "RunsOperations" , run : ThreadRun , polling_interval : int ) -> ThreadRun :
1861+ """Poll and process a run until it reaches a terminal state or is cancelled.
18601862
1861- def _start (self , runs_operations : "RunsOperations" , run : ThreadRun , polling_interval : int ) -> ThreadRun :
1862- # Monitor and process the run status
1863+ :param runs_operations: Operations client used to retrieve, cancel, and submit tool outputs/approvals.
1864+ :type runs_operations: RunsOperations
1865+ :param run: The initial run returned from create/process call.
1866+ :type run: ThreadRun
1867+ :param polling_interval: Delay (in seconds) between polling attempts.
1868+ :type polling_interval: int
1869+ :return: The final terminal ``ThreadRun`` object (completed, failed, cancelled, or expired).
1870+ :rtype: ThreadRun
1871+ """
18631872 current_retry = 0
18641873 while run .status in [
18651874 RunStatus .QUEUED ,
@@ -1882,7 +1891,7 @@ def _start(self, runs_operations: "RunsOperations", run: ThreadRun, polling_int
18821891 if any (tool_call .type == "function" for tool_call in tool_calls ):
18831892 toolset = ToolSet ()
18841893 toolset .add (runs_operations ._function_tool )
1885- tool_outputs = toolset ._execute_tool_calls (tool_calls , run = run , required_action_handler = self )
1894+ tool_outputs = toolset ._execute_tool_calls (tool_calls , run = run , run_handler = self )
18861895
18871896 if _has_errors_in_toolcalls_output (tool_outputs ):
18881897 if current_retry >= runs_operations ._function_tool_max_retry : # pylint:disable=no-else-return
@@ -1896,7 +1905,9 @@ def _start(self, runs_operations: "RunsOperations", run: ThreadRun, polling_int
18961905
18971906 logger .debug ("Tool outputs: %s" , tool_outputs )
18981907 if tool_outputs :
1899- run2 = runs_operations .submit_tool_outputs (thread_id = run .thread_id , run_id = run .id , tool_outputs = tool_outputs )
1908+ run2 = runs_operations .submit_tool_outputs (
1909+ thread_id = run .thread_id , run_id = run .id , tool_outputs = tool_outputs
1910+ )
19001911 logger .debug ("Tool outputs submitted to run: %s" , run2 .id )
19011912 elif isinstance (run .required_action , SubmitToolApprovalAction ):
19021913 tool_calls = run .required_action .submit_tool_approval .tool_calls
@@ -1909,9 +1920,11 @@ def _start(self, runs_operations: "RunsOperations", run: ThreadRun, polling_int
19091920 for tool_call in tool_calls :
19101921 if isinstance (tool_call , RequiredMcpToolCall ):
19111922 logger .info (f"Approving tool call: { tool_call } " )
1912- tool_approval = self .submit_tool_approval (run , tool_call )
1923+ tool_approval = self .submit_mcp_tool_approval (run , tool_call )
19131924 if not tool_approval :
1914- logger .debug ("submit_tool_approval in event handler returned None. Please override this function and return a valid ToolApproval." )
1925+ logger .debug (
1926+ "submit_tool_approval in event handler returned None. Please override this function and return a valid ToolApproval."
1927+ )
19151928 run = runs_operations .cancel (thread_id = run .thread_id , run_id = run .id )
19161929
19171930 tool_approvals .append (tool_approval )
@@ -1923,11 +1936,51 @@ def _start(self, runs_operations: "RunsOperations", run: ThreadRun, polling_int
19231936
19241937 return run
19251938
1926- def submit_function_call_output (self , run : ThreadRun , tool_call : RequiredFunctionToolCall , tool_call_details : RequiredFunctionToolCallDetails ) -> Optional [str ]:
1939+ def submit_function_call_output (
1940+ self ,
1941+ run : ThreadRun ,
1942+ tool_call : RequiredFunctionToolCall ,
1943+ tool_call_details : RequiredFunctionToolCallDetails ,
1944+ ** kwargs ,
1945+ ) -> Optional [Any ]:
1946+ """Produce (or override) the output for a required function tool call.
1947+
1948+ Override this to inject custom execution logic, caching, validation, or transformation.
1949+ Return ``None`` to fall back to the default execution path handled in ``_start``.
1950+
1951+ :param run: Current run requiring the function output.
1952+ :type run: ThreadRun
1953+ :param tool_call: The tool call metadata referencing the function tool.
1954+ :type tool_call: RequiredFunctionToolCall
1955+ :param tool_call_details: Function arguments/details object.
1956+ :type tool_call_details: RequiredFunctionToolCallDetails
1957+ :keyword kwargs: Additional keyword arguments for extensibility.
1958+ :return: Stringified result to send back to the service, or ``None`` to delegate to auto function calling.
1959+ :rtype: Optional[Any]
1960+ """
19271961 return None
19281962
1929- def submit_tool_approval (self , run : ThreadRun , tool_call : RequiredMcpToolCall ) -> Optional [ToolApproval ]:
1930- return None
1963+ def submit_mcp_tool_approval (
1964+ self ,
1965+ run : ThreadRun ,
1966+ tool_call : RequiredMcpToolCall ,
1967+ ** kwargs ,
1968+ ) -> Optional [ToolApproval ]:
1969+ # NOTE: Implementation intentionally returns None; override in subclasses for real approval logic.
1970+ """Return a ``ToolApproval`` for an MCP tool call or ``None`` to indicate rejection/cancellation.
1971+
1972+ Override this to implement approval policies (interactive prompt, RBAC, heuristic checks, etc.).
1973+ Returning ``None`` triggers cancellation logic in ``_start``.
1974+
1975+ :param run: Current run containing the MCP approval request.
1976+ :type run: ThreadRun
1977+ :param tool_call: The MCP tool call requiring approval.
1978+ :type tool_call: RequiredMcpToolCall
1979+ :keyword kwargs: Additional keyword arguments for extensibility.
1980+ :return: A populated ``ToolApproval`` instance on approval, or ``None`` to decline.
1981+ :rtype: Optional[ToolApproval]
1982+ """
1983+ return None
19311984
19321985
19331986
@@ -1991,6 +2044,9 @@ def until_done(self) -> None:
19912044 except StopIteration :
19922045 pass
19932046
2047+ # Now that BaseAgentEventHandler is defined, we can bind the TypeVar.
2048+ BaseAgentEventHandlerT = TypeVar ("BaseAgentEventHandlerT" , bound = "BaseAgentEventHandler" )
2049+
19942050
19952051class AsyncAgentEventHandler (BaseAsyncAgentEventHandler [Tuple [str , StreamEventData , Optional [EventFunctionReturnT ]]]):
19962052 def __init__ (self ) -> None :
@@ -2350,7 +2406,7 @@ def _is_valid_connection_id(connection_id: str) -> bool:
23502406 "MessageTextFileCitationAnnotation" ,
23512407 "MessageDeltaChunk" ,
23522408 "MessageAttachment" ,
2353- "CreateAndProcessRequiredActionHandler " ,
2409+ "RunHandler " ,
23542410] # Add all objects you want publicly available to users at this package level
23552411
23562412
0 commit comments