Skip to content

Commit f760459

Browse files
committed
added test and bug fix
1 parent 2fc6595 commit f760459

File tree

3 files changed

+65
-7
lines changed

3 files changed

+65
-7
lines changed

sdk/ai/azure-ai-agents/azure/ai/agents/models/_patch.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1894,15 +1894,15 @@ def _start(self, runs_operations: "RunsOperations", run: ThreadRun, polling_inte
18941894
if any(tool_call.type == "function" for tool_call in tool_calls):
18951895
toolset = ToolSet()
18961896
toolset.add(runs_operations._function_tool)
1897-
tool_outputs = toolset._execute_tool_calls(tool_calls, run=run, run_handler=self)
1897+
tool_outputs = toolset._execute_tool_calls(tool_calls, run, self)
18981898

18991899
if _has_errors_in_toolcalls_output(tool_outputs):
19001900
if current_retry >= runs_operations._function_tool_max_retry: # pylint:disable=no-else-return
19011901
logger.warning(
19021902
"Tool outputs contain errors - reaching max retry %s",
19031903
runs_operations._function_tool_max_retry,
19041904
)
1905-
run = runs_operations.cancel(thread_id=run.thread_id, run_id=run.id)
1905+
return runs_operations.cancel(thread_id=run.thread_id, run_id=run.id)
19061906
else:
19071907
logger.warning("Tool outputs contain errors - retrying")
19081908
current_retry += 1
@@ -1927,9 +1927,9 @@ def _start(self, runs_operations: "RunsOperations", run: ThreadRun, polling_inte
19271927
tool_approval = self.submit_mcp_tool_approval(run, tool_call)
19281928
if not tool_approval:
19291929
logger.debug(
1930-
"submit_tool_approval in event handler returned None. Please override this function and return a valid ToolApproval."
1930+
"submit_tool_approval in run handler returned None. Please override this function and return a valid ToolApproval."
19311931
)
1932-
run = runs_operations.cancel(thread_id=run.thread_id, run_id=run.id)
1932+
return runs_operations.cancel(thread_id=run.thread_id, run_id=run.id)
19331933

19341934
tool_approvals.append(tool_approval)
19351935

@@ -1938,6 +1938,7 @@ def _start(self, runs_operations: "RunsOperations", run: ThreadRun, polling_inte
19381938
thread_id=run.thread_id, run_id=run.id, tool_approvals=tool_approvals
19391939
)
19401940

1941+
logger.debug("Current run ID: %s with status: %s", run.id, run.status)
19411942
return run
19421943

19431944
def submit_function_call_output(

sdk/ai/azure-ai-agents/samples/agents_tools/sample_agents_functions_in_create_and_process.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@
5353

5454
class MyRunHandler(RunHandler):
5555
def submit_function_call_output(
56-
self, run: ThreadRun, tool_call: RequiredFunctionToolCall, tool_call_details: RequiredFunctionToolCallDetails, **kwargs: Any
56+
self,
57+
run: ThreadRun,
58+
tool_call: RequiredFunctionToolCall,
59+
tool_call_details: RequiredFunctionToolCallDetails,
60+
**kwargs: Any,
5761
) -> Optional[Any]:
5862
print(f"Call function: {tool_call_details.name}")
5963
return functions.execute(tool_call)

sdk/ai/azure-ai-agents/tests/test_agents_mock.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright (c) Microsoft Corporation.
33
# Licensed under the MIT License.
44
# ------------------------------------
5-
from typing import Any, Iterator, List, MutableMapping, Optional, Dict
5+
from typing import Any, Callable, Iterator, List, MutableMapping, Optional, Dict
66

77
import json
88
import os
@@ -24,6 +24,7 @@
2424
ToolOutput,
2525
AgentEventHandler,
2626
ThreadRun,
27+
RunHandler,
2728
)
2829

2930
from user_functions import user_functions
@@ -88,7 +89,7 @@ def get_mock_client(self) -> AgentsClient:
8889
client.runs.submit_tool_outputs = MagicMock()
8990
return client
9091

91-
def get_toolset(self, file_id: Optional[str], function: Optional[str]) -> Optional[ToolSet]:
92+
def get_toolset(self, file_id: Optional[str], function: Optional[Callable[..., Any]]) -> Optional[ToolSet]:
9293
"""Get the tool set with given file id and function"""
9394
if file_id is None or function is None:
9495
return None
@@ -507,6 +508,58 @@ def test_create_run_tools_override(
507508
else:
508509
self._assert_tool_call(agents_client.runs.submit_tool_outputs, "run123", toolset1)
509510

511+
@patch("azure.ai.agents._client.PipelineClient")
512+
def test_create_and_process_with_manual_function_calls(
513+
self,
514+
mock_pipeline_client_gen: MagicMock,
515+
) -> None:
516+
"""Test that if user have set tool set in create create_and_process_run method, that tools are used."""
517+
toolset1 = self.get_toolset("file_for_agent_1", function1)
518+
mock_response = MagicMock()
519+
mock_response.status_code = 200
520+
side_effect = [self._get_agent_json("first", "123", toolset1)]
521+
side_effect.append(self._get_run("run123", toolset1)) # create_run
522+
side_effect.append(self._get_run("run123", toolset1)) # get_run
523+
side_effect.append(
524+
self._get_run("run123", toolset1, is_complete=True)
525+
) # get_run after resubmitting with tool results
526+
527+
class MyRunHandler(RunHandler):
528+
def submit_function_call_output(
529+
self,
530+
run: ThreadRun,
531+
tool_call: RequiredFunctionToolCall,
532+
tool_call_details: RequiredFunctionToolCallDetails,
533+
**kwargs: Any,
534+
) -> Optional[Any]:
535+
if tool_call_details.name == function1.__name__:
536+
return function1()
537+
538+
mock_response.json.side_effect = side_effect
539+
mock_pipeline_response = MagicMock()
540+
mock_pipeline_response.http_response = mock_response
541+
mock_pipeline = MagicMock()
542+
mock_pipeline._pipeline.run.return_value = mock_pipeline_response
543+
mock_pipeline_client_gen.return_value = mock_pipeline
544+
agents_client = self.get_mock_client()
545+
with agents_client:
546+
# Check that pipelines are created as expected.
547+
self._set_toolcalls(agents_client, toolset1, None)
548+
run_handler = MyRunHandler()
549+
agent1 = agents_client.create_agent(
550+
model="gpt-4-1106-preview",
551+
name="first",
552+
instructions="You are a helpful agent",
553+
toolset=toolset1,
554+
)
555+
self._assert_pipeline_and_reset(mock_pipeline._pipeline.run, tool_set=toolset1)
556+
557+
# Create run with new tool set, which also can be none.
558+
agents_client.runs.create_and_process(
559+
thread_id="some_thread_id", agent_id=agent1.id, polling_interval=0, run_handler=run_handler
560+
)
561+
self._assert_tool_call(agents_client.runs.submit_tool_outputs, "run123", toolset1)
562+
510563
@patch("azure.ai.agents._client.PipelineClient")
511564
@pytest.mark.parametrize(
512565
"file_agent_1,add_azure_fn",

0 commit comments

Comments
 (0)