Skip to content

Commit 13358f1

Browse files
committed
resolved comments from Jarno
1 parent 009e773 commit 13358f1

File tree

4 files changed

+95
-42
lines changed

4 files changed

+95
-42
lines changed

sdk/ai/azure-ai-agents/azure/ai/agents/models/_patch.py

Lines changed: 78 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,22 +1704,15 @@ def execute_tool_calls(self, tool_calls: List[Any]) -> Any:
17041704
"""
17051705
return self._execute_tool_calls(tool_calls)
17061706

1707-
def _execute_tool_calls(self, tool_calls: List[Any], run: Optional[ThreadRun] = None, required_action_handler: Optional['CreateAndProcessRequiredActionHandler'] = None) -> Any:
1708-
"""
1709-
Execute a tool of the specified type with the provided tool calls.
1710-
1711-
:param List[Any] tool_calls: A list of tool calls to execute.
1712-
:return: The output of the tool operations.
1713-
:rtype: Any
1714-
"""
1707+
def _execute_tool_calls(self, tool_calls: List[Any], run: Optional[ThreadRun] = None, run_handler: Optional['RunHandler'] = None) -> Any:
17151708
tool_outputs = []
17161709

17171710
for tool_call in tool_calls:
17181711
if tool_call.type == "function":
17191712
output: Optional[Any] = None
17201713

1721-
if required_action_handler and run:
1722-
output = required_action_handler.submit_function_call_output(run, tool_call, tool_call.function)
1714+
if run_handler and run:
1715+
output = run_handler.submit_function_call_output(run, tool_call, tool_call.function)
17231716
try:
17241717
if not output:
17251718
tool = self.get_tool(FunctionTool)
@@ -1782,7 +1775,7 @@ async def execute_tool_calls(self, tool_calls: List[Any]) -> Any:
17821775
EventFunctionReturnT = TypeVar("EventFunctionReturnT")
17831776
T = TypeVar("T")
17841777
BaseAsyncAgentEventHandlerT = TypeVar("BaseAsyncAgentEventHandlerT", bound="BaseAsyncAgentEventHandler")
1785-
BaseAgentEventHandlerT = TypeVar("BaseAgentEventHandlerT", bound="BaseAgentEventHandler")
1778+
# BaseAgentEventHandlerT is defined after BaseAgentEventHandler class to avoid forward reference during parsing.
17861779

17871780
async def async_chain(*iterators: AsyncIterator[T]) -> AsyncIterator[T]:
17881781
for iterator in iterators:
@@ -1856,10 +1849,26 @@ async def until_done(self) -> None:
18561849
pass
18571850

18581851

1859-
class CreateAndProcessRequiredActionHandler:
1852+
class RunHandler:
1853+
"""Helper that drives a run to completion for the "create and process" pattern.
1854+
1855+
Extension Points:
1856+
* ``submit_function_call_output`` -- override to customize how function tool results are produced.
1857+
* ``submit_mcp_tool_approval`` -- override to implement an approval workflow (UI prompt, policy, etc.).
1858+
"""
1859+
1860+
def _start(self, runs_operations: "RunsOperations", run: ThreadRun, polling_interval: int) -> ThreadRun:
1861+
"""Poll and process a run until it reaches a terminal state or is cancelled.
18601862
1861-
def _start(self, runs_operations: "RunsOperations", run: ThreadRun, polling_interval: int) -> ThreadRun:
1862-
# Monitor and process the run status
1863+
:param runs_operations: Operations client used to retrieve, cancel, and submit tool outputs/approvals.
1864+
:type runs_operations: RunsOperations
1865+
:param run: The initial run returned from create/process call.
1866+
:type run: ThreadRun
1867+
:param polling_interval: Delay (in seconds) between polling attempts.
1868+
:type polling_interval: int
1869+
:return: The final terminal ``ThreadRun`` object (completed, failed, cancelled, or expired).
1870+
:rtype: ThreadRun
1871+
"""
18631872
current_retry = 0
18641873
while run.status in [
18651874
RunStatus.QUEUED,
@@ -1882,7 +1891,7 @@ def _start(self, runs_operations: "RunsOperations", run: ThreadRun, polling_int
18821891
if any(tool_call.type == "function" for tool_call in tool_calls):
18831892
toolset = ToolSet()
18841893
toolset.add(runs_operations._function_tool)
1885-
tool_outputs = toolset._execute_tool_calls(tool_calls, run=run, required_action_handler=self)
1894+
tool_outputs = toolset._execute_tool_calls(tool_calls, run=run, run_handler=self)
18861895

18871896
if _has_errors_in_toolcalls_output(tool_outputs):
18881897
if current_retry >= runs_operations._function_tool_max_retry: # pylint:disable=no-else-return
@@ -1896,7 +1905,9 @@ def _start(self, runs_operations: "RunsOperations", run: ThreadRun, polling_int
18961905

18971906
logger.debug("Tool outputs: %s", tool_outputs)
18981907
if tool_outputs:
1899-
run2 = runs_operations.submit_tool_outputs(thread_id=run.thread_id, run_id=run.id, tool_outputs=tool_outputs)
1908+
run2 = runs_operations.submit_tool_outputs(
1909+
thread_id=run.thread_id, run_id=run.id, tool_outputs=tool_outputs
1910+
)
19001911
logger.debug("Tool outputs submitted to run: %s", run2.id)
19011912
elif isinstance(run.required_action, SubmitToolApprovalAction):
19021913
tool_calls = run.required_action.submit_tool_approval.tool_calls
@@ -1909,9 +1920,11 @@ def _start(self, runs_operations: "RunsOperations", run: ThreadRun, polling_int
19091920
for tool_call in tool_calls:
19101921
if isinstance(tool_call, RequiredMcpToolCall):
19111922
logger.info(f"Approving tool call: {tool_call}")
1912-
tool_approval = self.submit_tool_approval(run, tool_call)
1923+
tool_approval = self.submit_mcp_tool_approval(run, tool_call)
19131924
if not tool_approval:
1914-
logger.debug("submit_tool_approval in event handler returned None. Please override this function and return a valid ToolApproval.")
1925+
logger.debug(
1926+
"submit_tool_approval in event handler returned None. Please override this function and return a valid ToolApproval."
1927+
)
19151928
run = runs_operations.cancel(thread_id=run.thread_id, run_id=run.id)
19161929

19171930
tool_approvals.append(tool_approval)
@@ -1923,11 +1936,51 @@ def _start(self, runs_operations: "RunsOperations", run: ThreadRun, polling_int
19231936

19241937
return run
19251938

1926-
def submit_function_call_output(self, run: ThreadRun, tool_call: RequiredFunctionToolCall, tool_call_details: RequiredFunctionToolCallDetails) -> Optional[str]:
1939+
def submit_function_call_output(
1940+
self,
1941+
run: ThreadRun,
1942+
tool_call: RequiredFunctionToolCall,
1943+
tool_call_details: RequiredFunctionToolCallDetails,
1944+
**kwargs,
1945+
) -> Optional[Any]:
1946+
"""Produce (or override) the output for a required function tool call.
1947+
1948+
Override this to inject custom execution logic, caching, validation, or transformation.
1949+
Return ``None`` to fall back to the default execution path handled in ``_start``.
1950+
1951+
:param run: Current run requiring the function output.
1952+
:type run: ThreadRun
1953+
:param tool_call: The tool call metadata referencing the function tool.
1954+
:type tool_call: RequiredFunctionToolCall
1955+
:param tool_call_details: Function arguments/details object.
1956+
:type tool_call_details: RequiredFunctionToolCallDetails
1957+
:keyword kwargs: Additional keyword arguments for extensibility.
1958+
:return: Stringified result to send back to the service, or ``None`` to delegate to auto function calling.
1959+
:rtype: Optional[Any]
1960+
"""
19271961
return None
19281962

1929-
def submit_tool_approval(self, run: ThreadRun, tool_call: RequiredMcpToolCall) -> Optional[ToolApproval]:
1930-
return None
1963+
def submit_mcp_tool_approval(
1964+
self,
1965+
run: ThreadRun,
1966+
tool_call: RequiredMcpToolCall,
1967+
**kwargs,
1968+
) -> Optional[ToolApproval]:
1969+
# NOTE: Implementation intentionally returns None; override in subclasses for real approval logic.
1970+
"""Return a ``ToolApproval`` for an MCP tool call or ``None`` to indicate rejection/cancellation.
1971+
1972+
Override this to implement approval policies (interactive prompt, RBAC, heuristic checks, etc.).
1973+
Returning ``None`` triggers cancellation logic in ``_start``.
1974+
1975+
:param run: Current run containing the MCP approval request.
1976+
:type run: ThreadRun
1977+
:param tool_call: The MCP tool call requiring approval.
1978+
:type tool_call: RequiredMcpToolCall
1979+
:keyword kwargs: Additional keyword arguments for extensibility.
1980+
:return: A populated ``ToolApproval`` instance on approval, or ``None`` to decline.
1981+
:rtype: Optional[ToolApproval]
1982+
"""
1983+
return None
19311984

19321985

19331986

@@ -1991,6 +2044,9 @@ def until_done(self) -> None:
19912044
except StopIteration:
19922045
pass
19932046

2047+
# Now that BaseAgentEventHandler is defined, we can bind the TypeVar.
2048+
BaseAgentEventHandlerT = TypeVar("BaseAgentEventHandlerT", bound="BaseAgentEventHandler")
2049+
19942050

19952051
class AsyncAgentEventHandler(BaseAsyncAgentEventHandler[Tuple[str, StreamEventData, Optional[EventFunctionReturnT]]]):
19962052
def __init__(self) -> None:
@@ -2350,7 +2406,7 @@ def _is_valid_connection_id(connection_id: str) -> bool:
23502406
"MessageTextFileCitationAnnotation",
23512407
"MessageDeltaChunk",
23522408
"MessageAttachment",
2353-
"CreateAndProcessRequiredActionHandler",
2409+
"RunHandler",
23542410
] # Add all objects you want publicly available to users at this package level
23552411

23562412

sdk/ai/azure-ai-agents/azure/ai/agents/operations/_patch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def create_and_process(
441441
response_format: Optional["_types.AgentsResponseFormatOption"] = None,
442442
parallel_tool_calls: Optional[bool] = None,
443443
metadata: Optional[Dict[str, str]] = None,
444-
required_action_handler: Optional[_models.CreateAndProcessRequiredActionHandler] = None,
444+
run_handler: Optional[_models.RunHandler] = None,
445445
polling_interval: int = 1,
446446
**kwargs: Any,
447447
) -> _models.ThreadRun:
@@ -553,10 +553,10 @@ def create_and_process(
553553
)
554554

555555
# Monitor and process the run status
556-
if not required_action_handler:
557-
required_action_handler = _models.CreateAndProcessRequiredActionHandler(self)
556+
if not run_handler:
557+
run_handler = _models.RunHandler()
558558

559-
return required_action_handler._start(self, run, polling_interval)
559+
return run_handler._start(self, run, polling_interval)
560560

561561
@overload
562562
def stream(

sdk/ai/azure-ai-agents/samples/agents_tools/sample_agents_functions_in_create_and_process.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,14 @@
2222
the "Models + endpoints" tab in your Azure AI Foundry project.
2323
"""
2424
import os, time, sys
25-
from typing import Optional
25+
from typing import Any, Optional
2626
from azure.ai.projects import AIProjectClient
2727
from azure.identity import DefaultAzureCredential
2828
from azure.ai.agents.models import (
2929
FunctionTool,
3030
ListSortOrder,
3131
RequiredFunctionToolCall,
32-
SubmitToolOutputsAction,
33-
ToolOutput,
34-
CreateAndProcessRequiredActionHandler,
32+
RunHandler,
3533
ThreadRun,
3634
RequiredFunctionToolCall,
3735
RequiredFunctionToolCallDetails,
@@ -51,8 +49,9 @@
5149
# Initialize function tool with user functions
5250
functions = FunctionTool(functions=user_functions)
5351

54-
class MyCreateAndProcessRequiredActionHandler(CreateAndProcessRequiredActionHandler):
55-
def submit_function_call_output(self, run: ThreadRun, tool_call: RequiredFunctionToolCall, tool_call_details: RequiredFunctionToolCallDetails) -> Optional[str]:
52+
class MyRunHandler(RunHandler):
53+
def submit_function_call_output(self, run: ThreadRun, tool_call: RequiredFunctionToolCall, tool_call_details: RequiredFunctionToolCallDetails) -> Optional[Any]:
54+
print(f"Call function: {tool_call_details.name}")
5655
return functions.execute(tool_call)
5756

5857
with project_client:
@@ -77,8 +76,8 @@ def submit_function_call_output(self, run: ThreadRun, tool_call: RequiredFunctio
7776
)
7877
print(f"Created message, ID: {message.id}")
7978

80-
required_action_handler = MyCreateAndProcessRequiredActionHandler()
81-
run = agents_client.runs.create_and_process(thread_id=thread.id, agent_id=agent.id, required_action_handler=required_action_handler)
79+
run_handler = MyRunHandler()
80+
run = agents_client.runs.create_and_process(thread_id=thread.id, agent_id=agent.id, run_handler=run_handler)
8281

8382
print(f"Run completed with status: {run.status}")
8483

sdk/ai/azure-ai-agents/samples/agents_tools/sample_agents_mcp_in_create_and_process.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
4) MCP_SERVER_LABEL - A label for your MCP server.
2727
"""
2828

29-
import os, time
29+
import os
3030
from typing import Optional
3131
from azure.ai.projects import AIProjectClient
3232
from azure.identity import DefaultAzureCredential
@@ -35,9 +35,8 @@
3535
McpTool,
3636
RequiredMcpToolCall,
3737
RunStepActivityDetails,
38-
SubmitToolApprovalAction,
3938
ToolApproval,
40-
CreateAndProcessRequiredActionHandler,
39+
RunHandler,
4140
ThreadRun,
4241
ToolSet,
4342
)
@@ -66,15 +65,14 @@
6665
mcp_tool.allow_tool(search_api_code)
6766
print(f"Allowed tools: {mcp_tool.allowed_tools}")
6867

69-
class MyCreateAndProcessRequiredActionHandler(CreateAndProcessRequiredActionHandler):
70-
def submit_tool_approval(self, run: ThreadRun, tool_call: RequiredMcpToolCall) -> Optional[ToolApproval]:
68+
class MyRunHandler(RunHandler):
69+
def submit_mcp_tool_approval(self, run: ThreadRun, tool_call: RequiredMcpToolCall, **kwargs) -> Optional[ToolApproval]:
7170
return ToolApproval(
7271
tool_call_id=tool_call.id,
7372
approve=True,
7473
headers=mcp_tool.headers,
7574
)
7675

77-
7876
# Create agent with MCP tool and process agent run
7977
with project_client:
8078
agents_client = project_client.agents
@@ -108,9 +106,9 @@ def submit_tool_approval(self, run: ThreadRun, tool_call: RequiredMcpToolCall) -
108106
# Create and process agent run in thread with MCP tools
109107
mcp_tool.update_headers("SuperSecret", "123456")
110108

111-
required_action_handler = MyCreateAndProcessRequiredActionHandler()
109+
run_handler = MyRunHandler()
112110
# mcp_tool.set_approval_mode("never") # Uncomment to disable approval requirement
113-
run = agents_client.runs.create_and_process(thread_id=thread.id, agent_id=agent.id, required_action_handler=required_action_handler)
111+
run = agents_client.runs.create_and_process(thread_id=thread.id, agent_id=agent.id, run_handler=run_handler)
114112
print(f"Created run, ID: {run.id}")
115113

116114
print(f"Run completed with status: {run.status}")

0 commit comments

Comments
 (0)